当前位置: 首页 > article >正文

动手学深度学习-3.2 线性回归的从0开始

以下是代码的逐段解析及其实际作用:


1. 环境设置与库导入

%matplotlib inline
import random
import torch
from d2l import torch as d2l
  • 作用
    • %matplotlib inline:在 Jupyter Notebook 中内嵌显示 matplotlib 图形。
    • random:生成随机索引用于数据打乱。
    • torch:PyTorch 深度学习框架。
    • d2l:《动手学深度学习》提供的工具函数库(如绘图工具)。

2. 生成合成数据

假设真实权重向量为 w true ∈ R n \mathbf{w}_{\text{true}} \in \mathbb{R}^n wtrueRn,偏置为 b true b_{\text{true}} btrue,噪声为高斯分布 ϵ ∼ N ( 0 , σ 2 ) \epsilon \sim \mathcal{N}(0, \sigma^2) ϵN(0,σ2),则合成数据生成公式为:
y = X w true + b true + ϵ \mathbf{y} = \mathbf{X} \mathbf{w}_{\text{true}} + b_{\text{true}} + \epsilon y=Xwtrue+btrue+ϵ
其中:

  • X ∈ R m × n \mathbf{X} \in \mathbb{R}^{m \times n} XRm×n:输入特征矩阵( m m m 个样本, n n n 个特征)。
  • w true ∈ R n \mathbf{w}_{\text{true}} \in \mathbb{R}^n wtrueRn:真实权重向量。
  • ϵ ∈ R m \epsilon \in \mathbb{R}^m ϵRm:噪声向量。
def synthetic_data(w, b, num_examples):  #@save
    """生成y=Xw+b+噪声"""
    X = torch.normal(0, 1, (num_examples, len(w)))  # 生成标准正态分布的输入特征 num_examples行,len(w)列
    y = torch.matmul(X, w) + b                      # 计算线性输出 y = Xw + b
    y += torch.normal(0, 0.01, y.shape)             # 添加高斯噪声
    return X, y.reshape((-1, 1))                    # y行数不定(值为-1,列数为1)

true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)

生成的函数是一个二维线性回归模型,其数学表达式为:

y = w 1 x 1 + w 2 x 2 + b + ϵ y = w_1 x_1 + w_2 x_2 + b + \epsilon y=w1x1+w2x2+b+ϵ

其中:

  • 权重 w = [ w 1 , w 2 ] = [ 2 , − 3.4 ] \mathbf{w} = [w_1, w_2] = [2, -3.4] w=[w1,w2]=[2,3.4],由 true_w 定义。
  • 偏置 b = 4.2 b = 4.2 b=4.2,由 true_b 定义。
  • 噪声 ϵ ∼ N ( 0 , 0.0 1 2 ) \epsilon \sim \mathcal{N}(0, 0.01^2) ϵN(0,0.012),即均值为 0、标准差为 0.01 的高斯噪声。

展开为标量形式:
y i = 2 ⋅ x i 1 − 3.4 ⋅ x i 2 + 4.2 + ϵ i ( i = 1 , 2 , … , 1000 ) y_i = 2 \cdot x_{i1} - 3.4 \cdot x_{i2} + 4.2 + \epsilon_i \quad (i = 1, 2, \dots, 1000) yi=2xi13.4xi2+4.2+ϵi(i=1,2,,1000)


3. 数据可视化

d2l.set_figsize()
d2l.plt.scatter(features[:, (1)].detach().numpy(), labels.detach().numpy(), 1);
  • 绘制第二个特征(features[:,1] => n行第1列)与标签 labels 的散点图。

4. 定义数据迭代器

def data_iter(batch_size, features, labels):
    num_examples = len(features)
    indices = list(range(num_examples))
    random.shuffle(indices)  # 打乱索引顺序
    for i in range(0, num_examples, batch_size):
        batch_indices = torch.tensor(indices[i: min(i + batch_size, num_examples)])
        yield features[batch_indices], labels[batch_indices]  # 生成小批量数据
  • 作用
    • 将数据集按 batch_size 划分为小批量,并随机打乱顺序。
    • 使用生成器 (yield) 逐批返回数据,避免一次性加载全部数据到内存。

5. 初始化模型参数

