Bootstrap

gorm源码阅读之schema

基于 https://github.com/go-gorm/gorm v1.21.x

数据映射

我们来看看不用gorm直接使用mysql驱动连接查询一个user的代码

type User struct {
  Id       int    `json:"id"`
  Name     string `json:"name"`
  Age      int    `json:"age"`
}
func queryUser(db *sql.DB){
        fmt.Println("query times:",i)
        user := new(User)
        row := db.QueryRow("select * from users where id=?", 1)
        //row.scan中的字段必须是按照数据库存入字段的顺序,否则报错
        if err := row.Scan(&user.Id,&user.Name,&user.Age); err != nil{
            fmt.Printf("scan failed, err:%v",err)
            return
        }
        fmt.Println(*user)
    }
}

来看看gorm是怎么做的

func queryUser(db *gorm.DB){
  var user User
    if err := db.First(&user).Error; nil == err {
    fmt.Printf("user:%+v\n", user)
    }
}

gorm帮我们解决了数据字段和struct结构的数据映射,这也是一个orm的关键所在。

gorm其实使用的是反射,在gorm源码里面,Schema是数据映射的这块的核心,

Schema实际上就是保存了目标对象,也就是user的数据结构

type DB struct {
  *Config
  Error        error
  RowsAffected int64
  Statement    *Statement
  clone        int
}
// db.Statement.Schema就是Schema对象了
type Statement struct {
  //...
  Schema               *schema.Schema
    //...
}

初始化Schema

初始化db的时候,初始化db.Statement并没有初始化Schema

func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
  //...可以看到初始化Statement的时候,Statement.Schema并没有初始化
  db.Statement = &Statement{
    DB:       db,
    ConnPool: db.ConnPool,
    Context:  context.Background(),
    Clauses:  map[string]clause.Clause{},
  }
  return
}

//从入口开始找Schema的初始化堆栈
{
    var user User
  if err := db.First(&user).Error; nil != err {
    fmt.Printf("err:%v\n", err)
  }
  fmt.Printf("user:%+v\n", user)
}

//First调用tx.callbacks.Query().Execute(tx)
func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) {
    //...
  //这里的Statement.Dest,也就是数据最终要保存到的user
  tx.Statement.Dest = dest
  tx.callbacks.Query().Execute(tx)
  return
}

//processor的Execute里初始化Schema
func (p *processor) Execute(db *DB) {
  //...
    //有两种方式,可以告知gorm目标对象的结构
    //第一种:db.Model(&user).Update("name", "hello"),直接传一个Model
    //第二种:db.First(&user),到First里,这个&user就是Dest
  if stmt.Model == nil {
    stmt.Model = stmt.Dest
  } else if stmt.Dest == nil {
    stmt.Dest = stmt.Model
  }

  if stmt.Model != nil {
        //stmt.Parse利用反射开始读取Model的结构
    if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.SQL.Len() == 0)) {
      if errors.Is(err, schema.ErrUnsupportedDataType) && stmt.Table == "" {
        db.AddError(fmt.Errorf("%w: Table not set, please set it like: db.Model(&user) or db.Table(\"users\")", err))
      } else {
        db.AddError(err)
      }
    }
  }
    //...
    //执行callback,真正的查询的地方
  for _, f := range p.fns {
    f(db)
  }
    //...
}

//这里的value就是&user
func (stmt *Statement) Parse(value interface{}) (err error) {
  if stmt.Schema, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil && stmt.Table == "" {
    if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 {
      stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)}
      stmt.Table = tables[1]
      return
    }

    stmt.Table = stmt.Schema.Table
  }
  return err
}

要理解Parse里面的代码要先将下面关于反射的测试代码理解了,

也可以参考

  var blog Blog
  structValue := reflect.ValueOf(blog)
  fmt.Printf("structValue value:%v\n", structValue)              // structValue value:{0  0}
  fmt.Printf("structValue type:%v\n", structValue.Type())        // structValue type:main.Blog
  fmt.Printf("structValue kind:%v\n", structValue.Type().Kind()) // structValue kind:struct
  fmt.Printf("structValue CanSet:%t\n", structValue.CanSet())    // structValue CanSet:false

  structPtrValue := reflect.ValueOf(&blog)
  fmt.Printf("structPtrValue value:%v\n", structPtrValue)                       // structPtrValue value:&{0  0}
  fmt.Printf("structPtrValue type:%v\n", structPtrValue.Type())                 // structPtrValue type:*main.Blog
  fmt.Printf("structPtrValue type Elem:%v\n", structPtrValue.Type().Elem())     // structPtrValue type Elem:main.Blog
  fmt.Printf("structPtrValue kind:%v\n", structPtrValue.Type().Kind())          // structPtrValue kind:ptr
  fmt.Printf("structPtrValue CanSet:%t\n", structPtrValue.CanSet())             // structPtrValue CanSet:false
  fmt.Printf("structPtrValue Elem CanSet:%t\n", structPtrValue.Elem().CanSet()) // structPtrValue Elem CanSet:true

现在来看Parse(),Parse就是利用反射将user里的每一个属性读到Field里去,比如属性名称,属性索引等

