当前位置: 首页 > article >正文

Go database/sql包源码分析

文章目录

    • 前言
    • 使用教程
    • 数据结构
      • 接口
      • 结构体
        • DB
        • driverConn
        • Tx
    • sql.DB 初始化
    • Query
      • 整体流程
      • 参数校验
      • 结果集处理
      • 释放连接
    • Exec
    • 预编译
      • PrepareContext
      • Stmt.ExecContext
      • Stmt.QueryContext
      • 直接调db.QueryContext,会预编译吗
    • 事务
      • Begin
      • Exec
      • Commit
      • Rollback
    • 连接池
      • 获取连接
      • 归还连接
      • 清理过期连接
    • 总结
    • 参考

前言

本系列会分为3个部分介绍go操作mysql:

  1. database/sql包
  2. mysql适配go sql的驱动
  3. 工程上使用的orm框架:gorm

本文是第一部分,将分析database/sql包的架构设计和实现原理

本文阅读go源码版本:go1.23.3


使用教程

假设有一张简单的user表,有3个字段,idnamenumber

推荐使用https://sqlpub.com/ 创建一个免费的db,用于学习调试

CREATE TABLE `user` (
  `id` bigint NOT NULL AUTO_INCREMENT,
  `name` varchar(100) NOT NULL DEFAULT '',
  `number` int NOT NULL DEFAULT '0',
  PRIMARY KEY (`id`)
) ENGINE=InnoDBDEFAULT CHARSET=utf8mb4

表里有一行数据:

1tom12

要用go原生的sql包查一行记录,有以下流程:

  1. 注册db驱动:如果用mysql,匿名import github.com/go-sql-driver/mysql 即可
    1. 该包的init方法会完成mysql驱动的注册
  2. 创建sql.DB实例
  3. 执行sql:调db.QueryContext执行查询操作
  4. 读取查询结果:调rows.Next和rows.Scan,将结果设置操传给rows.Scan的指针中
import (
	"context"
	"database/sql"
	"fmt"
    // 注册mysql驱动
	_ "github.com/go-sql-driver/mysql"
	"testing"
)

func TestDB(t *testing.T) {
    // 初始化sql.DB
	db, err := sql.Open("mysql", "账号:密码@(连接地址)/数据库名称")
	if err != nil {
		panic(err)
	}

	ctx := context.Background()
    // 执行查询
	rows, err := db.QueryContext(ctx, "select * from user where id = ?", 1)
	id := 0
	name := ""
	number := 0

    // 接收查询结果
	for rows.Next() {
		rows.Scan(&id, &name, &number)
	}

    // 输出:1 tom 12
	fmt.Println(id, name, number)
}

数据结构

标准库database/sql定义了一套用于sql查询的标准流程,将其关键部分,每种db不同的部分抽象成接口,由各个db的驱动实现

sql.DB维护了一个连接池,以及统一处理和DB无关的操作逻辑,并在适当的时机调Driver接口,完成数据库的操作


接口

在这里插入图片描述

  • Connector:抽象的连接器,可以创建连接,返回驱动Driver

  • Driver:抽象的数据库驱动,唯一的功能是创建一个连接

    • 但实际上sql包并没有用这个接口去创建连接,都是用Connector取创建连接
  • Conn:抽象的连接,具备预编译sql,以及开启事务的能力

  • Statement:抽象的预编译状态,传入参数能执行增删改查操作,不用再传sql语句

  • Tx:抽象的事物,可以执行CommitRollback操作

  • Rows:抽象的查询结果,用于知道查询结果有哪些列,以及读取查询结果

  • Result:抽象的增删改结果,用于知道本次执行的LastInsertIdRowsAffected

以上这些接口都在driver包下,每一种db的驱动都需要实现这些接口
例如:mysql的驱动实现在github.com/go-sql-driver/mysql
下面详细介绍这些接口

Connector:抽象的连接器

type Connector interface {
	// 创建一个链接
	Connect(context.Context) (Conn, error)

	// 主要用于sql.DB的Driver方法能返回Driver
	Driver() Driver
}
  • connector内部不用对Conn池化,因为sql.DB自带了连接池

sql.DB的Driver方法能返回Driver:

func (db *DB) Driver() driver.Driver {
	return db.connector.Driver()
}

Driver:抽象的驱动。只有新建连接这个功能

type Driver interface {
    // 新建一个连接
	Open(name string) (Conn, error)
}

Conn:抽象的db连接

type Conn interface {
    // 预编译
	Prepare(query string) (Stmt, error)
	Close() error
    // 开启事务
	Begin() (Tx, error)
}

Stmt:抽象的预编译状态

type Stmt interface {
	
	Close() error

	// 预编译sql有多少参数
	NumInput() int
  
    // 执行增删改操作	
	Exec(args []Value) (Result, error)

    // 查询
	Query(args []Value) (Rows, error)
}

Tx:抽象的事务

type Tx interface {
    // 事务提交
	Commit() error
    // 事务回滚
	Rollback() error
}

Rows:抽象的查询结果

type Rows interface {
	// 本次查询出来有哪些列
	Columns() []string

	Close() error

	// 将查询结果设置到dest中
	Next(dest []Value) error
}

Result:抽象的增删改执行结果

type Result interface {
	LastInsertId() (int64, error)
	RowsAffected() (int64, error)
}

结构体

除了给驱动实现的接口外,sql包定义了一些结构体,作为db client,以及封装对驱动的操作。包括了:

DB

DB:代表一个数据库实例,包含以下部分:

  • 各种统计信息,这里忽略

  • 关于连接池的配置,包括:

    • maxIdleCount:最大空闲连接数。如果连接池中已经有这么多空闲链接了,下次归还连接时直接释放,而不是再放回池中
    • maxOpen:最多可以打开的连接数
    • maxLifetime:一个连接最多能存在多久
    • maxIdleTime:一个连接最多能空闲多久
  • 连接池:

    • freeConn:空闲连接slice
    • connRequests:所有等待获取连接g的channel
    • numOpen:已经打开了多少连接
