40分钟学 Go 语言高并发:Context包与并发控制
Context包与并发控制
学习目标
知识点 | 掌握程度 | 应用场景 |
---|---|---|
context原理 | 深入理解实现机制 | 并发控制和请求链路追踪 |
超时控制 | 掌握超时设置和处理 | API请求超时、任务限时控制 |
取消信号传播 | 理解取消机制和传播链 | 优雅退出、资源释放 |
context最佳实践 | 掌握使用规范和技巧 | 工程实践中的常见场景 |
1. Context原理
1.1 Context基本结构和实现
让我们先看一个完整的Context使用示例:
package main
import (
"context"
"fmt"
"log"
"time"
)
// 请求追踪信息
type RequestInfo struct {
TraceID string
SessionID string
StartTime time.Time
}
// 服务接口
type Service interface {
HandleRequest(ctx context.Context, req string) (string, error)
}
// 业务服务实现
type BusinessService struct {
name string
}
func NewBusinessService(name string) *BusinessService {
return &BusinessService{name: name}
}
// 处理请求
func (s *BusinessService) HandleRequest(ctx context.Context, req string) (string, error) {
// 获取请求追踪信息
info, ok := ctx.Value("request-info").(*RequestInfo)
if !ok {
return "", fmt.Errorf("request info not found in context")
}
log.Printf("[%s] Processing request: %s, TraceID: %s, Session: %s\n",
s.name, req, info.TraceID, info.SessionID)
// 模拟处理过程
select {
case <-time.After(2 * time.Second):
return fmt.Sprintf("Result for %s", req), nil
case <-ctx.Done():
return "", ctx.Err()
}
}
// 请求中间件
func requestMiddleware(next Service) Service {
return &middlewareService{next: next}
}
type middlewareService struct {
next Service
}
func (m *middlewareService) HandleRequest(ctx context.Context, req string) (string, error) {
// 开始时间
startTime := time.Now()
// 添加请求信息到context
info := &RequestInfo{
TraceID: fmt.Sprintf("trace-%d", time.Now().UnixNano()),
SessionID: fmt.Sprintf("session-%d", time.Now().Unix()),
StartTime: startTime,
}
ctx = context.WithValue(ctx, "request-info", info)
// 调用下一个处理器
result, err := m.next.HandleRequest(ctx, req)
// 记录处理时间
duration := time.Since(startTime)
log.Printf("Request completed in %v, TraceID: %s\n", duration, info.TraceID)
return result, err
}
func main() {
// 创建服务
service := requestMiddleware(NewBusinessService("UserService"))
// 创建基础context
ctx := context.Background()
// 添加超时控制
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
// 处理请求
result, err := service.HandleRequest(ctx, "get user profile")
if err != nil {
log.Printf("Request failed: %v\n", err)
return
}
log.Printf("Request succeeded: %s\n", result)
// 模拟超时场景
ctx, cancel = context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
result, err = service.HandleRequest(ctx, "get user settings")
if err != nil {
log.Printf("Request failed: %v\n", err)
return
}
}
2. 超时控制
让我们实现一个带有超时控制的HTTP服务:
package main
import (
"context"
"encoding/json"
"fmt"
"log"
"net/http"
"time"
)
// 响应结构
type Response struct {
Data interface{} `json:"data,omitempty"`
Error string `json:"error,omitempty"`
}
// 服务配置
type ServiceConfig struct {
Timeout time.Duration
MaxConcurrent int
RetryAttempts int
RetryDelay time.Duration
}
// HTTP客户端包装器
type HTTPClient struct {
client *http.Client
config ServiceConfig
limiter chan struct{} // 并发限制器
}
// 创建新的HTTP客户端
func NewHTTPClient(config ServiceConfig) *HTTPClient {
return &HTTPClient{
client: &http.Client{
Timeout: config.Timeout,
},
config: config,
limiter: make(chan struct{}, config.MaxConcurrent),
}
}
// 发送HTTP请求
func (c *HTTPClient) DoRequest(ctx context.Context, method, url string) (*Response, error) {
var lastErr error
for attempt := 0; attempt <= c.config.RetryAttempts; attempt++ {
select {
case <-ctx.Done():
return nil, ctx.Err()
case c.limiter <- struct{}{}: // 获取并发许可
}
// 确保释放并发许可
defer func() {
<-c.limiter
}()
// 创建请求
req, err := http.NewRequestWithContext(ctx, method, url, nil)
if err != nil {
return nil, fmt.Errorf("create request failed: %w", err)
}
// 设置请求超时
reqCtx, cancel := context.WithTimeout(ctx, c.config.Timeout)
defer cancel()
// 执行请求
resp, err := c.client.Do(req.WithContext(reqCtx))
if err != nil {
lastErr = err
log.Printf("Request failed (attempt %d): %v\n", attempt+1, err)
// 如果不是最后一次尝试,等待后重试
if attempt < c.config.RetryAttempts {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(c.config.RetryDelay):
continue
}
}
continue
}
defer resp.Body.Close()
// 解析响应
var result Response
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("decode response failed: %w", err)
}
return &result, nil
}
return nil, fmt.Errorf("all retry attempts failed, last error: %v", lastErr)
}
// 处理HTTP请求的处理器
func handleRequest(w http.ResponseWriter, r *http.Request) {
// 创建context
ctx := r.Context()
// 模拟长时间处理
select {
case <-time.After(2 * time.Second):
response := Response{
Data: "Request processed successfully",
}
json.NewEncoder(w).Encode(response)
case <-ctx.Done():
response := Response{
Error: "Request timeout",
}
w.WriteHeader(http.StatusGatewayTimeout)
json.NewEncoder(w).Encode(response)
}
}
func main() {
// 配置HTTP客户端
config := ServiceConfig{
Timeout: 5 * time.Second,
MaxConcurrent: 10,
RetryAttempts: 3,
RetryDelay: time.Second,
}
client := NewHTTPClient(config)
// 创建HTTP服务器
http.HandleFunc("/api", handleRequest)
// 启动服务器
go func() {
log.Println("Server starting on :8080")
if err := http.ListenAndServe(":8080", nil); err != nil {
log.Fatal(err)
}
}()
// 等待服务器启动
time.Sleep(time.Second)
// 测试请求
ctx := context.Background()
// 测试正常请求
resp, err := client.DoRequest(ctx, "GET", "http://localhost:8080/api")
if err != nil {
log.Printf("Request failed: %v\n", err)
} else {
log.Printf("Response: %+v\n", resp)
}
// 测试超时请求
ctx, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
resp, err = client.DoRequest(ctx, "GET", "http://localhost:8080/api")
if err != nil {
log.Printf("Request failed (expected): %v\n", err)
} else {
log.Printf("Response: %+v\n", resp)
}
// 保持主程序运行
select {}
}
package main
import (
"context"
"fmt"
"math/rand"
"sync"
"time"
)
// 请求处理器
type RequestHandler struct {
requests chan Request
responses chan Response
done chan struct{}
wg sync.WaitGroup
}
// 请求结构
type Request struct {
ID int
Timeout time.Duration
Data string
}
// 响应结构
type Response struct {
RequestID int
Result string
Error error
}
// 创建新的请求处理器
func NewRequestHandler() *RequestHandler {
return &RequestHandler{
requests: make(chan Request, 100),
responses: make(chan Response, 100),
done: make(chan struct{}),
}
}
// 启动处理器
func (h *RequestHandler) Start(workers int) {
for i := 0; i < workers; i++ {
h.wg.Add(1)
go h.worker(i)
}
}
// 工作协程
func (h *RequestHandler) worker(id int) {
defer h.wg.Done()
for {
select {
case req, ok := <-h.requests:
if !ok {
fmt.Printf("Worker %d: request channel closed\n", id)
return
}
// 创建context用于超时控制
ctx, cancel := context.WithTimeout(context.Background(), req.Timeout)
// 处理请求
response := h.processRequest(ctx, req)
// 发送响应
select {
case h.responses <- response:
fmt.Printf("Worker %d: sent response for request %d\n",
id, req.ID)
case <-h.done:
cancel()
return
}
cancel() // 清理context
case <-h.done:
fmt.Printf("Worker %d: received stop signal\n", id)
return
}
}
}
// 处理单个请求
func (h *RequestHandler) processRequest(ctx context.Context, req Request) Response {
// 模拟处理时间
processTime := time.Duration(rand.Intn(int(req.Timeout))) + req.Timeout/2
select {
case <-time.After(processTime):
return Response{
RequestID: req.ID,
Result: fmt.Sprintf("Processed: %s", req.Data),
}
case <-ctx.Done():
return Response{
RequestID: req.ID,
Error: ctx.Err(),
}
}
}
// 提交请求
func (h *RequestHandler) SubmitRequest(req Request) error {
select {
case h.requests <- req:
return nil
case <-h.done:
return fmt.Errorf("handler is stopped")
}
}
// 获取响应
func (h *RequestHandler) GetResponse() (Response, error) {
select {
case resp := <-h.responses:
return resp, nil
case <-h.done:
return Response{}, fmt.Errorf("handler is stopped")
}
}
// 停止处理器
func (h *RequestHandler) Stop() {
close(h.done)
h.wg.Wait()
close(h.requests)
close(h.responses)
}
func main() {
// 创建请求处理器
handler := NewRequestHandler()
handler.Start(3)
// 发送一些测试请求
requests := []Request{
{ID: 1, Timeout: time.Second, Data: "Fast request"},
{ID: 2, Timeout: time.Second * 2, Data: "Normal request"},
{ID: 3, Timeout: time.Millisecond * 500, Data: "Quick request"},
{ID: 4, Timeout: time.Second * 3, Data: "Slow request"},
}
// 提交请求
for _, req := range requests {
if err := handler.SubmitRequest(req); err != nil {
fmt.Printf("Failed to submit request %d: %v\n", req.ID, err)
continue
}
fmt.Printf("Submitted request %d\n", req.ID)
}
// 收集响应
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < len(requests); i++ {
resp, err := handler.GetResponse()
if err != nil {
fmt.Printf("Failed to get response: %v\n", err)
continue
}
if resp.Error != nil {
fmt.Printf("Request %d failed: %v\n", resp.RequestID, resp.Error)
} else {
fmt.Printf("Request %d succeeded: %s\n", resp.RequestID, resp.Result)
}
}
}()
// 等待所有响应处理完成
wg.Wait()
// 停止处理器
handler.Stop()
fmt.Println("Main: processing completed")
}
3. 取消信号传播
让我们继续完成取消信号传播的示例代码:
package main
import (
"context"
"fmt"
"log"
"sync"
"time"
)
// 任务定义
type Task struct {
ID string
Process func(ctx context.Context) (interface{}, error)
}
// 工作池
type WorkerPool struct {
workers int
tasks chan Task
results chan interface{}
errors chan error
done chan struct{}
wg sync.WaitGroup
}
// 创建新的工作池
func NewWorkerPool(workers int) *WorkerPool {
return &WorkerPool{
workers: workers,
tasks: make(chan Task, workers*2),
results: make(chan interface{}, workers*2),
errors: make(chan error, workers*2),
done: make(chan struct{}),
}
}
// 启动工作池
func (p *WorkerPool) Start(ctx context.Context) {
// 启动workers
for i := 0; i < p.workers; i++ {
p.wg.Add(1)
go p.worker(ctx, i)
}
// 等待所有worker完成
go func() {
p.wg.Wait()
close(p.done)
close(p.results)
close(p.errors)
}()
}
// worker处理任务
func (p *WorkerPool) worker(ctx context.Context, id int) {
defer p.wg.Done()
log.Printf("Worker %d started\n", id)
for {
select {
case <-ctx.Done():
log.Printf("Worker %d stopped: %v\n", id, ctx.Err())
return
case task, ok := <-p.tasks:
if !ok {
log.Printf("Worker %d: task channel closed\n", id)
return
}
log.Printf("Worker %d processing task %s\n", id, task.ID)
// 创建任务专用的context
taskCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
// 执行任务
result, err := task.Process(taskCtx)
cancel() // 释放任务context资源
if err != nil {
select {
case p.errors <- fmt.Errorf("task %s failed: %w", task.ID, err):
case <-ctx.Done():
return
}
} else {
select {
case p.results <- result:
case <-ctx.Done():
return
}
}
}
}
}
// 提交任务
func (p *WorkerPool) Submit(task Task) error {
select {
case p.tasks <- task:
return nil
case <-p.done:
return fmt.Errorf("worker pool is closed")
}
}
// 关闭工作池
func (p *WorkerPool) Close() {
close(p.tasks)
}
// 获取结果通道
func (p *WorkerPool) Results() <-chan interface{} {
return p.results
}
// 获取错误通道
func (p *WorkerPool) Errors() <-chan error {
return p.errors
}
func main() {
// 创建根context
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// 创建工作池
pool := NewWorkerPool(3)
pool.Start(ctx)
// 创建模拟任务
tasks := []Task{
{
ID: "task-1",
Process: func(ctx context.Context) (interface{}, error) {
select {
case <-time.After(2 * time.Second):
return "Task 1 completed", nil
case <-ctx.Done():
return nil, ctx.Err()
}
},
},
{
ID: "task-2",
Process: func(ctx context.Context) (interface{}, error) {
select {
case <-time.After(3 * time.Second):
return "Task 2 completed", nil
case <-ctx.Done():
return nil, ctx.Err()
}
},
},
{
ID: "task-3",
Process: func(ctx context.Context) (interface{}, error) {
select {
case <-time.After(1 * time.Second):
return nil, fmt.Errorf("task 3 failed")
case <-ctx.Done():
return nil, ctx.Err()
}
},
},
}
// 提交任务
for _, task := range tasks {
if err := pool.Submit(task); err != nil {
log.Printf("Failed to submit task %s: %v\n", task.ID, err)
}
}
// 等待3秒后取消所有任务
go func() {
time.Sleep(3 * time.Second)
log.Println("Cancelling all tasks...")
cancel()
}()
// 收集结果和错误
completed := 0
expected := len(tasks)
for completed < expected {
select {
case result, ok := <-pool.Results():
if !ok {
continue
}
log.Printf("Got result: %v\n", result)
completed++
case err, ok := <-pool.Errors():
if !ok {
continue
}
log.Printf("Got error: %v\n", err)
completed++
case <-ctx.Done():
log.Printf("Main: context cancelled: %v\n", ctx.Err())
completed = expected // 强制退出循环
}
}
// 关闭工作池
pool.Close()
// 等待工作池完全关闭
<-pool.done
log.Println("All workers stopped")
}
3.1 取消信号传播流程图
4. Context最佳实践
4.1 Context使用规范
- 函数调用链传递
// 推荐
func HandleRequest(ctx context.Context, req *Request) error
// 不推荐
func HandleRequest(timeout time.Duration, req *Request) error
- Context应作为第一个参数
// 推荐
func ProcessTask(ctx context.Context, task *Task) error
// 不推荐
func ProcessTask(task *Task, ctx context.Context) error
- 不要储存Context在结构体中
// 不推荐
type Service struct {
ctx context.Context
}
// 推荐
type Service struct {
// 其他字段
}
func (s *Service) DoWork(ctx context.Context) error
4.2 Context使用注意事项
- 不要将nil传递给context参数
// 推荐
ctx := context.Background()
ProcessTask(ctx, task)
// 不推荐
ProcessTask(nil, task)
- context.Value应该只用于请求作用域数据
// 推荐
ctx = context.WithValue(ctx, "request-id", requestID)
// 不推荐 - 配置信息应该通过其他方式传递
ctx = context.WithValue(ctx, "db-config", dbConfig)
- 正确处理取消信号
select {
case <-ctx.Done():
return ctx.Err()
default:
// 继续处理
}
4.3 实践建议
- 超时控制
- 设置合理的超时时间
- 在不同层级使用不同的超时时间
- 确保资源正确释放
- 错误处理
- 区分超时和取消错误
- 传递有意义的错误信息
- 实现优雅降级
- 性能优化
- 避免创建过多的context
- 合理使用context.Value
- 及时取消不需要的操作
- 日志追踪
- 记录关键操作的耗时
- 追踪请求的完整链路
- 记录取消原因
总结
关键点回顾
- Context原理
- 继承关系
- 值传递机制
- 生命周期管理
- 超时控制
- 设置超时时间
- 处理超时信号
- 资源清理
- 取消信号传播
- 信号传递机制
- 取消处理流程
- 资源释放
- 最佳实践
- 使用规范
- 注意事项
- 优化建议
实践建议
- 代码规范
- 遵循命名约定
- 合理组织代码结构
- 添加必要的注释
- 错误处理
- 使用有意义的错误信息
- 实现错误恢复机制
- 记录错误日志
- 性能优化
- 减少不必要的context创建
- 避免context.Value滥用
- 及时释放资源
怎么样今天的内容还满意吗?再次感谢观众老爷的观看,关注GZH:凡人的AI工具箱,回复666,送您价值199的AI大礼包。最后,祝您早日实现财务自由,还请给个赞,谢谢!