//这里的dest就是&user
func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
  if dest == nil {
    return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
  }

  modelType := reflect.ValueOf(dest).Type()//*main.User
    //modelType.Kind()是等于reflect.Ptr的
  for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
    modelType = modelType.Elem()//main.User
  }

    //传入的dest必须是struct
  if modelType.Kind() != reflect.Struct {
    if modelType.PkgPath() == "" {
      return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
    }
    return nil, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
  }

    //如果缓存有保存这个schema,直接返回
  if v, ok := cacheStore.Load(modelType); ok {
    s := v.(*Schema)
    <-s.initialized
    return s, s.err
  }

    //新建一个main.User
  modelValue := reflect.New(modelType)
    //modelType.Name() == User
  tableName := namer.TableName(modelType.Name())
    
    //如果有自定义的tablename,就用自定义的那个tablename
  if tabler, ok := modelValue.Interface().(Tabler); ok {
    tableName = tabler.TableName()
  }
  if en, ok := namer.(embeddedNamer); ok {
    tableName = en.Table
  }

    //初始化Schema对象
  schema := &Schema{
    Name:           modelType.Name(),
    ModelType:      modelType,
    Table:          tableName,
    FieldsByName:   map[string]*Field{},
    FieldsByDBName: map[string]*Field{},
    Relationships:  Relationships{Relations: map[string]*Relationship{}},
    cacheStore:     cacheStore,
    namer:          namer,
    initialized:    make(chan struct{}),
  }

  defer func() {
    if schema.err != nil {
      logger.Default.Error(context.Background(), schema.err.Error())
      cacheStore.Delete(modelType)
    }
  }()

    //得到user的所有的fieldStruct
  for i := 0; i < modelType.NumField(); i++ {
        //ast.IsExported判断Field是不是对外开放的,也就是属性名以大写开头
    if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) {
            //将fieldStruct解析到Field
      if field := schema.ParseField(fieldStruct); field.EmbeddedSchema != nil {
        schema.Fields = append(schema.Fields, field.EmbeddedSchema.Fields...)
      } else {
        schema.Fields = append(schema.Fields, field)
      }
    }
  }
    //...
  return schema, schema.err
}

如何给user赋值

现在user的Schema已经解析完了,那在哪里将数据库中的数据,设置到user里面去呢?

//callbacks/query.go里面定义了真正的查询的地方
func Query(db *gorm.DB) {
  if db.Error == nil {
    BuildQuerySQL(db)

    if !db.DryRun && db.Error == nil {
      rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
      if err != nil {
        db.AddError(err)
        return
      }
      defer rows.Close()

      gorm.Scan(rows, db, false)
    }
  }
}

//Scan将数据库的数据丢到user里
func Scan(rows *sql.Rows, db *DB, initialized bool) {
  columns, _ := rows.Columns()
  values := make([]interface{}, len(columns))
  db.RowsAffected = 0

    //First()里面tx.Statement.Dest = dest,就是这里的dest
  switch dest := db.Statement.Dest.(type) {
  case map[string]interface{}, *map[string]interface{}:
    //...
  default:
        //由于user是struct会走到这里
    Schema := db.Statement.Schema
    switch db.Statement.ReflectValue.Kind() {
    case reflect.Slice, reflect.Array:
      //...
    case reflect.Struct, reflect.Ptr:
      if db.Statement.ReflectValue.Type() != Schema.ModelType {
        Schema, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
      }

      if initialized || rows.Next() {
                //构建values准备接收数据库的值
        for idx, column := range columns {
          if field := Schema.LookUpField(column); field != nil && field.Readable {
            values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
          } else if names := strings.Split(column, "__"); len(names) > 1 {
                        //如果找不到column对应的field,判断column是否包含"__",这是什么情况?
            if rel, ok := Schema.Relationships.Relations[names[0]]; ok {
              if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
                values[idx] = reflect.New(reflect.PtrTo(field.IndirectFieldType)).Interface()
                continue
              }
            }
            values[idx] = &sql.RawBytes{}
          } else {
            values[idx] = &sql.RawBytes{}
          }
        }

        db.RowsAffected++
                //读取数据库返回的值
        db.AddError(rows.Scan(values...))

        for idx, column := range columns {
          if field := Schema.LookUpField(column); field != nil && field.Readable {
                          //在callbacks.go, processor.Execute里stmt.ReflectValue = reflect.ValueOf(stmt.Dest)
                        //将读取的值赋值给user的field
            field.Set(db.Statement.ReflectValue, values[idx])
          } else if names := strings.Split(column, "__"); len(names) > 1 {
            if rel, ok := Schema.Relationships.Relations[names[0]]; ok {
              if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable {
                relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue)
                value := reflect.ValueOf(values[idx]).Elem()

                if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
                  if value.IsNil() {
                    continue
                  }
                  relValue.Set(reflect.New(relValue.Type().Elem()))
                }

                field.Set(relValue, values[idx])
              }
            }
          }
        }
      }
    }
  }

  if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound {
    db.AddError(ErrRecordNotFound)
  }
}