Go 中的泛型,日常如何使用
泛型从 go 的 1.18 开始支持
什么是泛型编程
在泛型出现之前,如果需要计算两数之和,可能会这样写:
func Add(a, b int) int { returb a + b }
这个很简单,但是只能两个参数都是 int 类型的时候才能调用
如果想要计算两个浮点数的和,就需要再定义一个函数
func AddFloat32(a, b float32) float32 { return a + b }
如果需要计算两个字符串的和,就需要再定义一个函数,这样太过麻烦
在一个函数中,行参只是类似占位符的东西,只有调用函数传入实参之后才有具体的值
如果把行参、实参的概念推广一下,给变量的类型也引入类似行参实参的概念,那就可以接收各种类型的值
函数就类似于:
func Add(a, b T) T { return a + b }
在这段代码中,T 可以称为 类型行参(type parameter),它不是具体的类型,在定义函数时参数的类型不确定,在传入时才确定,传入参数的具体类型称为 类型实参(type argument)
这样通过 类型行参 和 类型实参 进行编码的方式就称为 泛型编程
Go 的泛型
除了上面提到的类型行参和类型实参,Go 还引入了其他概念:
-
类型形参 (Type parameter)
-
类型实参(Type argument)
-
类型形参列表( Type parameter list)
-
类型约束(Type constraint)
-
实例化(Instantiations)
-
泛型类型(Generic type)
-
泛型接收器(Generic receiver)
-
泛型函数(Generic function)
一个一个说
类型行参、类型实参、类型约束和泛型类型
假如现在要定义一个切片,可以容纳 int、float32、string 等多种类型,不使用泛型常规做法是为每种类型各自定义一个切片
使用泛型,就可以这样定义:
type AllTypeSlice[T int | float32 | string] []T
在这行代码中:
-
T 就是上面介绍的 类型参数,在定义 slice 时 T 的类型不确定,类似于一个占位符
-
int|float|string
称为 类型约束,中间的|
就是告诉编译器只能接收这几种类型中的一种 -
中括号[] 中的
T int|float32|float64
这一整串定义了所有的类型实参(这里只有 T 这一个类型行参),称为 类型行参列表
泛型类型不能拿来使用,必须传入类型实参,能确定类型后才能使用,这个过程称为 实例化
func main() { type AllTypeSlice[T int | float32 | string] []T intSlice := AllTypeSlice[int]{} fmt.Printf("%T\n", intSlice) // main.AllTypeSlice[int] floatSlice := AllTypeSlice[float32]{} fmt.Printf("%T\n", floatSlice) // main.AllTypeSlice[float32] }
类型行参的个数可以有多个,比如:
type AllTypeMap[KEY int | string, VALUE float32 | float64] map[KEY]VALUE
这样就定义了一个泛型 map,key 和 value 都可以是多种
func main() { type AllTypeMap[KEY int | string, VALUE float32 | float64] map[KEY]VALUE var a AllTypeMap[string, float64] = map[string]float64{ "jack_score": 9.6, "bob_score": 8.4, } fmt.Printf("%T\n", a) // main.AllTypeMap[string,float64] }
其他的泛型类型
所有类型定义都可以使用类型行参,包括结构体、接口、通道
// 泛型结构体 type AllTypeStruct[T int | string] struct { Name string Data T } // 泛型接口 type PrintData[T int | float32] interface { Print(data T) } // 泛型通道 type AllTypeChan[T int | string] chan T
类型行参的嵌套
类型行参是可以嵌套使用的,比如:
type AllTypeStruct[T int | float32, S []T] struct { Data S MaxValue T MinValue T }
使用:
s1 := AllTypeStruct[int, []int]{} s2 := AllTypeStruct[int, []float32]{} // 报错 类型需要一致
即使 T 可以从 int 和 float32 中选,但是传入时类型已经确定了,所以 T 的类型和 []T 中 T 的类型要保持一致
泛型的使用
泛型一般用来实现一些不需要关注具体类型就可以使用的通用方法
给出一个比较常用的场景:Gorm 中使用泛型写一些基础方法,不同类型调用这些方法,会使用对应的 model,通过相同的逻辑查不同的表,而不需要再分别写各自的方法
常见的一些通用方法:
func IsErrRecordNotFound(err error) bool { return errors.Is(err, gorm.ErrRecordNotFound) } type baseRepo[T any] struct { db *gorm.DB } func NewBaseRepo[T any](db *gorm.DB) baseRepo[T] { return baseRepo[T]{db: db} } func (r *baseRepo[T]) GetDB() *gorm.DB { return r.db } // GetByID 通过 id 查单条记录 func (r *baseRepo[T]) GetByID(id int, preloads ...string) (*T, error) { var result T db := r.db for _, preload := range preloads { db = db.Preload(preload) } if err := db.First(&result, id).Error; err != nil { return nil, err } return &result, nil } // GetByIds 通过 ids 用 IN 查结果集 func (r *baseRepo[T]) GetByIds(ids []int, preloads ...string) (list []*T, err error) { db := r.db for _, preload := range preloads { db = db.Preload(preload) } err = db.Unscoped().Where("id IN ?", ids).Find(&list).Error return } // GetFirst 根据条件查第一条记录 传入对应类型结构体 func (r *baseRepo[T]) GetFirst(cond T, preloads ...string) (*T, error) { var result T db := r.db for _, preload := range preloads { db = db.Preload(preload) } if err := db.First(&result, cond).Error; err != nil { if IsErrRecordNotFound(err) { return nil, nil } return nil, err } return &result, nil } // GetList 通过条件查所有记录 func (r *baseRepo[T]) GetList(cond T, preloads ...string) ([]*T, error) { var list []*T db := r.db for _, preload := range preloads { db = db.Preload(preload) } if err := db.Find(&list, cond).Error; err != nil { return nil, err } return list, nil } // GetListWithOrder 根据条件查询所有记录并排序 func (r *baseRepo[T]) GetListWithOrder(cond T, order string, preloads ...string) ([]*T, error) { var list []*T db := r.db for _, preload := range preloads { db = db.Preload(preload) } if err := db.Order(order).Find(&list, cond).Error; err != nil { return nil, err } return list, nil } // GetPage 分页查询记录 func (r *baseRepo[T]) GetPage(cond *T, order string, pageNo, pageSize int, preloads ...string) ([]*T, error) { db := r.db for _, preload := range preloads { db = db.Preload(preload) } offset := (pageNo - 1) * pageSize limit := pageSize db = db.Order(order).Offset(offset).Limit(limit) if cond != nil { db = db.Where(cond) } var list []*T if err := db.Find(&list).Error; err != nil { return nil, err } return list, nil } // LikeWithOrder 模糊查询并排序 func (r *baseRepo[T]) LikeWithOrder(columns []string, keyword, order string, preloads ...string) ([]*T, error) { var list []*T db := r.db for _, preload := range preloads { db = db.Preload(preload) } like := "%" + keyword + "%" for _, column := range columns { query := fmt.Sprintf("%s like ?", column) db = db.Or(query, like) } if err := db.Order(order).Find(&list).Error; err != nil { return nil, err } return list, nil } // GetAll 查询所有记录 func (r *baseRepo[T]) GetAll(preloads ...string) ([]*T, error) { var list []*T db := r.db for _, preload := range preloads { db = db.Preload(preload) } if err := db.Find(&list).Error; err != nil { return nil, err } return list, nil } // GetIds 通过条件查 ids func (r *baseRepo[T]) GetIds(cond T) ([]int, error) { var ids []int model := new(T) if err := r.db.Model(model).Where(cond).Pluck("id", &ids).Error; err != nil { return nil, err } return ids, nil } // UpdateById 更新单个属性 func (r *baseRepo[T]) UpdateById(id int, column string, value any) error { m := new(T) return r.db.Model(m).Where("id = ?", id).Update(column, value).Error } // UpdatesById 更新多个属性值,可以传 map 或者结构体 func (r *baseRepo[T]) UpdatesById(id int, updateInfo interface{}) error { m := new(T) return r.db.Model(m).Where("id = ?", id).Updates(updateInfo).Error } // DeleteByID 删除一条记录 func (r *baseRepo[T]) DeleteByID(id int, force ...bool) error { m := new(T) session := r.db.Model(m) if len(force) > 0 && force[0] == true { session = session.Unscoped() } return session.Delete(m, id).Error } // DeleteBatch 删除多条记录 func (r *baseRepo[T]) DeleteBatch(ids []int, force ...bool) error { m := new(T) session := r.db.Where("id IN ?", ids) if len(force) > 0 && force[0] == true { session = session.Unscoped() } return session.Delete(m).Error } // DeleteWith 通过条件删除记录 func (r *baseRepo[T]) DeleteWith(cond T, force ...bool) error { m := new(T) session := r.db.Where(cond) if len(force) > 0 && force[0] == true { session = session.Unscoped() } return session.Delete(m).Error } // Count 查询记录条目数 func (r *baseRepo[T]) Count(cond T) (int, error) { var num int64 m := new(T) err := r.db.Model(m).Where(cond).Count(&num).Error return int(num), err }
这些方法都是一些逻辑比较简单的通用方法,在使用时,只需要通过结构体嵌套的方式注入 baseRepo
,在查询时就可以直接使用,很方便
结语
对于泛型,我认为只需要掌握基础的使用方法,一些细节和高级用法需要用到时再去查询比较合适,重点在于日常使用
参考(真的很详细,推荐阅读):后端 - Go 1.18 泛型全面讲解:一篇讲清泛型的全部 - 个人文章 - SegmentFault 思否