当前位置: 首页 > article >正文

《动手深度学习》线性回归简洁实现实例

本篇文章是《动手深度学习》线性回归简洁实现实例的实现和分析,主要对代码进行详细讲解,有问题欢迎在评论区讨论交流。

代码实现

实现代码如下所示。

import torch
from torch.utils import data
# d2l包是李沐老师等人开发的动手深度学习配套的包,
# 里面封装了很多有关与数据集定义,数据预处理,优化损失函数的包
from d2l import torch as d2l
# nn 是神经网络 Neural Network 的缩写,提供了一系列的模块和类,实现创建、训练、保存、恢复神经网络
from torch import nn
 
'''
1. 生成数据集,共 1000 条
true_w 和 true_b 是临时变量用于生成数据集
生成 X, y :满足关系 y = Xw + b + noise
'''
true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = d2l.synthetic_data(true_w, true_b, 1000)
 
'''
2. 构造循环读取数据集的迭代器
'''
def load_array(data_arrays, batch_size, is_train=True):  #@save
    # 构造一个 PyTorch 数据迭代器,对 tensor 进行打包,包装成 dataset。
    dataset = data.TensorDataset(*data_arrays)
    # 根据数据集构造一个迭代器
    return data.DataLoader(dataset, batch_size, shuffle=is_train)
 
# 小批量数据
batch_size = 10
# 设置了一个数据读取的迭代器,每次读取 batch_size(10) 条
data_iter = load_array((features, labels), batch_size)
 
'''
3. 设置全连接层
'''
'''
# nn.Linear(in_features, out_features, bias=True)
# in_features : 输入向量的列数
# out_features : 输出向量的列数
# bias = True 是否包含偏置
执行线性变换:Yn*o = Xn*i Wi*o + b
其中:W 和 b 模型需要学习的参数
在本例中:n = 10,i = 2, o = 1
'''
net = nn.Sequential(nn.Linear(2, 1))
# 设置权重 w 和 偏置 b
net[0].weight.data.normal_(0, 0.01)
net[0].bias.data.fill_(0)
 
'''
4. 定义损失函数
'''
# 均方误差,是预测值与真实值之差的平方和的平均值
loss = nn.MSELoss()
# lr 学习率 learning rate
trainer = torch.optim.SGD(net.parameters(), lr=0.03)
 
'''
4. 训练数据
'''
# 超参数 设置批次
num_epochs = 3
for epoch in range(num_epochs): # 进行 num_epochs 个迭代周期
    for X, y in data_iter:
        l = loss(net(X) ,y) # 计算损失,net(X) 计算预测值 y1,loss(y1, y) 计算预测值和真实值之间的差距
        trainer.zero_grad() # 将所有模型参数的梯度置为 0
        l.backward() # 求梯度,不使用从零实现中 l.sum.backward 的原因是损失计算中使用了平均的 gard
        trainer.step() # 优化参数 w 和 b
    l = loss(net(features), labels)
    print(f'epoch {epoch + 1}, loss {l:f}')
 
w = net[0].weight.data
print('w的估计误差:', true_w - w.reshape(true_w.shape))
b = net[0].bias.data
print('b的估计误差:', true_b - b)


http://www.kler.cn/a/313352.html

相关文章:

  • 国家网络安全法律法规
  • 鸿蒙HarmonyOS 地图不显示解决方案
  • HTTP 客户端怎么向 Spring Cloud Sleuth 传输跟踪 ID
  • SHELL脚本(Linux)
  • [ 网络安全介绍 5 ] 为什么要学习网络安全?
  • mongoDB的安装及使用
  • 【Webpack--013】SourceMap源码映射设置
  • windows环境下配置MySQL主从启动失败 查看data文件夹中.err发现报错unknown variable ‘log‐bin=mysql‐bin‘
  • 使用vite+react+ts+Ant Design开发后台管理项目(二)
  • SpringBoot:关于Redis的配置失效(版本问题)
  • 6. Python 输出长方形,直角三角形,等腰三角形
  • 【Linux基础IO】深入Linux文件描述符与重定向:解锁高效IO操作的秘密
  • 解决“Windows系统中以管理员身份运行程序时无法访问映射的网络磁盘”的问题
  • C# WPF如何实现数据共享
  • C#使用实体类Entity Framework Core操作mysql入门:从数据库反向生成模型2 处理连接字符串
  • 2024年上海小学生古诗文大会倒计时一个月:做2024官方模拟题
  • 人家90年代就尝试过的模式:我们所热衷的“数科公司”
  • 基于spring的ssm整合
  • 航空航司reese84逆向
  • linux文件同步、传输
  • 数据结构不再难懂:带你轻松搞定图
  • linux-L6 linux管理服务的启动、重启、停止、重载、查看状态命令
  • EmguCV学习笔记 VB.Net 12.3 OCR
  • OpenAI GPT o1技术报告阅读(4)- 填字游戏推理
  • 【Git 操作】Git 的基本操作
  • Elasticsearch:检索增强生成背后的重要思想