当前位置: 首页 > article >正文

使用 PyTorch 实现线性回归:从零开始的完整指南

在机器学习中,线性回归是最基础且广泛使用的算法之一。它通过拟合数据点之间的线性关系,帮助我们理解和预测变量之间的关系。本文将通过一个简单的例子,展示如何使用 PyTorch 框架实现线性回归,并对自定义数据集进行拟合。

1. 线性回归简介

线性回归的目标是找到一个线性方程 y=wx+b,其中 w 是斜率,b 是截距,使得该方程能够尽可能地拟合给定的数据点。在实际应用中,我们通常使用最小二乘法来最小化预测值与真实值之间的误差。

2. 准备数据

首先,我们需要准备一个简单的数据集。在这个例子中,我们将使用一个包含 10 个数据点的自定义数据集:

data = [
    [-0.5, 7.7],
    [1.8, 98.5],
    [0.9, 57.8],
    [0.4, 39.2],
    [-1.4, -15.7],
    [-1.4, -37.3],
    [-1.8, -49.1],
    [1.5, 75.6],
    [0.4, 34.0],
    [0.8, 62.3]
]

这些数据点表示输入特征 x 和目标变量 y 之间的关系。我们将使用 PyTorch 的张量(Tensor)来存储和处理这些数据。

3. 构建线性回归模型

接下来,我们需要定义一个线性回归模型。在 PyTorch 中,可以通过继承 nn.Module 来定义一个自定义模型。我们将使用一个简单的线性层来实现这个模型:

class LinearModel(nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()
        self.layers = nn.ModuleList([nn.Linear(1, 1)])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

这个模型包含一个线性层,其输入维度为 1,输出维度也为 1,正好符合我们的问题需求。

4. 定义损失函数和优化器

为了训练模型,我们需要定义一个损失函数和一个优化器。在这里,我们使用均方误差(MSE)作为损失函数,使用随机梯度下降(SGD)作为优化器:

criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

5. 训练模型

现在,我们可以开始训练模型了。我们将数据集输入模型,计算损失,并通过反向传播更新模型参数。以下是完整的训练代码:

epochs = 500
for n in range(1, epochs + 1):
    y_pred = model(x_train.unsqueeze(1))
    loss = criterion(y_pred.squeeze(1), y_train)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if n % 10 == 0 or n == 1:
        print(f"Epoch: {n}, Loss: {loss.item():.4f}")

在每个 epoch 中,我们计算模型的预测值,计算损失,并通过 loss.backward() 计算梯度,最后通过 optimizer.step() 更新模型参数。

6. 可视化结果

训练完成后,我们可以通过绘制原始数据点和拟合的直线来直观地展示模型的效果。以下是完整的可视化代码:

plt.rcParams['font.sans-serif'] = ['SimHei']  # 指定中文字体为黑体
plt.rcParams['axes.unicode_minus'] = False  # 正确显示负号


# 绘制原始数据点
plt.scatter(x_data, y_data, color='blue', label='原始数据')

# 绘制拟合的直线
slope = model.layers[0].weight.item()
intercept = model.layers[0].bias.item()
x_fit = np.linspace(x_data.min(), x_data.max(), 100)
y_fit = slope * x_fit + intercept
plt.plot(x_fit, y_fit, color='red', label='拟合直线')

# 添加图例和标签
plt.xlabel('X')
plt.ylabel('Y')
plt.legend()
plt.title('线性回归拟合结果')
plt.show()

运行上述代码后,你将看到如下图像:

从图中可以看出,拟合的直线能够较好地反映数据点之间的线性关系。

7. 总结

通过本文的介绍,你已经学会了如何使用 PyTorch 实现线性回归,并对自定义数据集进行拟合。线性回归虽然简单,但在许多实际问题中都非常有效。希望这篇文章能够帮助你更好地理解和应用线性回归模型。


代码完整版

以下是完整的代码,供你参考和使用:

import torch
import torch.nn as nn
import numpy as np
from matplotlib import pyplot as plt

# 设置 matplotlib 支持中文显示
plt.rcParams['font.sans-serif'] = ['SimHei']  # 指定中文字体为黑体
plt.rcParams['axes.unicode_minus'] = False  # 正确显示负号

# 定义输入数据
data = [
    [-0.5, 7.7],
    [1.8, 98.5],
    [0.9, 57.8],
    [0.4, 39.2],
    [-1.4, -15.7],
    [-1.4, -37.3],
    [-1.8, -49.1],
    [1.5, 75.6],
    [0.4, 34.0],
    [0.8, 62.3]
]

# 转换为 NumPy 数组
data = np.array(data)
# 提取 x_data 和 y_data
x_data = data[:, 0]
y_data = data[:, 1]

# 将 x_data 和 y_data 转化成 tensor
x_train = torch.tensor(x_data, dtype=torch.float32)
y_train = torch.tensor(y_data, dtype=torch.float32)

# 定义损失函数
criterion = nn.MSELoss()

# 定义线性回归模型
class LinearModel(nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()
        self.layers = nn.ModuleList([nn.Linear(1, 1)])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

model = LinearModel()

# 定义优化器
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# 训练模型
epochs = 500
for n in range(1, epochs + 1):
    y_pred = model(x_train.unsqueeze(1))
    loss = criterion(y_pred.squeeze(1), y_train)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if n % 10 == 0 or n == 1:
        print(f"Epoch: {n}, Loss: {loss.item():.4f}")

# 绘制图像
# 绘制原始数据点
plt.scatter(x_data, y_data, color='blue', label='原始数据')

# 绘制拟合的直线
slope = model.layers[0].weight.item()
intercept = model.layers[0].bias.item()
x_fit = np.linspace(x_data.min(), x_data.max(), 100)
y_fit = slope * x_fit + intercept
plt.plot(x_fit, y_fit, color='red', label='拟合直线')

# 添加图例和标签
plt.xlabel('X')
plt.ylabel('Y')
plt.legend()
plt.title('线性回归拟合结果')
plt.show()


http://www.kler.cn/a/523369.html

相关文章:

  • 图漾相机——C++语言属性设置
  • csapp2.4节——浮点数
  • [MySQL]事务的理论、属性与常见操作
  • 云计算技术深度解析与代码使用案例
  • 把markdown转换为pdf的方法
  • stack 和 queue容器的介绍和使用
  • Ubuntu 18.04安装Emacs 26.2问题解决
  • 大一计算机的自学总结:位运算的应用及位图
  • 在做题中学习(82):最小覆盖子串
  • Vue 响应式渲染 - 待办事项简单实现
  • 案例研究丨浪潮云洲通过DataEase推进多维度数据可视化建设
  • 图神经网络驱动的节点分类:从理论到实践
  • 神经网络和深度学习
  • DeepSeek-R1本地部署笔记
  • Zookeeper(31)Zookeeper的事务ID(zxid)是什么?
  • 集群建模、空地协同,无人机高效救灾技术详解
  • 【Elasticsearch】_rollover API详解
  • Linux 阻塞IO
  • Spring Security(maven项目) 3.0.2.9版本
  • 【Rust自学】16.2. 使用消息传递来跨线程传递数据
  • 苹果AI最新动态:Siri改造和AI模型优化成2025年首要任务
  • 记录 | 基于Docker Desktop的MaxKB安装
  • 从 Web3 游戏融资热看行业未来发展趋势
  • C语言实现统计数组正负元素相关数据
  • Leecode刷题C语言之跳跃游戏②
  • 【信息系统项目管理师-选择真题】2008上半年综合知识答案和详解