type DB struct {
    // 连接器
	connector driver.Connector
	
	mu           sync.Mutex
    // 空闲链接
	freeConn     []*driverConn
    // 所有等待获取连接g的channel
	connRequests connRequestSet
  
    // 已经打开的连接数
	numOpen      int    
	
	openerCh          chan struct{}
	closed            bool
	dep               map[finalCloser]depSet

    // 最大空闲连接数。如果连接池中已经有这么多空闲链接了,下次归还连接时直接释放,而不是再放回池中
	maxIdleCount      int
    // 最多可以打开的连接数
	maxOpen           int                    
	// 一个连接最多能存在多久
    maxLifetime       time.Duration          
    // 一个连接最多能空闲多久
	maxIdleTime       time.Duration         
	cleanerCh         chan struct{}
    // 多少g在阻塞等待连接
	waitCount         int64 	
}

其中connRequestSet结构如下:每个等待获取连接的请求会阻塞在connRequestAndIndex.req上

type connRequestSet struct {
    // 所有阻塞等待的请求
	s []connRequestAndIndex
}

type connRequestAndIndex struct {
    // 通过channel传递connRequest,里面有连接
	req chan connRequest
	curIdx *int
}

type connRequest struct {
    // 连接
	conn *driverConn
	err  error
}

driverConn

driverConn中持有db驱动实现的连接driver.Conn,以及用于连接池管理相关的创建时间,放回池的时间

type driverConn struct {
	db        *DB
    // 连接创建时间
	createdAt time.Time

	sync.Mutex
    // 持有db驱动实现的连接
	ci          driver.Conn
	needReset   bool 
	closed      bool
	finalClosed bool
    // 该连接下所有的stmt
	openStmt    map[*driverStmt]bool

	inUse      bool
    // 被放回连接池的时间,用于计算idleTime
	returnedAt time.Time 
    // 被放回连接池的回调函数
	onPut      []func()  
	dbmuClosed bool      
}

Tx

Tx持有驱动实现的driver.Tx,同时持有属于哪个连接,DB,事务是否完成等信息

type Tx struct {
    // 属于哪个DB
	db *DB


	closemu sync.RWMutex

	// 属于哪个连接
	dc  *driverConn
    // 驱动实现的Tx
	txi driver.Tx

	
	releaseConn func(error)

	// 事务是否已完成
	done atomic.Bool

	
	keepConnOnRollback bool

	// 当前事务的所有stmt
	stmts struct {
		sync.Mutex
		v []*Stmt
	}

	cancel func()
	ctx context.Context
}

sql.DB 初始化

在这里插入图片描述

  1. 校验driver是否已经注册
  2. 创建sql.DB实例,主要包含具体db的连接器connector,连接池
func Open(driverName, dataSourceName string) (*DB, error) {
    // 校验driver是否已注册
	driversMu.RLock()
	driveri, ok := drivers[driverName]
	driversMu.RUnlock()
	if !ok {
		return nil, fmt.Errorf("sql: unknown driver %q (forgotten import?)", driverName)
	}

	if driverCtx, ok := driveri.(driver.DriverContext); ok {
		connector, err := driverCtx.OpenConnector(dataSourceName)
		if err != nil {
			return nil, err
		}
		return OpenDB(connector), nil
	}

    // 调OpenDB
	return OpenDB(dsnConnector{dsn: dataSourceName, driver: driveri}), nil
}

OpenDB

  1. 创建db实例,持有connector。以后当db需要创建连接时,就调connector.Connect
  2. 启动connectionOpener协程,用于在连接池连接不足时补充连接
func OpenDB(c driver.Connector) *DB {
	ctx, cancel := context.WithCancel(context.Background())
	db := &DB{
		connector:    c,
		openerCh:     make(chan struct{}, connectionRequestQueueSize),
		lastPut:      make(map[*driverConn]string),
		connRequests: make(map[uint64]chan connRequest),
		stop:         cancel,
	}

	go db.connectionOpener(ctx)

	return db
}

connectionOpener:收到来自db.openerCh的信号后,调openNewConnection创建连接

func (db *DB) connectionOpener(ctx context.Context) {
	for {
		select {
		case <-ctx.Done():
			return
		case <-db.openerCh:
			db.openNewConnection(ctx)
		}
	}
}

openNewConnection

  1. 调db驱动创建一个连接driver.Conn
  2. 将driver.Conn封装成driverConn
  3. 将driverConn放入连接池
func (db *DB) openNewConnection(ctx context.Context) {
    // 调db驱动创建一个连接
    ci, err := db.connector.Connect(ctx)
	db.mu.Lock()
	defer db.mu.Unlock()
	if db.closed {
		if err == nil {
			ci.Close()
		}
		db.numOpen--
		return
	}
  
	if err != nil {
		db.numOpen--
		db.putConnDBLocked(nil, err)
		db.maybeOpenNewConnections()
		return
	}

    // 封装成driverConn
	dc := &driverConn{
		db:         db,
		createdAt:  nowFunc(),
		returnedAt: nowFunc(),
		ci:         ci,
	}

    // 添加到连接池中
	if db.putConnDBLocked(dc, err) {
		db.addDepLocked(dc, dc)
	} else {
		db.numOpen--
		ci.Close()
	}
}

Query

整体流程

假设要执行如下查询:select * from user where id = 1

rows, err := db.QueryContext(ctx, "select * from user where id = ?", 1)
id := 0
name := ""
number := 0

// 接收查询结果
for rows.Next() {
    rows.Scan(&id, &name, &number)
}

db.QueryContext中用retry函数包装query方法

func (db *DB) QueryContext(ctx context.Context, query string, args ...any) (*Rows, error) {
	var rows *Rows
	var err error
    // 重试执行func
	err = db.retry(func(strategy connReuseStrategy) error {
		rows, err = db.query(ctx, query, args, strategy)
		return err
	})

	return rows, err
}

在遇到连接过期时,下层方法会返回ErrBadConn,此时可通过重试来提高成功率
retry逻辑如下:如果遇到返回err=driver.ErrBadConn,最多尝试3次。前两次可以从连接池获取或新建连接,第三次一定新建连接

const maxBadConnRetries = 2

func (db *DB) retry(fn func(strategy connReuseStrategy) error) error {
	for i := int64(0); i < maxBadConnRetries; i++ {
		err := fn(cachedOrNewConn)
        // err == nil: 返回
        // err != driver.ErrBadConn时,不重试
		if err == nil || !errors.Is(err, driver.ErrBadConn) {
			return err
		}
	}

	return fn(alwaysNewConn)
}

