GORM Playground Link

https://github.com/go-gorm/playground/pull/619

Description

我对 gorm 做了一层符合我们业务场景的封装,手动使用 relfect 构造了 Find 所需的 Dest ,可是当使用 Preload 时,Preload功能失效了。 整体代码大致如下。

type ScanRsp struct {
    Items []any    `json:"items"`
    Page  null.Int `json:"page"`
    Size  null.Int `json:"size"`
    Total null.Int `json:"total"`
    req   *ScanReq
}

type ScanReq struct {
    Page      null.Int `json:"page" validate:"required,min=1"`
    Size      null.Int `json:"size" validate:"required,min=1,max=10000"`
    Wheres    []*Where `json:"wheres" validate:"omitempty,dive"`
    Orders    []*Order `json:"orders" validate:"omitempty,dive"`
    Fields    []string `json:"fields"`
    Model     any      `json:"model"`
    SelectWay string   `json:"select_way" validate:"oneof=scan find first last pluck,default=find"`
    SkipCount bool     `json:"skip_count"`
}

// Do 更具请求条件执行查询
// opts 为其他选项,为查询添加一些自定义选择。
// 如需要连表加载,请使用 Joins 函数:
//
//  db.Joins("User")
//
// 如果需要指定查询表,请在 ScanReq 中指定 Model
// Do 函数中的 model  参数具体作用为为返回值指定类型
//
// example:
//
//  rsp, _ := scan.Do(&User{})
//  rsp.Items 中的数据类型为 *User
func (req *ScanReq) Do(model any, opts ...func(tx *gorm.DB) (*gorm.DB, error)) (*ScanRsp, error) {
    db := DB.GetDB()
    if req.Model != nil {
        db = db.Model(req.Model)
    } else {
        req.Model = model
        db = db.Model(model)
    }

    if req.Page.Valid { // 分页
        db = db.Offset(int(req.Page.Int64-1) * int(req.Size.Int64))
    }

    if req.Size.Valid { // 分页
        db = db.Limit(int(req.Size.Int64))
    }

    if len(req.Fields) > 0 { // 指定字段
        db = db.Select(req.Fields)
    }

    for _, where := range req.Wheres { // where 条件
        if where == nil {
            continue
        }
        if m, ok := req.Model.(TableNameInterface); ok {
            db = where.Model(m).Where(db)
        } else {
            db = where.Where(db)
        }
    }

    for _, order := range req.Orders { // order 条件
        if order == nil {
            continue
        }
        db = order.Order(db)
    }

    for _, opt := range opts { // 其他选项
        var err error
        if db, err = opt(db); err != nil {
            return nil, err
        }
    }

    items := NewModelSlice(model).Interface()
    one := NewEmptyModel(model).Interface()
    var err error
    switch strings.ToLower(req.SelectWay) {
    case "scan":
        err = db.Scan(&items).Error
    case "first":
        err = db.First(&one).Error
        if err == nil {
            items = []any{one}
        }
    case "last":
        err = db.Last(&one).Error
        if err == nil {
            items = []any{one}
        }
    case "pluck":
        if len(req.Fields) == 0 ||
            len(req.Fields) > 1 ||
            req.Fields[0] == "" ||
            req.Fields[0] == "*" {
            return nil, errors.New("invalid fields")
        }
        field := req.Fields[0]
        if err := db.Pluck(field, &items).Error; err != nil {
            return nil, err
        }
    case "find":
        fallthrough
    default:
        err = db.Find(&items).Error
    }

    if err != nil {
        if !errors.Is(err, gorm.ErrRecordNotFound) {
            return nil, errors.Wrap(err, "do err")
        }
    }

    var (
        count      int64
        countValid bool
    )
    if !req.SkipCount {
        if err := db.
            Limit(-1).  // cancel limit
            Offset(-1). // cancel offset
            Count(&count).Error; err != nil {
            return nil, err
        }
        countValid = true
    }

    return &ScanRsp{
        Items: AnyToAnySlice(items),
        Page:  null.IntFrom(req.Page.Int64),
        Size:  null.IntFrom(req.Size.Int64),
        Total: null.NewInt(count, countValid),
        req:   req,
    }, nil
}