当前位置: 首页 > article >正文

gorm源码解析(四):事务,预编译

文章目录

    • 前言
    • 事务
      • 自己控制事务
      • 用 Transaction方法包装事务
    • 预编译
    • 事务结合预编译
    • 总结

前言

前几篇文章介绍gorm的整体设计,增删改查的具体实现流程。本文将聚焦与事务和预编译部分

事务

自己控制事务

用gorm框架,可以自己控制事务的Begin,Commit和Rollback,如下所示:

// 开始事务
tx := db.Begin()

// 在事务中执行一些 db 操作(从这里开始,您应该使用 'tx' 而不是 'db')
tx.Create(...) 

// ...  

// 遇到错误时回滚事务
tx.Rollback()

// 否则,提交事务
tx.Commit()

下面看看每个api的源码实现

Beigin:

  1. 新建一个db实例,作为本次事务操作的会话
  2. 非预编译模式下:调sql.DB的BeginTx方法,让tx.Statement.ConnPool持有其返回的sql.Tx
    1. 之后的增删改查操作,都用这个sql.Tx执行,会用同一个db连接(也就是调begin的那个连接)

预编译模式本文后面再介绍,这里只关注非预编译模式

func (db *DB) Begin(opts ...*sql.TxOptions) *DB {  
    var (  
       // clone statement  
       // 新建的tx.clone = 1  
       tx  = db.getInstance().Session(&Session{Context: db.Statement.Context, NewDB: db.clone == 1})  
       opt *sql.TxOptions  
       err error  
    )  
  
    if len(opts) > 0 {  
       opt = opts[0]  
    }  
  
    switch beginner := tx.Statement.ConnPool.(type) {  
    // 非预编译模式下:  
    case TxBeginner:  
       tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)  
    // 预编译模式下:  
    case ConnPoolBeginner:  
       tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)  
    default:  
       err = ErrInvalidTransaction  
    }  
  
    if err != nil {  
       tx.AddError(err)  
    }  
  
    return tx  
}

Commit和Rollback:调sql.Tx的Commit和Rollback方法,内部会调mysql驱动mysqlTx的Commit和Rollback方法,完成事务的提交和回滚操作

func (db *DB) Commit() *DB {  
    if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil && !reflect.ValueOf(committer).IsNil() {  
       db.AddError(committer.Commit())  
    } else {  
       db.AddError(ErrInvalidTransaction)  
    }  
    return db  
}  
  

func (db *DB) Rollback() *DB {  
    if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {  
       if !reflect.ValueOf(committer).IsNil() {  
          db.AddError(committer.Rollback())  
       }  
    } else {  
       db.AddError(ErrInvalidTransaction)  
    }  
    return db  
}

用 Transaction方法包装事务

然而,自己控制事务有以下问题:

  1. 很多重复代码,样板代码
  2. 自己控制事务生命周期容易出问题,例如可能忘记commit或rollback,或者说发生panic但没有recover,导致没有触发commit或rollback

于是gorm提供了Transaction方法,帮我们调了Begin,事务执行成功后Commit,事务执行失败或发生panic时执行Rollback操作,也就是帮我们控制事务的生命周期,我们只用关注业务逻辑即可

使用方法为将业务逻辑func传进去,例如:

db.Transaction(func(tx *gorm.DB) error {  
    // 在事务中执行一些 db 操作(从这里开始,您应该使用 'tx' 而不是 'db')  
    if err := tx.Create(&Animal{Name: "Giraffe"}).Error; err != nil {  
       // 返回任何错误都会回滚事务  
       return err  
    }  
  
    if err := tx.Create(&Animal{Name: "Lion"}).Error; err != nil {  
       return err  
    }  
  
    // 返回 nil 提交事务  
    return nil  
})

在这里插入图片描述


Transaction方法流程如下:

  1. db.Begin开启事务
  2. 执行fc,把事务tx传进去
  3. 如果fc执行成功,执行tx.Commit提交事务
  4. 如果fc执行出错或发生panic,调tx.Rollback回滚事务
func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) {
	panicked := true

	if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
		// ...
	} else {
		// 开启事务
		tx := db.Begin(opts...)
		if tx.Error != nil {
			return tx.Error
		}

		defer func() {
			// 如果发生panic,或者执行出错,回滚
			if panicked || err != nil {
				tx.Rollback()
			}
		}()

		if err = fc(tx); err == nil {
			panicked = false
			return tx.Commit().Error
		}
	}

	panicked = false
	return
}

预编译

