代码如下:
"""01
PyTorch 线性回归实现
"""
import torch
import numpy as np
from matplotlib import pyplot as plt
# 1.散点输入
# 1、散点输入
# 定义输入数据
data = [[-0.5, 7.7], [1.8, 98.5], [0.9, 57.8], [0.4, 39.2], [-1.4, -15.7], [-1.4, -37.3], [-1.8, -49.1], [1.5, 75.6], [0.4, 34.0], [0.8, 62.3]]
# 转换为 NumPy 数组
data = np.array(data)
# 提取 x_data 和 y_data
x_data = data[:, 0]
y_data = data[:, 1]
#将x_data 和y_data 转化成tensor
x_train=torch.tensor(x_data,dtype=torch.float32)
y_train=torch.tensor(y_data,dtype=torch.float32)
#导入类库 Dataloader TensorDataset
from torch.utils.data import DataLoader,TensorDataset
# TensorDataset 主要用于封装张量,将输入张量和目标张量组成一个数据集
#返回值能够按照索引获得数据和标签,例如(x_train[i],y_train[i])
dataset=TensorDataset(x_train,y_train)
#DataLoader,数据加载器结合了数据集和采样器,并为给定的数据集提供可迭代性
#返回值dataloader是可迭代的对象,每次迭代生成一个批次的数据,由输入张量和目标张量组成的元组组成
#batch_size 表示一次加载到内存多少个数据,
# shuffle 就表示打乱
dataloader=DataLoader(dataset,batch_size=2,shuffle=True)
#如果batch_size=3 长度为10 怎么划分一下? 111 111 111 1
# 2.定义前向模型
import torch.nn as nn
#定义损失
criterion=nn.MSELoss()
# 方案4
# 最常用的网络结构
# 直接重写继承nn.Module
class LinearModel(nn.Module):
#初始化
def __init__(self):
super(LinearModel,self).__init__()
#定义一个nn.ModuleList
self.layers=nn.Linear(1,1)
#前向传播
def forward(self, x):
x=self.layers(x)
return x
#初始化一下模型,返回模型对象
model=LinearModel()
#优化器
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
#4.开始迭代
epoches =500
for n in range(1,epoches+1):
epoch_loss=0
#以前都是所有数据一块训练,现在是按照批次进行训练
#现在x_train 相当于10个样本,但是现在维度,添加一个维度
#10x1 变成样本 x 维度形式
y_prd=model(x_train.unsqueeze(1))
#计算损失
#y_prd在前面,y_true 是后面
batch_loss=criterion(y_prd.squeeze(1),y_train)
#梯度更新
#清空之前存储在优化器中的梯度
optimizer.zero_grad()
#损失函数对模型参数的梯度
batch_loss.backward()
#根据优化算法更新参数
optimizer.step()
#计算一下epoch的损失
epoch_loss=epoch_loss+batch_loss
# 5、显示频率设置
#计算一下epoch的平均损失
avg_loss=epoch_loss/(len(dataloader))
# 不先画图
if n % 10 == 0 or n == 1:
print(f"epoches:{n},loss:{avg_loss}")
plt.clf()
# 绘制原始数据点,使用蓝色表示
plt.scatter(x_data, y_data, color='blue')
# 绘制当前预测的回归线,使用红色表示
plt.plot(x_data, y_prd.detach().numpy(), color='red') # detach() 是为了防止从计算图中分离,避免梯度计算
# 设置图表的 x 和 y 轴标签
plt.xlabel('X')
plt.ylabel('Y')
# 设置图表的标题,显示当前 epoch 数
plt.title(f'Epoch {n}')
# 暂停 0.1 秒,实时更新图表
plt.pause(0.1)
# 训练完成后,显示最终的图表
plt.show()
结果展示: