gorm源码解析(四):事务,预编译
文章目录
- 前言
- 事务
- 自己控制事务
- 用 Transaction方法包装事务
- 预编译
- 事务结合预编译
- 总结
前言
前几篇文章介绍gorm的整体设计,增删改查的具体实现流程。本文将聚焦与事务和预编译部分
事务
自己控制事务
用gorm框架,可以自己控制事务的Begin,Commit和Rollback,如下所示:
// 开始事务
tx := db.Begin()
// 在事务中执行一些 db 操作(从这里开始,您应该使用 'tx' 而不是 'db')
tx.Create(...)
// ...
// 遇到错误时回滚事务
tx.Rollback()
// 否则,提交事务
tx.Commit()
下面看看每个api的源码实现
Beigin:
- 新建一个db实例,作为本次事务操作的会话
- 非预编译模式下:调sql.DB的BeginTx方法,让
tx.Statement.ConnPool
持有其返回的sql.Tx- 之后的增删改查操作,都用这个sql.Tx执行,会用同一个db连接(也就是调begin的那个连接)
预编译模式本文后面再介绍,这里只关注非预编译模式
func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
var (
// clone statement
// 新建的tx.clone = 1
tx = db.getInstance().Session(&Session{Context: db.Statement.Context, NewDB: db.clone == 1})
opt *sql.TxOptions
err error
)
if len(opts) > 0 {
opt = opts[0]
}
switch beginner := tx.Statement.ConnPool.(type) {
// 非预编译模式下:
case TxBeginner:
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
// 预编译模式下:
case ConnPoolBeginner:
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
default:
err = ErrInvalidTransaction
}
if err != nil {
tx.AddError(err)
}
return tx
}
Commit和Rollback:调sql.Tx的Commit和Rollback方法,内部会调mysql驱动mysqlTx的Commit和Rollback方法,完成事务的提交和回滚操作
func (db *DB) Commit() *DB {
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil && !reflect.ValueOf(committer).IsNil() {
db.AddError(committer.Commit())
} else {
db.AddError(ErrInvalidTransaction)
}
return db
}
func (db *DB) Rollback() *DB {
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
if !reflect.ValueOf(committer).IsNil() {
db.AddError(committer.Rollback())
}
} else {
db.AddError(ErrInvalidTransaction)
}
return db
}
用 Transaction方法包装事务
然而,自己控制事务有以下问题:
- 很多重复代码,样板代码
- 自己控制事务生命周期容易出问题,例如可能忘记commit或rollback,或者说发生panic但没有recover,导致没有触发commit或rollback
于是gorm提供了Transaction
方法,帮我们调了Begin,事务执行成功后Commit,事务执行失败或发生panic时执行Rollback操作,也就是帮我们控制事务的生命周期,我们只用关注业务逻辑即可
使用方法为将业务逻辑func传进去,例如:
db.Transaction(func(tx *gorm.DB) error {
// 在事务中执行一些 db 操作(从这里开始,您应该使用 'tx' 而不是 'db')
if err := tx.Create(&Animal{Name: "Giraffe"}).Error; err != nil {
// 返回任何错误都会回滚事务
return err
}
if err := tx.Create(&Animal{Name: "Lion"}).Error; err != nil {
return err
}
// 返回 nil 提交事务
return nil
})
Transaction方法流程如下:
- db.Begin开启事务
- 执行fc,把事务tx传进去
- 如果fc执行成功,执行
tx.Commit
提交事务 - 如果fc执行出错或发生panic,调
tx.Rollback
回滚事务
func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) {
panicked := true
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
// ...
} else {
// 开启事务
tx := db.Begin(opts...)
if tx.Error != nil {
return tx.Error
}
defer func() {
// 如果发生panic,或者执行出错,回滚
if panicked || err != nil {
tx.Rollback()
}
}()
if err = fc(tx); err == nil {
panicked = false
return tx.Commit().Error
}
}
panicked = false
return
}
预编译
只要用database/sql配合mysql驱动,执行sql时一定走了预编译。要么是客户端预编译,要么是mysql服务端预编译
在gorm层面,PrepareStmt的目的是提高性能,而不是执行sql时从非预编译变成预编译
默认gorm层面的PrepareStmt为false,这里假设使用服务端预编译(连接mysql的dsn中interpolateParams为false),执行sql模板 + 参数
类型的操作时有如下流程:
- 往mysql发送sql模板,获得
stmtId
- 往msqyl发送stmtId + 参数,得到执行结果
- 往mysql发送释放stmt命令
如果使用gorm层面的PrepareStmt,会对sql.Stmt进行缓存,如果当前连接预编译过该Stmt,就能直接用
接下来看看gorm层面怎么处理预编译的
初始化db时(gorm.Open),如果config.PrepareStmt为true,使用预编译模式:
func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
// ...
// 预编译模式
if config.PrepareStmt {
preparedStmt := NewPreparedStmtDB(db.ConnPool)
db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
db.ConnPool = preparedStmt
}
db.Statement = &Statement{
DB: db,
ConnPool: db.ConnPool,
Context: context.Background(),
Clauses: map[string]clause.Clause{},
}
// ...
return
}
ConnPool被替换成PreparedStmtDB,其结构如下:
type PreparedStmtDB struct {
// key: sql模板,value:Stmt
Stmts map[string]*Stmt
Mux *sync.RWMutex
// 内置的 ConnPool 字段通常为 database/sql 中的 *DB
ConnPool
}
初始化PreparedStmtDB,内部持有sql.DB
func NewPreparedStmtDB(connPool ConnPool) *PreparedStmtDB {
return &PreparedStmtDB{
// 持有sql.DB
ConnPool: connPool,
// sql到Stmt的映射
Stmts: make(map[string]*Stmt),
Mux: &sync.RWMutex{},
}
}
可以看出PreparedStmtDB拥有sql模板到Stmt的缓存,那么遇到相同sql时,如果之前已经预编译过,就能用该Stmt执行db操作
后续基于这个db执行任何操作时,分为两个步骤:
- 通过
PreparedStmtDB.prepare(...)
操作创建/复用 stmt,后续相同 sql 模板可以复用此 stmt - 通过
stmt.Query(...)/Exec(...)
执行 sql
例如在执行PreparedStmtDB.QueryContext时:
- 先调PreparedStmtDB.prepare看有没有可复用的sql.Stmt
- 再调sql.Stmt执行QueryContext操作
func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
// 先获取stmt
stmt, err := db.prepare(ctx, db.ConnPool, false, query)
if err == nil {
// 再用stmt执行sql
rows, err = stmt.QueryContext(ctx, args...)
if errors.Is(err, driver.ErrBadConn) {
db.Mux.Lock()
defer db.Mux.Unlock()
go stmt.Close()
delete(db.Stmts, query)
}
}
return rows, err
}
PreparedStmtDB.prepare:
- 尝试从缓存Stmts中,根据query模板找sql.Stmt,如果有就返回
- 否则调sql.DB,根据query模板生成一个sql.Stmt,加入缓存中
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) {
db.Mux.RLock()
// 以sql为模板,先查有没有可复用的stmt
// 如果stmt.Transaction为false,可以复用
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
db.Mux.RUnlock()
// wait for other goroutines prepared
<-stmt.prepared
if stmt.prepareErr != nil {
return Stmt{}, stmt.prepareErr
}
return *stmt, nil
}
db.Mux.RUnlock()
db.Mux.Lock()
// double check
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
db.Mux.Unlock()
// wait for other goroutines prepared
<-stmt.prepared
if stmt.prepareErr != nil {
return Stmt{}, stmt.prepareErr
}
return *stmt, nil
}
// ...
// 到这里没有可复用的模板
// 创建stmt实例,加到map中
cacheStmt := Stmt{Transaction: isTransaction, prepared: make(chan struct{})}
db.Stmts[query] = &cacheStmt
db.Mux.Unlock()
// prepare completed
defer close(cacheStmt.prepared)
// 调sql.DB的prepareContext方法,创建sql.stmt
stmt, err := conn.PrepareContext(ctx, query)
if err != nil {
cacheStmt.prepareErr = err
db.Mux.Lock()
delete(db.Stmts, query)
db.Mux.Unlock()
return Stmt{}, err
}
db.Mux.Lock()
cacheStmt.Stmt = stmt
db.Mux.Unlock()
return cacheStmt, nil
}
这里是ORM层面对Stmt做了缓存
而在go标准库database/sql层面,stmt要能执行的前提是,在当前连接预编过
我们看sql.Stmt源码,注释上有这么一段话:
A Stmt is safe for concurrent use by multiple goroutines
翻译:Stmt能被多个g并发调用
When the Stmt needs to execute on a new underlying connection, it will prepare itself on the new connection automatically
翻译:当Stmt需要被新的连接使用时,需要在新连接上预编译
sql.Stmt有一个css
slice,存放预编译了该Stmt的连接
type Stmt struct {
// 持有driverStmt
cgds *driverStmt
// ...
// 预编译了该Stmt的连接
css []connStmt
}
type connStmt struct {
dc *driverConn
ds *driverStmt
}
基于Stmt执行Query, Exec
操作时,会检查当前使用的连接在不在Stmt.css
里面,如果在就能立即执行,否则需要先预编译该sql模板才能执行
这里以Stmt.ExecQuery为例,看看如何获取连接并执行
func (s *Stmt) ExecContext(ctx context.Context, args ...any) (Result, error) {
s.closemu.RLock()
defer s.closemu.RUnlock()
var res Result
err := s.db.retry(func(strategy connReuseStrategy) error {
// 获取连接
dc, releaseConn, ds, err := s.connStmt(ctx, strategy)
if err != nil {
return err
}
// 执行exec
res, err = resultFromStatement(ctx, dc.ci, ds, args...)
releaseConn(err)
return err
})
return res, err
}
重点在获取连接:
- 从连接池获取一个连接
- 检查该连接是否在Stmt.css里面,如果是,直接返回
- 否则需要先用该连接预编译sql模板
func (s *Stmt) connStmt(ctx context.Context, strategy connReuseStrategy) (dc *driverConn, releaseConn func(error), ds *driverStmt, err error) {
// ...
// 从连接池获取一个连接
dc, err = s.db.conn(ctx, strategy)
if err != nil {
return nil, nil, nil, err
}
s.mu.Lock()
// 检查该连接是否在Stmt.css里面,如果是,直接返回
for _, v := range s.css {
if v.dc == dc {
s.mu.Unlock()
return dc, dc.releaseConn, v.ds, nil
}
}
s.mu.Unlock()
// 否则需要先用该连接预编译sql模板
withLock(dc, func() {
ds, err = s.prepareOnConnLocked(ctx, dc)
})
if err != nil {
dc.releaseConn(err)
return nil, nil, nil, err
}
return dc, dc.releaseConn, ds, nil
}
prepareOnConnLocked:预编译完成后,将连接加入stmt.css
func (s *Stmt) prepareOnConnLocked(ctx context.Context, dc *driverConn) (*driverStmt, error) {
si, err := dc.prepareLocked(ctx, s.cg, s.query)
if err != nil {
return nil, err
}
cs := connStmt{dc, si}
s.mu.Lock()
// 预编译完成后,将当前连接加入stmt.css
s.css = append(s.css, cs)
s.mu.Unlock()
return cs.ds, nil
}
事务结合预编译
如果同时有事务和预编译,那么在执行Exec/Query时,稍微有点不一样
回到事务的Begin方法:
func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
// ...
switch beginner := tx.Statement.ConnPool.(type) {
// ...
// 预编译模式下:
case ConnPoolBeginner:
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
default:
err = ErrInvalidTransaction
}
if err != nil {
tx.AddError(err)
}
return tx
}
如果是预编译模式,进入PreparedStmtDB.BeginTx,返回PreparedStmtTX实例
func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) {
if beginner, ok := db.ConnPool.(TxBeginner); ok {
tx, err := beginner.BeginTx(ctx, opt)
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err
}
// ...
}
接下来看看基于PreparedStmtTX执行增删改查有何特别的地方
这里以PreparedStmtTX.ExecContext为例:
- 调
PreparedStmtDB.prepare
从缓存拿或新建一个Stmt - 调
sql.Tx.StmtContext
先处理步骤1返回的Stmt,再执行Exec,重点在这里
func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) {
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
if err == nil {
result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...)
if errors.Is(err, driver.ErrBadConn) {
tx.PreparedStmtDB.Mux.Lock()
defer tx.PreparedStmtDB.Mux.Unlock()
go stmt.Close()
delete(tx.PreparedStmtDB.Stmts, query)
}
}
return result, err
}
sql.Tx.StmtContext方法
- 拿到和事务tx绑定的连接dc
- 看dc是否预编译过该stmt,如果没有执行预编译操作
- 将事务tx放到
Stmt.cg
字段,后面执行Exec获取连接时,优先从该字段获取- 也就是保证执行整个事务都要用同一个连接
func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {
dc, release, err := tx.grabConn(ctx)
if err != nil {
return &Stmt{stickyErr: err}
}
defer release(nil)
if tx.db != stmt.db {
return &Stmt{stickyErr: errors.New("sql: Tx.Stmt: statement from different database used")}
}
var si driver.Stmt
var parentStmt *Stmt
stmt.mu.Lock()
if stmt.closed || stmt.cg != nil {
// ...
} else {
stmt.removeClosedStmtLocked()
// See if the statement has already been prepared on this connection,
// and reuse it if possible. for _, v := range stmt.css {
if v.dc == dc {
si = v.ds.si
break
}
}
stmt.mu.Unlock()
// 没预编译过,执行预编译
if si == nil {
var ds *driverStmt
withLock(dc, func() {
ds, err = stmt.prepareOnConnLocked(ctx, dc)
})
if err != nil {
return &Stmt{stickyErr: err}
}
si = ds.si
}
parentStmt = stmt
}
txs := &Stmt{
db: tx.db,
// 重点在这,将tx放到Stmt.cg中。后面执行Exec获取连接时,优先从该字段获取
cg: tx,
cgds: &driverStmt{
Locker: dc,
si: si,
},
parentStmt: parentStmt,
query: stmt.query,
}
if parentStmt != nil {
tx.db.addDep(parentStmt, txs)
}
tx.stmts.Lock()
tx.stmts.v = append(tx.stmts.v, txs)
tx.stmts.Unlock()
return txs
}
总结
至此,gorm的源码分析告一段落了,下一篇文章会介绍一些工程上使用gorm的最佳实践