当前位置: 首页 > article >正文

机器学习day3

自定义数据集使用框架的线性回归方法对其进行拟合

import matplotlib.pyplot as plt
import torch
import numpy as np
# 1.散点输入
# 1、散点输入
# 定义输入数据
data = [[-0.5, 7.7], [1.8, 98.5], [0.9, 57.8], [0.4, 39.2], [-1.4, -15.7], [-1.4, -37.3], [-1.8, -49.1], [1.5, 75.6], [0.4, 34.0], [0.8, 62.3]]
# 转换为 NumPy 数组
data = np.array(data)
# 提取 x_data 和 y_data
x_data = data[:, 0]
y_data = data[:, 1]

#将x_data 和y_data 转化成tensor
x_train=torch.tensor(x_data,dtype=torch.float32)
y_train=torch.tensor(y_data,dtype=torch.float32)
print(x_train)
# 2.定义前向模型
import torch.nn as nn
#定义损失
criterion=nn.MSELoss()

# 方案4
# 最常用的网络结构
# 直接重写继承nn.Module
class LinearModel(nn.Module):
    #初始化
    def __init__(self):
        super(LinearModel,self).__init__()
        #定义一个nn.ModuleList
        self.layers=nn.Linear(1,1)
    #前向传播
    def forward(self, x):
        x=self.layers(x)
        return x
#初始化一下模型,返回模型对象
model=LinearModel()

#优化器
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

#4.开始迭代
epoches =500
for n in range(1,epoches+1):
    #现在x_train 相当于10个样本,但是现在维度,添加一个维度
    #10x1   变成样本 x 维度形式
    y_prd=model(x_train.unsqueeze(1))
    #计算损失
    #y_prd在前面,y_true 是后面
    loss=criterion(y_prd.squeeze(1),y_train)
    #梯度更新
    #清空之前存储在优化器中的梯度
    optimizer.zero_grad()
    #损失函数对模型参数的梯度
    loss.backward()
    #根据优化算法更新参数
    optimizer.step()
    # 5、显示频率设置

    if n % 10 == 0 or n == 1:
        print(f"epoches:{n},loss:{loss}")

x_np = x_train.numpy()
y_np = y_train.numpy()
y_pred_np = model(x_train.unsqueeze(1)).detach().numpy()

# 绘制数据点和拟合直线
plt.scatter(x_np, y_np, color='blue', label='Data Points')  # 原始数据点
plt.plot(x_np, y_pred_np, color='red', label='Fitted Line')  # 拟合直线
plt.xlabel('x')
plt.ylabel('y')
plt.title('Linear Regression Fit (PyTorch)')
plt.legend()
plt.show()

结果展示


http://www.kler.cn/a/518972.html

相关文章:

  • 【数据结构】深入解析:构建父子节点树形数据结构并返回前端
  • 基于微信小程序的英语学习交流平台设计与实现(LW+源码+讲解)
  • 跟我学C++中级篇——容器的连接
  • Spring整合Mybatis、junit纯注解
  • Android中Service在新进程中的启动流程
  • PSD是什么图像格式?如何把PSD转为JPG格式?
  • 本地Ubuntu轻松部署高效性能监控平台SigNoz与远程使用教程
  • 71.在 Vue 3 中使用 OpenLayers 实现按住 Shift 拖拽、旋转和缩放效果
  • Linux MySQL离线安装
  • 《深入解析:DOS检测的技术原理与方法》
  • 9.business english-agreement
  • WPS添加文本超简单,批量操作不费劲-Excel易用宝
  • 基于Flask框架和Hive数仓的农业数据分析系统
  • 【论文阅读】RT-SKETCH: GOAL-CONDITIONED IMITATION LEARNING FROM HAND-DRAWN SKETCHES
  • 设计模式的艺术-中介者模式
  • 深度学习 Pytorch 单层神经网络
  • React Native 0.77 发布:更强的样式支持与性能优化
  • 图像处理算法研究的程序框架
  • 将本地项目上传到 GitLab/GitHub
  • 基于 Bash 脚本的系统信息定时收集方案
  • 机器学习-使用梯度下降最小化均方误差
  • 数据库的JOIN连接查询算法
  • BUUCTF_Web( XSS COURSE 1)xss
  • golang 编程规范 - Effective Go 中文
  • 【C】memory 详解
  • 了解网络编程