query方法先根据strategy获取连接,再调queryDC方法执行查询

关于怎么从连接池获取连接,怎么往连接池归还连接,在下文连接池模块详细分析

func (db *DB) query(ctx context.Context, query string, args []any, strategy connReuseStrategy) (*Rows, error) {
	dc, err := db.conn(ctx, strategy)
	if err != nil {
		return nil, err
	}

	return db.queryDC(ctx, nil, dc, dc.releaseConn, query, args)
}

queryDC:
在这里插入图片描述

  1. 如果driver.Conn实例实现了driver.QueryerContextdriver.Queryer接口,转化成对应接口调查询。此时

    1. 例如:mysql的驱动就实现了driver.QueryerContext接口
  2. 如果驱动的QueryerContext返回driver.ErrSkip 错误,执行接下里的操作

  3. 将sql语句 query预处理成stmt

  4. 将driver.Stmt 包装成driverStmt

  5. 执行sql

  6. 将响应结果rows(类型为driver.Rows)包装成sql.Rows类型返回

func (db *DB) queryDC(ctx, txctx context.Context, dc *driverConn, releaseConn func(error), query string, args []any) (*Rows, error) {
	queryerCtx, ok := dc.ci.(driver.QueryerContext)
	var queryer driver.Queryer
	if !ok {
		queryer, ok = dc.ci.(driver.Queryer)
	}
    // 如果驱动的连接实现了driver.QueryerContext或driver.Queryer接口
    // 转化为这两个接口执行查询
	if ok {
		var nvdargs []driver.NamedValue
		var rowsi driver.Rows
		var err error
		withLock(dc, func() {
			nvdargs, err = driverArgsConnLocked(dc.ci, nil, args)
			if err != nil {
				return
			}
			rowsi, err = ctxDriverQuery(ctx, queryerCtx, queryer, query, nvdargs)
		})
		if err != driver.ErrSkip {
			if err != nil {
				releaseConn(err)
				return nil, err
			}
			rows := &Rows{
				dc:          dc,
				releaseConn: releaseConn,
				rowsi:       rowsi,
			}
			rows.initContextClose(ctx, txctx)
			return rows, nil
		}
	}

    // 驱动没实现那两个接口,或者返回了driver.ErrSkip
    // 将query预处理成stmt
	var si driver.Stmt
	var err error
	withLock(dc, func() {
		si, err = ctxDriverPrepare(ctx, dc.ci, query)
	})
	if err != nil {
		releaseConn(err)
		return nil, err
	}

    // 包装成driverStmt
	ds := &driverStmt{Locker: dc, si: si}
    // 执行sql
	rowsi, err := rowsiFromStatement(ctx, dc.ci, ds, args...)
	if err != nil {
		ds.Close()
		releaseConn(err)
		return nil, err
	}

	// 将响应结果rows(类型driver.Rows)包装成sql.Rows类型返回
	rows := &Rows{
		dc:          dc,
		releaseConn: releaseConn,
		rowsi:       rowsi,
		closeStmt:   ds,
	}
	rows.initContextClose(ctx, txctx)
	return rows, nil
}

如果驱动的driver.Conn实现了driver.QueryerContext接口,就会调该接口的方法执行查询
否则往下走,执行预编译再查询的流程
ctxDriverPrepare:将query预处理成stmt,调驱动实现

func ctxDriverPrepare(ctx context.Context, ci driver.Conn, query string) (driver.Stmt, error) {
    // 如果连接实现了driver.ConnPrepareContext接口
    // 转换成该接口执行PrepareContext
	if ciCtx, is := ci.(driver.ConnPrepareContext); is {
		return ciCtx.PrepareContext(ctx, query)
	}

    // 调抽象Conn的Prepare方法,返抽象的Stmt
	si, err := ci.Prepare(query)
	if err == nil {
		select {
		default:
		case <-ctx.Done():
			si.Close()
			return nil, ctx.Err()
		}
	}
	return si, err
}

rowsiFromStatement:

  1. 处理参数,将参数转换成 driver.NamedValue,对每个参数并执行一些检查:

    1. 要么只能是基础类型
    2. 要么实现了Valuer接口,调Valuer.Value方法获得的返回值,也只能是基础类型
  2. 调driver.Stmt的Query方法,执行真正的sql

func rowsiFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, args ...any) (driver.Rows, error) {
   ds.Lock()
   defer ds.Unlock()
   dargs, err := driverArgsConnLocked(ci, ds, args)
   if err != nil {
      return nil, err
   }
   return ctxDriverStmtQuery(ctx, ds.si, dargs)
}

ctxDriverStmtQuery内部没啥好说的,交给驱动执行查询

参数校验

driverArgsConnLocked校验args是否是基础类型

  1. 对于每个参数,调checker方法检查是否合法
  2. 如果是预编译stmt,校验占位符个数和args个数是否相等
func driverArgsConnLocked(ci driver.Conn, ds *driverStmt, args []any) ([]driver.NamedValue, error) {
	nvargs := make([]driver.NamedValue, len(args))

	want := -1

	var si driver.Stmt
	var cc ccChecker
	if ds != nil {
		si = ds.si
		want = ds.si.NumInput()
		cc.want = want
	}

	nvc, ok := si.(driver.NamedValueChecker)
	if !ok {
		nvc, _ = ci.(driver.NamedValueChecker)
	}
	cci, ok := si.(driver.ColumnConverter)
	if ok {
		cc.cci = cci
	}

	
	var err error
	var n int
    // 遍历每个参数
	for _, arg := range args {
		nv := &nvargs[n]
		// ...
		nv.Ordinal = n + 1
		nv.Value = arg

		
		checker := defaultCheckNamedValue
		nextCC := false
		switch {
		case nvc != nil:
			nextCC = cci != nil
			checker = nvc.CheckNamedValue
		case cci != nil:
			checker = cc.CheckNamedValue
		}

	nextCheck:
        // 用checker校验每个参数
		err = checker(nv)
		switch err {
		case nil:
			n++
			continue
		// ...
		}
	}

    //  如果是预编译stmt,校验占位符个数和args个数是否相等
	if want != -1 && len(nvargs) != want {
		return nil, fmt.Errorf("sql: expected %d arguments, got %d", want, len(nvargs))
	}

	return nvargs, nil
}

