当前位置: 首页 > article >正文

Lucas带你手撕机器学习——线性回归

什么是线性回归

线性回归是机器学习中的基础算法之一,用于预测一个连续的输出值。它假设输入特征与输出值之间的关系是线性关系,即目标变量是输入变量的线性组合。我们可以从代码实现的角度来学习线性回归,包括如何使用 Python 进行简单的线性回归模型构建、训练、和预测。

线性回归的直观理解

你可以把线性回归理解成“画一条线来预测未来”。假设你有一张散点图,每个点代表某个物品的重量和它的价格。你的目标是找到一条直线,能够尽可能准确地描述这些点之间的关系。

线性回归的工作原理

假设我们有一些数据点,每个点都有一个输入(如重量)和一个输出(如价格)。线性回归就是在这些点之间找到一条直线,使得这条线能够“最好”地描述这些数据点。

这条直线的公式是:

在这里插入图片描述

其中:

  • y:输出,即我们想要预测的值(例如,物品的价格)
  • x:输入特征(例如,物品的重量)
  • w:线的斜率,表示重量对价格的影响有多大
  • b:截距,表示当重量为 0 时,预测的价格是多少

线性回归的基本原理

线性回归的数学公式为:

在这里插入图片描述

其中:

  • y 是预测值(目标变量)
  • x1,x2,…,xn 是输入特征
  • w1,w2,…,wn 是特征对应的权重(回归系数)
  • b 是偏置项(截距)

如何找到“最好的”直线?

“最好的”直线是指那些经过这条直线的点尽可能接近数据点。为了衡量直线的好坏,我们需要一个方法来计算直线与数据点之间的差距。

误差的概念
  • 对于每个数据点,我们可以计算它的实际价格(真实值)和用这条直线预测出来的价格之间的差距,称为“误差”。
  • 比如说,某个物品的真实价格是 10 元,但通过直线预测出来的价格是 9 元,那么这个点的误差就是 10−9=1。
均方误差(Mean Squared Error,MSE)

为了让误差的计算更稳定,我们通常不直接使用误差,而是使用“均方误差”来衡量模型的好坏:

在这里插入图片描述

其中:

  • yi:第 i 个样本的真实值
  • yi^:第 i 个样本通过模型预测的值
  • N:样本数量

均方误差的作用就是将所有数据点的误差平方后取平均值,这样可以确保误差不会因为正负抵消。我们的目标是让这个均方误差尽可能小,意味着直线与数据点之间的差距最小。

训练模型

在实际训练过程中,我们会不断调整直线的斜率 w 和截距 b,直到找到使均方误差最小的那一组 w 和 b。这就意味着找到了“最好的”直线。

代码实现

使用 Scikit-Learn 实现****线性回归

我们可以使用 Scikit-Learn 库,它提供了非常简洁的接口来进行线性回归。下面是一个完整的示例代码:

import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error

# 生成一些模拟数据
np.random.seed(42)
X = 2 * np.random.rand(100, 1)  # 输入特征,100 个样本,1 个特征
y = 4 + 3 * X + np.random.randn(100, 1)  # 线性关系 y = 4 + 3x + 噪声

# 拆分数据为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 创建线性回归模型并进行训练
model = LinearRegression()
model.fit(X_train, y_train)

# 输出模型的系数和截距
print(f'权重(w): {model.coef_[0][0]}')
print(f'截距(b): {model.intercept_[0]}')

# 预测并计算均方误差
y_pred = model.predict(X_test)
mse = mean_squared_error(y_test, y_pred)
print(f'测试集上的 MSE: {mse}')

# 可视化结果
plt.scatter(X_test, y_test, color='blue', label='真实值')
plt.plot(X_test, y_pred, color='red', label='预测值', linewidth=2)
plt.xlabel('X')
plt.ylabel('y')
plt.legend()
plt.title('线性回归拟合结果')
plt.show()

在这里插入图片描述

  1. 代码解释
  • 生成模拟数据: 生成了一些随机数据点 X和 y,其中 y=4 + 3X + 噪声,这样我们就有一个线性关系的示例数据。
  • 数据集拆分: 使用 train_test_split 将数据集拆分成训练集和测试集,80% 用于训练,20% 用于测试。
  • 训练模型: 使用 LinearRegression 类创建模型,并用训练集数据拟合模型。
  • 预测和评估: 使用测试集进行预测,计算预测值与真实值之间的均方误差(MSE)。
  • 结果可视化: 将真实值和预测结果在图中可视化,可以清楚地看到线性回归的拟合效果。

PyTorch 实现线性回归

为了更好地理解线性回归的原理,我们也可以使用 PyTorch 从头实现一个简单的线性回归模型:

import torch
import torch.nn as nn
import torch.optim as optim

# 生成模拟数据
torch.manual_seed(42)
X = torch.randn(100, 1) * 2
y = 4 + 3 * X + torch.randn(100, 1)

# 定义线性模型
class LinearRegressionModel(nn.Module):
    def __init__(self):
        super(LinearRegressionModel, self).__init__()
        self.linear = nn.Linear(1, 1)  # 输入 1 维,输出 1 维

    def forward(self, x):
        return self.linear(x)

# 创建模型、损失函数和优化器
model = LinearRegressionModel()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# 训练模型
epochs = 1000
for epoch in range(epochs):
    model.train()
    optimizer.zero_grad()
    outputs = model(X)
    loss = criterion(outputs, y)
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 100 == 0:
        print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item()}')

# 输出训练好的模型参数
[w, b] = model.parameters()
print(f'权重(w): {w.item()}')
print(f'截距(b): {b.item()}')

代码解释

  • 定义模型: 使用 nn.Module 定义了一个简单的线性模型,只包含一个线性层。
  • 定义损失函数和优化器: 选择均方误差作为损失函数(nn.MSELoss()),使用随机梯度下降(optim.SGD)优化模型。
  • 模型训练: 通过前向传播计算损失,通过反向传播计算梯度并更新模型参数。

总结

以上两种方法分别使用 Scikit-Learn 和 PyTorch 实现了线性回归模型。Scikit-Learn 的方式适合快速建模和测试,而 PyTorch 版本则更灵活,更适合理解深度学习模型的训练过程。掌握这些方法后,可以将它们应用于更复杂的模型和任务中。

感谢阅读!!我是正在澳洲深造的Lucas!!
在这里插入图片描述


http://www.kler.cn/news/357030.html

相关文章:

  • 记录Visio导出图片的文字与latex中文字大小一致的问题,和visio导出适用于论文的高清图片问题
  • Java项目-基于Springboot的应急救援物资管理系统项目(源码+说明).zip
  • 虾​皮​一​面​-​2
  • 数学归纳法——第一数学归纳法、第二数学归纳法步骤和示例
  • SpringBoot中的RedisTemplate对象中的setIfAbsent()方法有什么作用?
  • Mapbox GL 加载GeoServer底图服务器的WMS source
  • 开源的存储引擎--cantian
  • js 字符串与数组的操作
  • python【装饰器】
  • python中_init_.py 到底有啥用?
  • nvm安装,node多版本管理
  • 多级缓存-案例导入说明
  • 自定义多级联动选择器指南(uni-app)
  • Spring Boot实现的电影评论系统开发
  • 开发工具(上)
  • 【数据结构与算法】第2课—数据结构之顺序表
  • 对于从vscode ssh到virtualBox的timeout记录
  • 【JavaScript】LeetCode:76-80
  • 【RestTemplate】重试机制详解
  • 生成式人工智能如何帮助我们更有效地传达信息