w = torch.normal(0, 0.01, size=(2,1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)
  • 初始化w和b的值
    • w:从均值为 0、标准差为 0.01 的正态分布中初始化权重,启用梯度追踪。
    • b:初始化为 0 的偏置,启用梯度追踪。
    • 参数需梯度追踪以支持反向传播。

6. 定义模型、损失函数和优化器

def linreg(X, w, b):  #@save
    """线性回归模型"""
    return torch.matmul(X, w) + b

def squared_loss(y_hat, y):  #@save
    """均方损失"""
    return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2  # 除以2便于梯度计算

def sgd(params, lr, batch_size):  #@save
    """小批量随机梯度下降"""
    with torch.no_grad():  # 禁用梯度计算
        for param in params:
            param -= lr * param.grad / batch_size  # 参数更新
            param.grad.zero_()                     # 梯度清零
  • linreg:模型预测值 y ^ \hat{\mathbf{y}} y^ 的矩阵形式为:
    y ^ = X w + b \hat{\mathbf{y}} = \mathbf{X} \mathbf{w} + b y^=Xw+b
    其中:

    • w ∈ R n \mathbf{w} \in \mathbb{R}^n wRn:待学习的权重向量。
    • b ∈ R b \in \mathbb{R} bR:待学习的偏置。
  • squared_loss:损失函数的矩阵形式为:
    L = 1 2 ∥ y ^ − y ∥ 2 L = \frac{1}{2} \| \hat{\mathbf{y}} - \mathbf{y} \|^2 L=21y^y2

    L ( w , b ) = 1 2 m ∥ X w + b − y ∥ 2 L(\mathbf{w}, b) = \frac{1}{2m} \| \mathbf{X} \mathbf{w} + b - \mathbf{y} \|^2 L(w,b)=2m1Xw+by2
    展开后:
    L ( w , b ) = 1 2 m ( X w + b 1 − y ) ⊤ ( X w + b 1 − y ) L(\mathbf{w}, b) = \frac{1}{2m} (\mathbf{X} \mathbf{w} + b \mathbf{1} - \mathbf{y})^\top (\mathbf{X} \mathbf{w} + b \mathbf{1} - \mathbf{y}) L(w,b)=2m1(Xw+b1y)(Xw+b1y)

  • sgd:小批量随机梯度下降优化器,

    • 对权重 w \mathbf{w} w 的梯度
      ∇ w L = 1 m X ⊤ ( X w + b 1 − y ) \nabla_{\mathbf{w}} L = \frac{1}{m} \mathbf{X}^\top (\mathbf{X} \mathbf{w} + b \mathbf{1} - \mathbf{y}) wL=m1X(Xw+b1y)

    • 对偏置 b b b 的梯度
      ∇ b L = 1 m 1 ⊤ ( X w + b 1 − y ) , 1 为单位列向量 \nabla_{b} L = \frac{1}{m} \mathbf{1}^\top (\mathbf{X} \mathbf{w} + b \mathbf{1} - \mathbf{y}),\mathbf{1} 为单位列向量 bL=m11(Xw+b1y)1为单位列向量

    • 使用学习率 η \eta η,参数更新公式为:
      w ← w − η ∇ w L b ← b − η ∇ b L \mathbf{w} \leftarrow \mathbf{w} - \eta \nabla_{\mathbf{w}} L\\ b \leftarrow b - \eta \nabla_{b} L wwηwLbbηbL


7. 训练循环

lr = 0.03
num_epochs = 3
batch_size = 10  # 需补充定义(原代码未显式定义)

for epoch in range(num_epochs):
    for X, y in data_iter(batch_size, features, labels):
        l = loss(net(X, w, b), y)  # 计算小批量损失
        l.sum().backward()         # 反向传播计算梯度
        sgd([w, b], lr, batch_size) # 更新参数
    with torch.no_grad():
        train_l = loss(net(features, w, b), labels)
        print(f'epoch {epoch + 1}, loss {float(train_l.mean()):f}')
  • 作用

    • 外层循环:遍历训练轮次 (num_epochs)。
    • 内层循环:按小批量遍历数据,计算损失并反向传播。
    • l.sum().backward():将小批量损失求和后反向传播,计算梯度。
    • sgd:根据梯度更新参数,梯度需除以 batch_size 以保持学习率一致性。
    • 每个 epoch 结束后,计算并打印整体训练损失。
    • mean()函数计算平均值
  • 梯度下降

  l.sum().backward()  # 反向传播计算梯度
  sgd([w, b], lr, batch_size)  # 更新参数
  • 小批量梯度计算公式:
    ∇ w L batch = 1 batch_size X batch ⊤ ( X batch w + b − y batch ) \nabla_{\mathbf{w}} L_{\text{batch}} = \frac{1}{\text{batch\_size}} \mathbf{X}_{\text{batch}}^\top (\mathbf{X}_{\text{batch}} \mathbf{w} + b - \mathbf{y}_{\text{batch}}) wLbatch=batch_size1Xbatch(Xbatchw+bybatch)
    ∇ b L batch = 1 batch_size 1 ⊤ ( X batch w + b − y batch ) \nabla_{b} L_{\text{batch}} = \frac{1}{\text{batch\_size}} \mathbf{1}^\top (\mathbf{X}_{\text{batch}} \mathbf{w} + b - \mathbf{y}_{\text{batch}}) bLbatch=batch_size11(Xbatchw+bybatch)

http://www.kler.cn/a/530932.html

相关文章:

  • 寒假刷题Day20
  • 基于STM32景区环境监测系统的设计与实现(论文+源码)
  • CSS Fonts(字体)
  • Ubuntu 下 nginx-1.24.0 源码分析 main函数 — ngx_cdecl 宏
  • 使用朴素贝叶斯对自定义数据集进行分类
  • 贪吃蛇实现
  • 【数据结构-Trie树】力扣648. 单词替换
  • Kafka流式计算架构
  • Linux——进程间通信之SystemV共享内存
  • (回溯递归dfs 电话号码的字母组合 remake)leetcode 17
  • OpenCV4.8 开发实战系列专栏之 30 - OpenCV中的自定义滤波器
  • html中的列表元素
  • 全域旅游景区导览系统小程序独立部署
  • 使用冒泡排序模拟实现qsort函数
  • NeetCode刷题第20天(2025.2.1)
  • STM32 DMA+AD多通道
  • 音标-- 02-- 重音 音节 变音
  • 2024美团秋招硬件开发笔试真题及答案解析
  • Day33【AI思考】-分层递进式结构 对数学数系的 终极系统分类
  • C++:结构体和类
  • 刷题记录 动态规划-5: 62. 不同路径
  • python的pre-commit库的使用
  • leetcode——从前序与中序遍历序列构造二叉树(java)
  • stm32小白成长为高手的学习步骤和方法
  • NOTEPAD++编写abap
  • 国土安全保障利器,高速巡飞无人机技术详解