Bootstrap

gorm源码阅读之callback

基于 https://github.com/go-gorm/gorm v1.21.x

gorm执行sql语句都是通过注册callback函数实现的,最终的curd语句都是到callback这里才真正的得到执行,这个callback在gorm的怎么初始化到执行的,分析一下。

一、初始化callback

1、调用堆栈

main.go 创建db

//main.go 创建db 
db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{})

//gorm.go 这里Dialector是驱动 
config.Dialector.Initialize(db)

//mysql.go 这里我用的mysql的驱动,包是gorm.io/driver/mysql
func (dialector Dialector) Initialize(db *gorm.DB) (err error)

//注册db的callbacks
func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
  //...
  queryCallback := db.Callback().Query()
  queryCallback.Register("gorm:query", Query)
  queryCallback.Register("gorm:preload", Preload)
  queryCallback.Register("gorm:after_query", AfterQuery)
  //...
}

2、callback结构

callbacks就是一个map,map里面是一个processor

//callbacks就是一个map,map里面是一个processor
type callbacks struct {
  processors map[string]*processor
}

//这里要区分db.callbacks和processor.callbacks,两个是不同的东西
type processor struct {
  db        *DB
  fns       []func(*DB)
  callbacks []*callback
}

//这个是processor里面存的callback
type callback struct {
  name      string
  before    string
  after     string
  remove    bool
  replace   bool
  match     func(*DB) bool
  handler   func(*DB)
  processor *processor
}

//初始化的时候,就是为curd等几个操作分别创建一个processor
func initializeCallbacks(db *DB) *callbacks {
  return &callbacks{
    processors: map[string]*processor{
      "create": {db: db},
      "query":  {db: db},
      "update": {db: db},
      "delete": {db: db},
      "row":    {db: db},
      "raw":    {db: db},
    },
  }
}

3、如何注册callback

//参考 <1、调用堆栈>,注册db默认的callbacks
func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
  //...
    //先取出db.callbacks这个map里query的processor
  queryCallback := db.Callback().Query()
  queryCallback.Register("gorm:query", Query)
  queryCallback.Register("gorm:preload", Preload)
  queryCallback.Register("gorm:after_query", AfterQuery)
  //...
}

//往processor里注册回调函数fn
func (p *processor) Register(name string, fn func(*DB)) error {
    //这里是新建了一个临时的callback,然后将这个callback保存到processor里面
    //这里的callback和db.callbacks不是同一个东西哦
  return (&callback{processor: p}).Register(name, fn)
}

//将name和回调函数存在新建的临时callback里,同时将这个callback保存到processor里面
func (c *callback) Register(name string, fn func(*DB)) error {
  c.name = name
  c.handler = fn
  c.processor.callbacks = append(c.processor.callbacks, c)
  return c.processor.compile()
}

//将processor里面的callback进行优先级排序,同时将排序后的回调保存在p.fns里
func (p *processor) compile() (err error) {
  var callbacks []*callback
  for _, callback := range p.callbacks {
    if callback.match == nil || callback.match(p.db) {
      callbacks = append(callbacks, callback)
    }
  }
  p.callbacks = callbacks

    //对processor.callbacks进行排序
  if p.fns, err = sortCallbacks(p.callbacks); err != nil {
    p.db.Logger.Error(context.Background(), "Got error when compile callbacks, got %v", err)
  }
  return
}

//如何自定义注册带顺序的回调
db.Callback().Create().Before("gorm:create").Register("update_created_at", updateCreated)
db.Callback().Create().After("gorm:create").Register("update_created_at", updateCreated)
db.Callback().Query().After("gorm:query").Register("my_plugin:after_query", afterQuery)
db.Callback().Delete().After("gorm:delete").Register("my_plugin:after_delete", afterDelete)
db.Callback().Update().Before("gorm:update").Register("my_plugin:before_update", beforeUpdate)

二、调用callback

1、调用堆栈

//以下面这个查询为例
var user User
if err := db.Preload("Clazz").Where("id = ?", 1).First(&user).Error; nil != err {
  fmt.Printf("err:%v\n", err)
}

//db.First()才开始真正的执行查询
func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) {
  //...
    //这里是真正的执行查询
  tx.callbacks.Query().Execute(tx)
  //...
}

//tx.callbacks.Query()就是取出query对应的processor
func (cs *callbacks) Query() *processor {
  return cs.processors["query"]
}

//执行processor的Execute()
func (p *processor) Execute(db *DB) {
    //...
    //这里调用p.fns,也就是排过序的callbacks
  for _, f := range p.fns {
    f(db)
  }
}

2、真正执行Query地方

//上面的堆栈执行到p.fns,在一开始的时候,db就注册了到fns了,
//就是下面这个Query函数
func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
  //...
    //先取出db.callbacks这个map里query的processor
  queryCallback := db.Callback().Query()
  queryCallback.Register("gorm:query", Query)
  queryCallback.Register("gorm:preload", Preload)
  queryCallback.Register("gorm:after_query", AfterQuery)
  //...
}

//callbacks/query.go里面定义了真正的查询的地方
func Query(db *gorm.DB) {
  if db.Error == nil {
    BuildQuerySQL(db)

    if !db.DryRun && db.Error == nil {
      rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
      if err != nil {
        db.AddError(err)
        return
      }
      defer rows.Close()

      gorm.Scan(rows, db, false)
    }
  }
}