mysql驱动mysqlConn 实现了driver.NamedValueChecker结构,总体来说也是判断是否是基础类型,或实现了driver.Valuer接口

func (c converter) ConvertValue(v any) (driver.Value, error) {
	if driver.IsValue(v) {
		return v, nil
	}

    // 如果实现了driver.Valuer接口
	if vr, ok := v.(driver.Valuer); ok {
		sv, err := callValuerValue(vr)
		if err != nil {
			return nil, err
		}
		if driver.IsValue(sv) {
			return sv, nil
		}
	
		if u, ok := sv.(uint64); ok {
			return u, nil
		}
		return nil, fmt.Errorf("non-Value type %T returned from Value", sv)
	}

    // 否则校验v是否是基础类型
	rv := reflect.ValueOf(v)
	switch rv.Kind() {
	case reflect.Ptr:
		// indirect pointers
		if rv.IsNil() {
			return nil, nil
		} else {
			return c.ConvertValue(rv.Elem().Interface())
		}
	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
		return rv.Int(), nil
	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
		return rv.Uint(), nil
	case reflect.Float32, reflect.Float64:
		return rv.Float(), nil
	case reflect.Bool:
		return rv.Bool(), nil
	case reflect.Slice:
		switch t := rv.Type(); {
		case t == jsonType:
			return v, nil
		case t.Elem().Kind() == reflect.Uint8:
			return rv.Bytes(), nil
		default:
			return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, t.Elem().Kind())
		}
	case reflect.String:
		return rv.String(), nil
	}
	return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind())
}

结果集处理

当调db.QueryContext执行查询时,需要用下面这种方式读取查询结果:

// 还有下一行时
for rows.Next() {
    // 读取当前行的数据到id,name,number中
    rows.Scan(&id, &name, &number)
}

Next:判断还有没有下一行

func (rs *Rows) Next() bool {
	rs.closemuRUnlockIfHeldByScan()
	if rs.contextDone.Load() != nil {
		return false
	}

	var doClose, ok bool
	withLock(rs.closemu.RLocker(), func() {
        // 调rs.nextLocked
		doClose, ok = rs.nextLocked()
	})
	if doClose {
		rs.Close()
	}
	if doClose && !ok {
		rs.hitEOF = true
	}
	return ok
}

nextLocked流程如下:

  1. 分配lastcols slice,长度和rowsi.Columns()相同,也就是和驱动返回的列长度一样
  2. 调驱动driver.Rows的Next方法,将数据都到lastcols中
  3. 如果驱动的Next返回io.EOF,说明没有下一行了
func (rs *Rows) nextLocked() (doClose, ok bool) {
	if rs.closed {
		return false, false
	}

	rs.dc.Lock()
	defer rs.dc.Unlock()

	if rs.lastcols == nil {
        // 分配len(Columns)个长度
		rs.lastcols = make([]driver.Value, len(rs.rowsi.Columns()))
	}

    // 调驱动的Next方法,将一行数据写入rs.lastcols中
	rs.lasterr = rs.rowsi.Next(rs.lastcols)
    // 如果没有下一行,驱动的Next方法会返回io.EO
	if rs.lasterr != nil {
		if rs.lasterr != io.EOF {
			return true, false
		}
		nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet)
		if !ok {
			return true, false
		}
		.
		if !nextResultSet.HasNextResultSet() {
			doClose = true
		}
		return doClose, false
	}
	return false, true
}

Scan:将Next已经从驱动都到的数据(暂存到lastcols中的),赋值到用户传入·的dest中

func (rs *Rows) Scan(dest ...any) error {
	// ...

    // 校验dest的长度必须要和rs.lastcols相同
	if len(dest) != len(rs.lastcols) {
		rs.closemuRUnlockIfHeldByScan()
		return fmt.Errorf("sql: expected %d destination arguments in Scan, not %d", len(rs.lastcols), len(dest))
	}

	for i, sv := range rs.lastcols {
        // 将sv设置到dest[i]中
		err := convertAssignRows(dest[i], sv, rs)
		if err != nil {
			rs.closemuRUnlockIfHeldByScan()
			return fmt.Errorf(`sql: Scan error on column index %d, name %q: %w`, i, rs.rowsi.Columns()[i], err)
		}
	}
	return nil
}

其中在convertAssignRows中,会根据dest的实际类型,和某一列驱动返回数据src的实际类型做转化匹配,然后把驱动返回的数据设置到dest中

如果dest实现了sql.Scanner接口,会调其Scan方法从src读取数据

释放连接

查询完毕后,调Rows.Close释放连接

func (rs *Rows) Close() error {
    // ...
	return rs.close(nil)
}
func (rs *Rows) close(err error) error {
	rs.closemu.Lock()
	defer rs.closemu.Unlock()

	if rs.closed {
		return nil
	}
	rs.closed = true

	if rs.lasterr == nil {
		rs.lasterr = err
	}

	withLock(rs.dc, func() {
		err = rs.rowsi.Close()
	})
	// ...

	if rs.closeStmt != nil {
		rs.closeStmt.Close()
	}
    // 释放连接
	rs.releaseConn(err)

	rs.lasterr = rs.lasterrOrErrLocked(err)
	return err
}

Exec

当需要执行增删改操作时,调db.ExecContent方法

例如:要把user表中,id=1行的name更新为jerry

import (
	"context"
	"database/sql"
	"fmt"
	_ "github.com/go-sql-driver/mysql"
	"testing"
)

func TestDB(t *testing.T) {
	db, err := sql.Open("mysql", "")
	if err != nil {
		panic(err)
	}

	ctx := context.Background()
	res, err := db.ExecContext(ctx, "update user set name = ? where id = ?", "jerry", 1)
	fmt.Println(err)

	affected, err := res.RowsAffected()
	lastInsertId, err := res.LastInsertId()
	fmt.Println(affected)   // 结果:1
	fmt.Println(lastInsertId)  // 结果:0
}

其整体流程和QueryContext非常类似

func (db *DB) ExecContext(ctx context.Context, query string, args ...any) (Result, error) {
    var res Result
    var err error

    err = db.retry(func(strategy connReuseStrategy) error {
       res, err = db.exec(ctx, query, args, strategy)
       return err
    })

    return res, err
}

exec先获取连接,再执行exec

func (db *DB) exec(ctx context.Context, query string, args []any, strategy connReuseStrategy) (Result, error) {
	dc, err := db.conn(ctx, strategy)
	if err != nil {
		return nil, err
	}
	return db.execDC(ctx, dc, dc.releaseConn, query, args)
}

execDC:整体流程和queryDC类似:

  1. 如果驱动的连接实现了ExecerContextExecer接口,转换成对应接口执行exec
  2. 否则先预编译,然后调resultFromStatement执行exec
func (db *DB) execDC(ctx context.Context, dc *driverConn, release func(error), query string, args []any) (res Result, err error) {
	// 在defer里释放连接
    defer func() {
		release(err)
	}()
	execerCtx, ok := dc.ci.(driver.ExecerContext)
	var execer driver.Execer
	if !ok {
		execer, ok = dc.ci.(driver.Execer)
	}
    // 如果驱动的连接实现了ExecerContext或Execer接口,转换成对应接口执行exec
	if ok {
		var nvdargs []driver.NamedValue
		var resi driver.Result
		withLock(dc, func() {
			nvdargs, err = driverArgsConnLocked(dc.ci, nil, args)
			if err != nil {
				return
			}
			resi, err = ctxDriverExec(ctx, execerCtx, execer, query, nvdargs)
		})
		if err != driver.ErrSkip {
			if err != nil {
				return nil, err
			}
			return driverResult{dc, resi}, nil
		}
	}

    // 先预编译
	var si driver.Stmt
	withLock(dc, func() {
		si, err = ctxDriverPrepare(ctx, dc.ci, query)
	})
	if err != nil {
		return nil, err
	}
	ds := &driverStmt{Locker: dc, si: si}
	defer ds.Close()
    // 执行exec
	return resultFromStatement(ctx, dc.ci, ds, args...)
}

resultFromStatement:调驱动执行Exec

func resultFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, args ...any) (Result, error) {
	ds.Lock()
	defer ds.Unlock()

	dargs, err := driverArgsConnLocked(ci, ds, args)
	if err != nil {
		return nil, err
	}

	resi, err := ctxDriverStmtExec(ctx, ds.si, dargs)
	if err != nil {
		return nil, err
	}
	return driverResult{ds.Locker, resi}, nil
}

预编译

PrepareContext

func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
	var stmt *Stmt
	var err error

	err = db.retry(func(strategy connReuseStrategy) error {
		stmt, err = db.prepare(ctx, query, strategy)
		return err
	})

	return stmt, err
}

内部调db.prepare:

func (db *DB) prepare(ctx context.Context, query string, strategy connReuseStrategy) (*Stmt, error) {
	// 获取连接
	dc, err := db.conn(ctx, strategy)
	if err != nil {
		return nil, err
	}
	return db.prepareDC(ctx, dc, dc.releaseConn, nil, query)
}

先获取连接,然后调db.prepareDC

  1. 调驱动的PrepareContext方法,将驱动返回的driver.Stmt包装成driverStmt
  2. 再包装成db.Stmt
  3. 将当前连接加入stmt.css中。之后执行Query或Exec时,只有出现再这里面的连接才能执行Exec或Query
func (db *DB) prepareDC(ctx context.Context, dc *driverConn, release func(error), cg stmtConnGrabber, query string) (*Stmt, error) {
	var ds *driverStmt
	var err error
	defer func() {
		release(err)
	}()
    // 调驱动的PrepareContext方法,将返回的driver.Stmt包装成driverStmt
	withLock(dc, func() {
		ds, err = dc.prepareLocked(ctx, cg, query)
	})
	if err != nil {
		return nil, err
	}
    // 再包装成db.Stmt返回
	stmt := &Stmt{
		db:    db,
		query: query,
		cg:    cg,
		cgds:  ds,
	}

	if cg == nil {
        // 将当前连接加入stmt.css中
		stmt.css = []connStmt{{dc, ds}}
		stmt.lastNumClosed = db.numClosed.Load()
		db.addDep(stmt, stmt)
	}
	return stmt, nil
}

Stmt.ExecContext

func (s *Stmt) ExecContext(ctx context.Context, args ...any) (Result, error) {
	// 可以读并发
    s.closemu.RLock()
	defer s.closemu.RUnlock()

	var res Result
	err := s.db.retry(func(strategy connReuseStrategy) error {
        // 获取一个连接来执行exec
		dc, releaseConn, ds, err := s.connStmt(ctx, strategy)
		if err != nil {
			return err
		}

        // 执行
		res, err = resultFromStatement(ctx, dc.ci, ds, args...)
		releaseConn(err)
		return err
	})

	return res, err
}

stmt.connStmt:

  1. 从连接池获取一个连接
  2. 如果这个连接预编译过该stmt,返回
  3. 否则需要在当前连接预编译该sql
func (s *Stmt) connStmt(ctx context.Context, strategy connReuseStrategy) (dc *driverConn, releaseConn func(error), ds *driverStmt, err error) {
	if err = s.stickyErr; err != nil {
		return
	}
	s.mu.Lock()
	if s.closed {
		s.mu.Unlock()
		err = errors.New("sql: statement is closed")
		return
	}

	
	// ...
	s.mu.Unlock()

    // 从连接池获取一个连接
	dc, err = s.db.conn(ctx, strategy)
	if err != nil {
		return nil, nil, nil, err
	}

	s.mu.Lock()
    // 如果当前连接预编译过该sql,直接用
	for _, v := range s.css {
		if v.dc == dc {
			s.mu.Unlock()
			return dc, dc.releaseConn, v.ds, nil
		}
	}
	s.mu.Unlock()

	// 需要在当前连接预编译该sql
	withLock(dc, func() {
		ds, err = s.prepareOnConnLocked(ctx, dc)
	})
	if err != nil {
		dc.releaseConn(err)
		return nil, nil, nil, err
	}

	return dc, dc.releaseConn, ds, nil
}

调驱动预编译该sql,然后把当前连接加入stmt.css中