只要用database/sql配合mysql驱动,执行sql时一定走了预编译。要么是客户端预编译,要么是mysql服务端预编译

在gorm层面,PrepareStmt的目的是提高性能,而不是执行sql时从非预编译变成预编译
默认gorm层面的PrepareStmt为false,这里假设使用服务端预编译(连接mysql的dsn中interpolateParams为false),执行sql模板 + 参数类型的操作时有如下流程:

  1. 往mysql发送sql模板,获得stmtId
  2. 往msqyl发送stmtId + 参数,得到执行结果
  3. 往mysql发送释放stmt命令

如果使用gorm层面的PrepareStmt,会对sql.Stmt进行缓存,如果当前连接预编译过该Stmt,就能直接用
接下来看看gorm层面怎么处理预编译的

初始化db时(gorm.Open),如果config.PrepareStmt为true,使用预编译模式:

func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
	// ...

	// 预编译模式
	if config.PrepareStmt {
		preparedStmt := NewPreparedStmtDB(db.ConnPool)
		db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
		db.ConnPool = preparedStmt
	}

	db.Statement = &Statement{
		DB:       db,
		ConnPool: db.ConnPool,
		Context:  context.Background(),
		Clauses:  map[string]clause.Clause{},
	}

	// ...

	return
}

ConnPool被替换成PreparedStmtDB,其结构如下:

type PreparedStmtDB struct {
    // key: sql模板,value:Stmt
	Stmts map[string]*Stmt
	Mux   *sync.RWMutex
    // 内置的 ConnPool 字段通常为 database/sql 中的 *DB
	ConnPool
}

初始化PreparedStmtDB,内部持有sql.DB

func NewPreparedStmtDB(connPool ConnPool) *PreparedStmtDB {
	return &PreparedStmtDB{
	    // 持有sql.DB
		ConnPool: connPool,
		// sql到Stmt的映射
		Stmts:    make(map[string]*Stmt),
		Mux:      &sync.RWMutex{},
	}
}

可以看出PreparedStmtDB拥有sql模板到Stmt的缓存,那么遇到相同sql时,如果之前已经预编译过,就能用该Stmt执行db操作

后续基于这个db执行任何操作时,分为两个步骤:

  • 通过PreparedStmtDB.prepare(...) 操作创建/复用 stmt,后续相同 sql 模板可以复用此 stmt
  • 通过 stmt.Query(...)/Exec(...) 执行 sql
    在这里插入图片描述

例如在执行PreparedStmtDB.QueryContext时:

  1. 先调PreparedStmtDB.prepare看有没有可复用的sql.Stmt
  2. 再调sql.Stmt执行QueryContext操作
func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
	// 先获取stmt
	stmt, err := db.prepare(ctx, db.ConnPool, false, query)
	if err == nil {
		// 再用stmt执行sql
		rows, err = stmt.QueryContext(ctx, args...)
		if errors.Is(err, driver.ErrBadConn) {
			db.Mux.Lock()
			defer db.Mux.Unlock()

			go stmt.Close()
			delete(db.Stmts, query)
		}
	}
	return rows, err
}

PreparedStmtDB.prepare:

  1. 尝试从缓存Stmts中,根据query模板找sql.Stmt,如果有就返回
  2. 否则调sql.DB,根据query模板生成一个sql.Stmt,加入缓存中
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) {
	db.Mux.RLock()
	// 以sql为模板,先查有没有可复用的stmt
	// 如果stmt.Transaction为false,可以复用
	if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
		db.Mux.RUnlock()
		// wait for other goroutines prepared
		<-stmt.prepared
		if stmt.prepareErr != nil {
			return Stmt{}, stmt.prepareErr
		}

		return *stmt, nil
	}
	db.Mux.RUnlock()

	db.Mux.Lock()
	// double check
	if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
		db.Mux.Unlock()
		// wait for other goroutines prepared
		<-stmt.prepared
		if stmt.prepareErr != nil {
			return Stmt{}, stmt.prepareErr
		}

		return *stmt, nil
	}
	
	// ...
	
	// 到这里没有可复用的模板
	// 创建stmt实例,加到map中
	cacheStmt := Stmt{Transaction: isTransaction, prepared: make(chan struct{})}
	db.Stmts[query] = &cacheStmt
	db.Mux.Unlock()

	// prepare completed
	defer close(cacheStmt.prepared)

	// 调sql.DB的prepareContext方法,创建sql.stmt
	stmt, err := conn.PrepareContext(ctx, query)
	if err != nil {
		cacheStmt.prepareErr = err
		db.Mux.Lock()
		delete(db.Stmts, query)
		db.Mux.Unlock()
		return Stmt{}, err
	}

	db.Mux.Lock()
	cacheStmt.Stmt = stmt
	db.Mux.Unlock()

	return cacheStmt, nil
}

