当前位置: 首页 > article >正文

自定义数据集使用框架的线性回归方法对其进行拟合

代码如下:

"""01
PyTorch 线性回归实现
"""
import torch
import numpy as np
from matplotlib import pyplot as plt

# 1.散点输入
# 1、散点输入
# 定义输入数据
data = [[-0.5, 7.7], [1.8, 98.5], [0.9, 57.8], [0.4, 39.2], [-1.4, -15.7], [-1.4, -37.3], [-1.8, -49.1], [1.5, 75.6], [0.4, 34.0], [0.8, 62.3]]
# 转换为 NumPy 数组
data = np.array(data)
# 提取 x_data 和 y_data
x_data = data[:, 0]
y_data = data[:, 1]

#将x_data 和y_data 转化成tensor
x_train=torch.tensor(x_data,dtype=torch.float32)
y_train=torch.tensor(y_data,dtype=torch.float32)
#导入类库 Dataloader TensorDataset
from torch.utils.data import DataLoader,TensorDataset
# TensorDataset 主要用于封装张量,将输入张量和目标张量组成一个数据集
#返回值能够按照索引获得数据和标签,例如(x_train[i],y_train[i])
dataset=TensorDataset(x_train,y_train)
#DataLoader,数据加载器结合了数据集和采样器,并为给定的数据集提供可迭代性
#返回值dataloader是可迭代的对象,每次迭代生成一个批次的数据,由输入张量和目标张量组成的元组组成
#batch_size 表示一次加载到内存多少个数据,
# shuffle 就表示打乱
dataloader=DataLoader(dataset,batch_size=2,shuffle=True)
#如果batch_size=3 长度为10 怎么划分一下? 111 111 111 1


# 2.定义前向模型
import torch.nn as nn
#定义损失
criterion=nn.MSELoss()

# 方案4
# 最常用的网络结构
# 直接重写继承nn.Module
class LinearModel(nn.Module):
    #初始化
    def __init__(self):
        super(LinearModel,self).__init__()
        #定义一个nn.ModuleList
        self.layers=nn.Linear(1,1)
    #前向传播
    def forward(self, x):
        x=self.layers(x)
        return x
#初始化一下模型,返回模型对象
model=LinearModel()

#优化器
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

#4.开始迭代
epoches =500
for n in range(1,epoches+1):
    epoch_loss=0
    #以前都是所有数据一块训练,现在是按照批次进行训练
    #现在x_train 相当于10个样本,但是现在维度,添加一个维度
    #10x1   变成样本 x 维度形式
    y_prd=model(x_train.unsqueeze(1))
    #计算损失
    #y_prd在前面,y_true 是后面
    batch_loss=criterion(y_prd.squeeze(1),y_train)
    #梯度更新
    #清空之前存储在优化器中的梯度
    optimizer.zero_grad()
    #损失函数对模型参数的梯度
    batch_loss.backward()
    #根据优化算法更新参数
    optimizer.step()
    #计算一下epoch的损失
    epoch_loss=epoch_loss+batch_loss

        # 5、显示频率设置

    #计算一下epoch的平均损失
    avg_loss=epoch_loss/(len(dataloader))
    # 不先画图
    if n % 10 == 0 or n == 1:
        print(f"epoches:{n},loss:{avg_loss}")
        plt.clf()
        # 绘制原始数据点,使用蓝色表示
        plt.scatter(x_data, y_data, color='blue')

        # 绘制当前预测的回归线,使用红色表示
        plt.plot(x_data, y_prd.detach().numpy(), color='red')  # detach() 是为了防止从计算图中分离,避免梯度计算

        # 设置图表的 x 和 y 轴标签
        plt.xlabel('X')
        plt.ylabel('Y')

        # 设置图表的标题,显示当前 epoch 数
        plt.title(f'Epoch {n}')

        # 暂停 0.1 秒,实时更新图表
        plt.pause(0.1)

    # 训练完成后,显示最终的图表
plt.show()

结果展示:


http://www.kler.cn/a/519927.html

相关文章:

  • JJJ:Linux - 高精度定时器 hrtimer
  • Go语言开发项目文件规范
  • Windows本地部署(DeepSeek-R1-Distill-Qwen-1.5B)模型
  • 定时器按键tim_key模版
  • 可以称之为“yyds”的物联网开源框架有哪几个?
  • 写一个存储“网站”的网站前的分析
  • 第30章 测试驱动开发中的设计模式解析(Python 版)
  • 三年级数学知识边界总结思考-下册
  • GSI快速收录服务:让你的网站内容“上架”谷歌
  • 从Spring请求处理到分层架构与IOC:注解详解与演进实战
  • MYSQL数据库 - 启动与连接
  • 入门 Canvas:Web 绘图的强大工具
  • C#,入门教程(05)——Visual Studio 2022源程序(源代码)自动排版的功能动画图示
  • rust学习-rust中的格式化打印
  • 深度解读:近端策略优化算法(PPO)
  • 浅谈在AI时代GIS的发展方向和建议
  • Elasticsearch 性能测试工具 Loadgen 之 004——高级用法示例
  • c语言函数(详解)
  • Vue.js 高级组件开发
  • 任务一:Android逆向
  • 泷羽Sec-Powershell3
  • 设计模式思想的元规则
  • Python从0到100(八十五):神经网络与迁移学习在猫狗分类中的应用
  • 数据结构day02
  • go安全项目汇总表
  • 神经网络|(三)线性回归基础知识