func (s *Stmt) prepareOnConnLocked(ctx context.Context, dc *driverConn) (*driverStmt, error) {
    // 预编译sql
	si, err := dc.prepareLocked(ctx, s.cg, s.query)
	if err != nil {
		return nil, err
	}
	cs := connStmt{dc, si}
	s.mu.Lock()
    // 将连接加入stmt.css中
	s.css = append(s.css, cs)
	s.mu.Unlock()
	return cs.ds, nil
}

Stmt.QueryContext

和ExecContext类似

直接调db.QueryContext,会预编译吗

根据上一节的分析可以看出,直接调db.QueryContext时,如果驱动没有实现QueryerContext,或者实现了但调用结果返回driver.ErrSkip,db.QueryContext接下来就会执行预编译

而在下一篇文章要分析的mysql驱动,其对QueryerContext接口的实现中,如果dsn参数中interpolateParam=true,就会使用客户端预编译。否则返回driver.ErrSkip,让db.QueryContext接下来执行预编译

也就是说如果用mysql驱动,db.QueryContext中要么走客户端预编译,要么走mysql服务端预编译

事务

Begin

执行db.BeginTx开启事务:

func (db *DB) BeginTx(ctx context.Context, opts *TxOptions) (*Tx, error) {
	var tx *Tx
	var err error

	err = db.retry(func(strategy connReuseStrategy) error {
		tx, err = db.begin(ctx, opts, strategy)
		return err
	})

	return tx, err
}

begin:获取连接,然后调beginDC开启事务

func (db *DB) begin(ctx context.Context, opts *TxOptions, strategy connReuseStrategy) (tx *Tx, err error) {
	dc, err := db.conn(ctx, strategy)
	if err != nil {
		return nil, err
	}
	return db.beginDC(ctx, dc, dc.releaseConn, opts)
}

func (db *DB) beginDC(ctx context.Context, dc *driverConn, release func(error), opts *TxOptions) (tx *Tx, err error) {
	var txi driver.Tx
	keepConnOnRollback := false
	withLock(dc, func() {
		_, hasSessionResetter := dc.ci.(driver.SessionResetter)
		_, hasConnectionValidator := dc.ci.(driver.Validator)
		keepConnOnRollback = hasSessionResetter && hasConnectionValidator
        // 内部调驱动的Begin方法开启事务
		txi, err = ctxDriverBegin(ctx, opts, dc.ci)
	})
	if err != nil {
		release(err)
		return nil, err
	}

	ctx, cancel := context.WithCancel(ctx)
    // 将driver.Tx包装成sql.Tx返回
	tx = &Tx{
		db:                 db,
        // 属于哪个连接
		dc:                 dc,
		releaseConn:        release,
        // 只有driver.Tx
		txi:                txi,
		cancel:             cancel,
		keepConnOnRollback: keepConnOnRollback,
		ctx:                ctx,
	}
	go tx.awaitDone()
	return tx, nil
}

Exec

在Tx上执行各种操作,都是用Tx绑定的连接

以exec为例:

func (tx *Tx) ExecContext(ctx context.Context, query string, args ...any) (Result, error) {
    // 获取tx绑定的连接
	dc, release, err := tx.grabConn(ctx)
	if err != nil {
		return nil, err
	}
    // 执行exec
	return tx.db.execDC(ctx, dc, release, query, args)
}

Tx.grabConn:执行一些检查后,返回tx绑定的连接

func (tx *Tx) grabConn(ctx context.Context) (*driverConn, releaseConn, error) {
	select {
	default:
	case <-ctx.Done():
		return nil, nil, ctx.Err()
	}

	
	tx.closemu.RLock()
	if tx.isDone() {
		tx.closemu.RUnlock()
		return nil, nil, ErrTxDone
	}

    // 返回tx绑定的driverconn
	return tx.dc, tx.closemuRUnlockRelease, nil
}

Commit

func (tx *Tx) Commit() error {
	select {
	default:
	case <-tx.ctx.Done():
		if tx.done.Load() {
			return ErrTxDone
		}
		return tx.ctx.Err()
	}
    // 只能rollback或commit一次
	if !tx.done.CompareAndSwap(false, true) {
		return ErrTxDone
	}

	tx.cancel()
	tx.closemu.Lock()
	tx.closemu.Unlock()

	var err error
	withLock(tx.dc, func() {
        // 调driver.Tx的Commit方法
		err = tx.txi.Commit()
	})
	if !errors.Is(err, driver.ErrBadConn) {
		tx.closePrepared()
	}
    // 执行完成后释放连接
	tx.close(err)
	return err
}

Rollback

func (tx *Tx) rollback(discardConn bool) error {
    // 只能rollback或commit一次
	if !tx.done.CompareAndSwap(false, true) {
		return ErrTxDone
	}

    // ...

	tx.cancel()
	tx.closemu.Lock()
	tx.closemu.Unlock()

	var err error
	withLock(tx.dc, func() {
        // 调 driver.Tx的Rollback方法
		err = tx.txi.Rollback()
	})
	if !errors.Is(err, driver.ErrBadConn) {
		tx.closePrepared()
	}
	if discardConn {
		err = driver.ErrBadConn
	}
    // 执行完成后释放连接
	tx.close(err)
	return err
}

连接池

什么时机获取连接?执行query,exec,prepare,begin时
在这里插入图片描述


获取连接

  1. 如果策略是能拿池中的连接,且池内还有连接,就从db.freeConn尾部拿一个连接

    1. 上文提到,retry方法中,前两次可以拿池中连接,第三次只能新建连接
  2. 如果已经达到连接池设置的最大连接了,需要等待

  3. 如果没达到上限,调db驱动新建连接driver.Conn,包装成driverConn返回

func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn, error) {
	db.mu.Lock()
    // 校验db是否已关闭
	if db.closed {
		db.mu.Unlock()
		return nil, errDBClosed
	}

    // 校验ctx是否已到期或关闭
	select {
	default:
	case <-ctx.Done():
		db.mu.Unlock()
		return nil, ctx.Err()
	}
	lifetime := db.maxLifetime

	// 优先拿尾部的连接
	last := len(db.freeConn) - 1
    // 策略是能拿池中的连接,且池内还有连接
	if strategy == cachedOrNewConn && last >= 0 {
		conn := db.freeConn[last]
		db.freeConn = db.freeConn[:last]
		conn.inUse = true

        // 校验conn有没有超过最大存活时间
		if conn.expired(lifetime) {
			db.maxLifetimeClosed++
			db.mu.Unlock()
			conn.Close()
			return nil, driver.ErrBadConn
		}
		db.mu.Unlock()

		// ...
		return conn, nil
	}


    // 到这说明没有空闲连接了,或者策略就是要新建连接
  
	// 已经到连接池设置的最大连接了,需要等待
	if db.maxOpen > 0 && db.numOpen >= db.maxOpen {
        // 当前请求的等待channel
		req := make(chan connRequest, 1)
        // 把添加channel到等待队列中
		delHandle := db.connRequests.Add(req)
		db.waitCount++
		db.mu.Unlock()

		waitStart := nowFunc()

		// 阻塞等待
		select {
        // 超时了
		case <-ctx.Done():
			
			db.mu.Lock()
            // 从connRequests中删除
			deleted := db.connRequests.Delete(delHandle)
			db.mu.Unlock()

			db.waitDuration.Add(int64(time.Since(waitStart)))

            // 如果在删除前,channel已经收到连接了,将该连接放回去
            // 为啥有这个case?可能在从connRequests delete之前,就拿到连接了
			if !deleted {
				select {
				default:
				case ret, ok := <-req:
					if ok && ret.conn != nil {
						db.putConn(ret.conn, ret.err, false)
					}
				}
			}
			return nil, ctx.Err()
        // 收到连接  
		case ret, ok := <-req:
			db.waitDuration.Add(int64(time.Since(waitStart)))
			if !ok {
				return nil, errDBClosed
			}
			// 检测连接是否超过最大存活时间,如果是就关闭连接,让上层重试
            if strategy == cachedOrNewConn && ret.err == nil && ret.conn.expired(lifetime) {
				db.mu.Lock()
				db.maxLifetimeClosed++
				db.mu.Unlock()
				ret.conn.Close()
				return nil, driver.ErrBadConn
			}

            // ...
		}
	}

    // 到这说明没有空闲连接,且Open的连接没到上限,就要新建连接
	db.numOpen++ 
	db.mu.Unlock()
    // 调connector新建连接
	ci, err := db.connector.Connect(ctx)
	if err != nil {
		db.mu.Lock()
		db.numOpen-- 
		db.maybeOpenNewConnections()
		db.mu.Unlock()
		return nil, err
	}
	db.mu.Lock()
    // 包装成driverConn返回
	dc := &driverConn{
		db:         db,
		createdAt:  nowFunc(),
		returnedAt: nowFunc(),
		ci:         ci,
		inUse:      true,
	}
	db.addDepLocked(dc, dc)
	db.mu.Unlock()
	return dc, nil
}

其中driverConn.expired:校验连接有没有超过最大存活时间,根据driverConn.createdAt判断

func (dc *driverConn) expired(timeout time.Duration) bool {
	if timeout <= 0 {
		return false
	}
	return dc.createdAt.Add(timeout).Before(nowFunc())
}

connRequestSet.Add:添加一个channel到connRequestSet中,idx取添加时等待队列的长度

func (s *connRequestSet) Add(v chan connRequest) connRequestDelHandle {
	idx := len(s.s)
	
	idxPtr := &idx
	s.s = append(s.s, connRequestAndIndex{v, idxPtr})
	return connRequestDelHandle{idxPtr}
}

归还连接

什么时机归还连接?

  • query:读取完数据后,调rows.Close,内部会归还连接
  • exec:在execDC的defer中归还连接

入口函数为driverConn.releaseConn

func (dc *driverConn) releaseConn(err error) {
	dc.db.putConn(dc, err, true)
}

putConn归还连接,err是为啥要归还连接

func (db *DB) putConn(dc *driverConn, err error, resetSession bool) {
	// ...
	
    // 如果超过最大生命周期,将err设为driver.ErrBadConn
	if !errors.Is(err, driver.ErrBadConn) && dc.expired(db.maxLifetime) {
		db.maxLifetimeClosed++
		err = driver.ErrBadConn
	}

	dc.inUse = false
	dc.returnedAt = nowFunc()
    // 执行回调
	for _, fn := range dc.onPut {
		fn()
	}
	dc.onPut = nil

    // 如果是ErrBadConn,直接关闭
	if errors.Is(err, driver.ErrBadConn) {
        // 尝试补充一些连接
		db.maybeOpenNewConnections()
		db.mu.Unlock()
		dc.Close()
		return
	}
	if putConnHook != nil {
		putConnHook(db, dc)
	}

    // 放回连接池
	added := db.putConnDBLocked(dc, nil)
	db.mu.Unlock()

	if !added {
		dc.Close()
		return
	}
}

maybeOpenNewConnections:尝试补充一些连接

要补充多少个?min(等待获取连接的请求数, 池中还能装多少连接)

func (db *DB) maybeOpenNewConnections() {
	numRequests := db.connRequests.Len()
	if db.maxOpen > 0 {
		numCanOpen := db.maxOpen - db.numOpen
		if numRequests > numCanOpen {
			numRequests = numCanOpen
		}
	}
    // 要补充多少个?min(等待获取连接的请求数, 池中还能装多少连接)
	for numRequests > 0 {
		db.numOpen++ 
		numRequests--
		if db.closed {
			return
		}
        // 通知connectionOpener去补充连接
		db.openerCh <- struct{}{}
	}
}

putConnDBLocked 归还连接:

  1. 如果池已经满了,无法归还
  2. 如果有请求在等待获取连接,随机获取一个,将连接发往该请求的channel
  3. 否则,如果空闲连接数还没到maxIdle,将连接返回池中
  4. 尝试开启协程清理连接
func (db *DB) putConnDBLocked(dc *driverConn, err error) bool {
	if db.closed {
		return false
	}

    // 已经满了,无法归还
	if db.maxOpen > 0 && db.numOpen > db.maxOpen {
		return false
	}

    // 有请求在等待获取连接,随机获取一个
    if req, ok := db.connRequests.TakeRandom(); ok {
		if err == nil {
			dc.inUse = true
		}
		req <- connRequest{
			conn: dc,
			err:  err,
		}
		return true
    // 没有请求在等待获取连接,放回池中    
	} else if err == nil && !db.closed {
        // 空闲连接数还没到maxIdle,可以放回池中
		if db.maxIdleConnsLocked() > len(db.freeConn) {
			db.freeConn = append(db.freeConn, dc)
            // 尝试开启协程清理连接
			db.startCleanerLocked()
			return true
		}
		db.maxIdleClosed++
	}
	return false    
}