这里是ORM层面对Stmt做了缓存
而在go标准库database/sql层面,stmt要能执行的前提是,在当前连接预编过

我们看sql.Stmt源码,注释上有这么一段话:

A Stmt is safe for concurrent use by multiple goroutines
翻译:Stmt能被多个g并发调用


When the Stmt needs to execute on a new underlying connection, it will prepare itself on the new connection automatically
翻译:当Stmt需要被新的连接使用时,需要在新连接上预编译

sql.Stmt有一个css slice,存放预编译了该Stmt的连接

type Stmt struct {  
   
    // 持有driverStmt
    cgds *driverStmt  
	// ...

    // 预编译了该Stmt的连接
    css []connStmt  
  
}

type connStmt struct {  
    dc *driverConn  
    ds *driverStmt  
}

基于Stmt执行Query, Exec操作时,会检查当前使用的连接在不在Stmt.css里面,如果在就能立即执行,否则需要先预编译该sql模板才能执行

这里以Stmt.ExecQuery为例,看看如何获取连接并执行

func (s *Stmt) ExecContext(ctx context.Context, args ...any) (Result, error) {  
    s.closemu.RLock()  
    defer s.closemu.RUnlock()  
  
    var res Result  
    err := s.db.retry(func(strategy connReuseStrategy) error {
       // 获取连接  
       dc, releaseConn, ds, err := s.connStmt(ctx, strategy)  
       if err != nil {  
          return err  
       }  
       // 执行exec
       res, err = resultFromStatement(ctx, dc.ci, ds, args...)  
       releaseConn(err)  
       return err  
    })  
  
    return res, err  
}

重点在获取连接:

  1. 从连接池获取一个连接
  2. 检查该连接是否在Stmt.css里面,如果是,直接返回
  3. 否则需要先用该连接预编译sql模板
func (s *Stmt) connStmt(ctx context.Context, strategy connReuseStrategy) (dc *driverConn, releaseConn func(error), ds *driverStmt, err error) {  
    // ...
    
    // 从连接池获取一个连接
    dc, err = s.db.conn(ctx, strategy)  
    if err != nil {  
       return nil, nil, nil, err  
    }  
  
    s.mu.Lock()  
    // 检查该连接是否在Stmt.css里面,如果是,直接返回
    for _, v := range s.css {  
       if v.dc == dc {  
          s.mu.Unlock()  
          return dc, dc.releaseConn, v.ds, nil  
       }  
    }  
    s.mu.Unlock()  
  
    // 否则需要先用该连接预编译sql模板
    withLock(dc, func() {  
       ds, err = s.prepareOnConnLocked(ctx, dc)  
    })  
    if err != nil {  
       dc.releaseConn(err)  
       return nil, nil, nil, err  
    }  
  
    return dc, dc.releaseConn, ds, nil  
}

prepareOnConnLocked:预编译完成后,将连接加入stmt.css

func (s *Stmt) prepareOnConnLocked(ctx context.Context, dc *driverConn) (*driverStmt, error) {  
    si, err := dc.prepareLocked(ctx, s.cg, s.query)  
    if err != nil {  
       return nil, err  
    }  
    cs := connStmt{dc, si}  
    s.mu.Lock() 
	    // 预编译完成后,将当前连接加入stmt.css
    s.css = append(s.css, cs)  
    s.mu.Unlock()  
    return cs.ds, nil  
}

事务结合预编译

如果同时有事务和预编译,那么在执行Exec/Query时,稍微有点不一样

回到事务的Begin方法:

func (db *DB) Begin(opts ...*sql.TxOptions) *DB {  
    // ...
  
    switch beginner := tx.Statement.ConnPool.(type) {  
    // ...  
    // 预编译模式下:  
    case ConnPoolBeginner:  
       tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)  
    default:  
       err = ErrInvalidTransaction  
    }  
  
    if err != nil {  
       tx.AddError(err)  
    }  
  
    return tx  
}

如果是预编译模式,进入PreparedStmtDB.BeginTx,返回PreparedStmtTX实例

func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) {  
    if beginner, ok := db.ConnPool.(TxBeginner); ok {  
       tx, err := beginner.BeginTx(ctx, opt)  
       return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err  
    }  
  
    // ... 
}