清理过期连接

什么连接需要被清理?

  • idleTime > maxIdleTime
  • lifeTime > maxLifeTime

什么时机进行清理?

当归还连接时,尝试开启后台协程定期清理,启动前会检查cleaner是否存在,保证全局只存在一个cleaner

当没有空闲连接时,退出清理任务

func (db *DB) startCleanerLocked() {
    // 需要清理,设置了maxLifetime或maxIdleTime 并且 有打开的连接,并且没有cleaner在清理
	if (db.maxLifetime > 0 || db.maxIdleTime > 0) && db.numOpen > 0 && db.cleanerCh == nil {
		db.cleanerCh = make(chan struct{}, 1)
		go db.connectionCleaner(db.shortestIdleTimeLocked())
	}
}

connectionCleaner中每隔一定时间扫描一下空闲连接,找到哪些需要被关闭,然后依次关闭
如果已经没有打开的连接了,就会退出协程

func (db *DB) connectionCleaner(d time.Duration) {
	const minInterval = time.Second

	if d < minInterval {
		d = minInterval
	}
	t := time.NewTimer(d)

	for {
		select {
		case <-t.C:
        case <-db.cleanerCh: // maxLifetime was changed or db was closed.
		}

		db.mu.Lock()

		d = db.shortestIdleTimeLocked()
        // 没有打开的连接,就会退出协程
		if db.closed || db.numOpen == 0 || d <= 0 {
			db.cleanerCh = nil
			db.mu.Unlock()
			return
		}

        // 拿到要close的连接
		d, closing := db.connectionCleanerRunLocked(d)
		db.mu.Unlock()
        // 挨个close
		for _, c := range closing {
			c.Close()
		}

		if d < minInterval {
			d = minInterval
		}

		if !t.Stop() {
			select {
			case <-t.C:
			default:
			}
		}
		t.Reset(d)
	}
}

connectionCleanerRunLocked:扫描空闲连接哪些需要被关闭

  1. 找到所有超过最大idleTime的conn

    1. 由于空闲连接slice采用的策略是:每次获取连接是从尾部获取,每次放回连接时往尾部放。因此从后往前遍历,只要遇到超过最大idleTime的conn就可以停止,前面的都是满足条件的
  2. 找到所有超过最大lifeTime的conn

func (db *DB) connectionCleanerRunLocked(d time.Duration) (time.Duration, []*driverConn) {
	var idleClosing int64
	var closing []*driverConn

    // 找到所有超过最大空闲时间的conn
	if db.maxIdleTime > 0 {
		// db.freeConn 已经按照returnAt从小到大排序
        // idleSince :returnAt小于这个值的都算过期(idleTime超过阈值了)
		idleSince := nowFunc().Add(-db.maxIdleTime)
		last := len(db.freeConn) - 1
        // 
		for i := last; i >= 0; i-- {
			c := db.freeConn[i]
			if c.returnedAt.Before(idleSince) {
                // 此时的i:最后一个过期的位置
                // 将i之前的删除
				i++
				closing = db.freeConn[:i:i]
				db.freeConn = db.freeConn[i:]
				idleClosing = int64(len(closing))
				db.maxIdleTimeClosed += idleClosing
				break
			}
		}

		// ...
	}

    // 找到所有超过最大lifeTime的conn
	if db.maxLifetime > 0 {
        // 小于这个值的都算过期(超过了最大时间)
		expiredSince := nowFunc().Add(-db.maxLifetime)
		for i := 0; i < len(db.freeConn); i++ {
			c := db.freeConn[i]
            // 过期了
			if c.createdAt.Before(expiredSince) {
                // 加到待close列表
				closing = append(closing, c)

				last := len(db.freeConn) - 1
				// 删除i位置的连接
				copy(db.freeConn[i:], db.freeConn[i+1:])
				db.freeConn[last] = nil
				db.freeConn = db.freeConn[:last]
				i--
            } 

            // ...
		}
		db.maxLifetimeClosed += int64(len(closing)) - idleClosing
	}

	return d, closing
}

总结

sql包将对具体db的操作委托为接口实现,自己完成与db无关的部分,包括:

  • 在driver包定义一系列抽象的接口,交给各个db的驱动实现
  • 提供连接池以及相关操作,包括获取连接,释放连接,过期连接清理

参考

Golang sql 标准库源码解析


http://www.kler.cn/a/428819.html

相关文章:

  • 无人机高速无刷动力电机核心设计技术
  • 基于Python的多元医疗知识图谱构建与应用研究(上)
  • 在 Vue 3 项目中集成和使用 vue3-video-play
  • flutter 装饰类【BoxDecoration】
  • docker 基础语法学习,K8s基础语法学习,零基础学习
  • Pytorch|YOLO
  • ShardingSphere 数据库中间件
  • k8s 为什么需要Pod?
  • 高级java每日一道面试题-2024年12月05日-JVM篇-什么是TLAB?
  • 计算机键盘的演变 | 键盘键名称及其功能 | 键盘指法
  • 软件无线电安全之GNU Radio基础(下)
  • 英文论文翻译成中文,怎样翻译更地道?
  • 【开源免费】基于Vue和SpringBoot的高校学科竞赛平台(附论文)
  • 普通算法——一维前缀和
  • k8s-Informer概要解析(2)
  • Mybatis-plus 多租户插件
  • 如何使用Apache HttpClient来执行GET、POST、PUT和DELETE请求
  • 【JAVA】Java高级:数据库监控与调优:SQL调优与执行计划的分析
  • MySQL(四)--索引
  • QNX系统的编译过程
  • 【uniapp】swiper切换时,v-for重新渲染页面导致文字在视觉上的拉扯问题
  • 40分钟学 Go 语言高并发:【实战】分布式缓存系统
  • Go学习:变量
  • 重生之我在21世纪学C++—关系、条件、逻辑操作符
  • 第三部分:进阶概念 7.数组与对象 --[JavaScript 新手村:开启编程之旅的第一步]
  • 猜数字的趣味小游戏——rand函数、srand函数、time函数的使用