接下来看看基于PreparedStmtTX执行增删改查有何特别的地方
这里以PreparedStmtTX.ExecContext为例:

  1. PreparedStmtDB.prepare从缓存拿或新建一个Stmt
  2. sql.Tx.StmtContext先处理步骤1返回的Stmt,再执行Exec,重点在这里
func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) {  
    stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)  
    if err == nil {  
       result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...)  
       if errors.Is(err, driver.ErrBadConn) {  
          tx.PreparedStmtDB.Mux.Lock()  
          defer tx.PreparedStmtDB.Mux.Unlock()  
  
          go stmt.Close()  
          delete(tx.PreparedStmtDB.Stmts, query)  
       }  
    }  
    return result, err  
}

sql.Tx.StmtContext方法

  1. 拿到和事务tx绑定的连接dc
  2. 看dc是否预编译过该stmt,如果没有执行预编译操作
  3. 将事务tx放到Stmt.cg字段,后面执行Exec获取连接时,优先从该字段获取
    1. 也就是保证执行整个事务都要用同一个连接
func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {  
    dc, release, err := tx.grabConn(ctx)  
    if err != nil {  
       return &Stmt{stickyErr: err}  
    }  
    defer release(nil)  
  
    if tx.db != stmt.db {  
       return &Stmt{stickyErr: errors.New("sql: Tx.Stmt: statement from different database used")}  
    }  
    var si driver.Stmt  
    var parentStmt *Stmt  
    stmt.mu.Lock()  
    if stmt.closed || stmt.cg != nil {  
       // ...
    } else {  
       stmt.removeClosedStmtLocked()  
       // See if the statement has already been prepared on this connection,  
       // and reuse it if possible.       for _, v := range stmt.css {  
          if v.dc == dc {  
             si = v.ds.si  
             break  
          }  
       }  
  
       stmt.mu.Unlock()  
       // 没预编译过,执行预编译
       if si == nil {  
          var ds *driverStmt  
          withLock(dc, func() {  
             ds, err = stmt.prepareOnConnLocked(ctx, dc)  
          })  
          if err != nil {  
             return &Stmt{stickyErr: err}  
          }  
          si = ds.si  
       }  
       parentStmt = stmt  
    }  
  
    txs := &Stmt{  
       db: tx.db,  
       // 重点在这,将tx放到Stmt.cg中。后面执行Exec获取连接时,优先从该字段获取
       cg: tx,  
       cgds: &driverStmt{  
          Locker: dc,  
          si:     si,  
       },  
       parentStmt: parentStmt,  
       query:      stmt.query,  
    }  
    if parentStmt != nil {  
       tx.db.addDep(parentStmt, txs)  
    }  
    tx.stmts.Lock()  
    tx.stmts.v = append(tx.stmts.v, txs)  
    tx.stmts.Unlock()  
    return txs  
}

总结

至此,gorm的源码分析告一段落了,下一篇文章会介绍一些工程上使用gorm的最佳实践


http://www.kler.cn/a/442030.html

相关文章:

  • 31、【OS】【Nuttx】OSTest分析(1):stdio测试(一)
  • [Qt] Box Model | 控件样式 | 实现log_in界面
  • Netty的相关组件之间的关系
  • owasp SQL 注入-03 (原理)
  • 软件测试——期末复习
  • 网络安全---CMS指纹信息实战
  • Java基础知识(四) -- 面向对象(中)
  • 量化交易实操入门
  • SQLite 安装与使用
  • vue的elementUI 给输入框绑定enter事件失效
  • 【C语言】指针数组和数组指针
  • 25上半年软考《电子商务设计师》,备考大纲已出!
  • 为什么 Teams 中搜索不到 Power Automate
  • 电脑开机提示error loading operating system怎么修复?
  • 新手谷歌浏览器的使用(使用国内的搜索引擎)
  • lc238除自身以外数组的乘积——动态规划前缀积
  • Java全栈项目 - 智能小区物业管理平台开发实践
  • 新知DAC维修,换牛,
  • Rust操作符和符号全解析
  • Java对集合的操作方法
  • 面试小札:闪电五连鞭_7
  • opencv # Sobel算子、Laplacian算子、Canny边缘检测、findContours、drawContours绘制轮廓、外接矩形
  • Sentry日志管理thinkphp8 tp8 sentry9 sentry8 php8.x配置步骤, tp8自定义异常处理类使用方法
  • NSDT 3DConvert:高效实现大模型文件在线预览与转换
  • 关于llama2:从原始llama-2-7b到llama-2-7b-hf的权重转换教程
  • cesium 与 threejs 对比