当前位置: 首页 > article >正文

【机器学习】线性回归模型

线性回归是机器学习中最基础的算法之一。它主要用于回归任务,即预测一个连续的数值输出。本文将从零开始,带领你构建线性回归模型,逐步推演损失函数、梯度下降、学习率等核心概念,并使用numpy实现。最后,我们会通过sklearn快速实现线性回归模型。

线性回归模型简介

线性回归模型的核心思想是用一个直线(或超平面)拟合一组数据,找到特征和目标变量之间的线性关系。其数学表达式为:

y = w ⋅ x + b y = w \cdot x + b y=wx+b
其中:

  • ( y ) 是预测值(输出),
  • ( w ) 是权重(或斜率),
  • ( x ) 是输入变量(特征),
  • ( b ) 是偏置(截距)。

目标是找到合适的 ( w ) 和 ( b ),使得模型的预测结果尽可能接近真实值。

损失函数

为了衡量模型的预测值和真实值之间的差距,我们使用损失函数。常见的损失函数是均方误差(MSE, Mean Squared Error),其公式如下:

M S E = 1 n ∑ i = 1 n ( y i − y i ^ ) 2 MSE = \frac{1}{n} \sum_{i=1}^{n} (y_i - \hat{y_i})^2 MSE=n1i=1n(yiyi^)2
其中:

  • ( y_i ) 是第 (i) 个样本的真实值,
  • ( \hat{y_i} ) 是模型的预测值,
  • ( n ) 是样本数。

损失函数越小,说明模型越准确。

梯度下降(Gradient Descent)

为了最小化损失函数,我们使用梯度下降算法。梯度下降的基本思想是从随机初始化的参数开始,逐步调整参数,使得损失函数逐渐变小,最终找到最优解。

梯度下降的更新规则
w = w − α ∂ L ∂ w b = b − α ∂ L ∂ b w = w - \alpha \frac{\partial L}{\partial w}\\ b = b - \alpha \frac{\partial L}{\partial b} w=wαwLb=bαbL
其中:

  • ( \alpha ) 是学习率(决定每次更新的步长),
  • ( \frac{\partial L}{\partial w} ) 是损失函数关于 ( w ) 的导数(梯度),
  • ( \frac{\partial L}{\partial b} ) 是损失函数关于 ( b ) 的导数。

学习率(Learning Rate)

学习率 ( \alpha ) 是梯度下降中的重要超参数。它决定了每次参数更新的步长。学习率过大,可能会错过最优解;学习率过小,训练过程会非常缓慢,甚至陷入局部最优解。

代码实现:从零开始构建线性回归模型

接下来,我们使用 numpy 从头实现一个线性回归模型。

数据准备

我们首先构造一组简单的线性数据,用来训练我们的模型。

import numpy as np
import matplotlib.pyplot as plt

# 生成数据
np.random.seed(42)
X = 2 * np.random.rand(100, 1)  # 随机生成 100 个点,范围在 [0, 2]
y = 4 + 3 * X + np.random.randn(100, 1)  # y = 4 + 3x + 噪声

# 可视化数据
plt.scatter(X, y)
plt.xlabel("X")
plt.ylabel("y")
plt.title("Generated Data")
plt.show()

损失函数实现

接下来,我们实现均方误差(MSE)损失函数。

def mse_loss(y_true, y_pred):
    return np.mean((y_true - y_pred) ** 2)

梯度计算

我们需要计算损失函数对 ( w ) 和 ( b ) 的偏导数:

def compute_gradients(X, y, w, b):
    n = len(y)
    y_pred = X.dot(w) + b
    dw = (2/n) * X.T.dot(y_pred - y)
    db = (2/n) * np.sum(y_pred - y)
    return dw, db

梯度下降算法

使用梯度下降算法更新参数 ( w ) 和 ( b ):

def gradient_descent(X, y, w, b, learning_rate, iterations):
    for i in range(iterations):
        dw, db = compute_gradients(X, y, w, b)
        w -= learning_rate * dw
        b -= learning_rate * db
        
        if i % 100 == 0:
            y_pred = X.dot(w) + b
            loss = mse_loss(y, y_pred)
            print(f"Iteration {i}: Loss = {loss}")
            
    return w, b

模型训练

初始化参数并开始训练:

# 初始化参数
w = np.random.randn(1, 1)
b = np.random.randn(1)

# 超参数设置
learning_rate = 0.1
iterations = 1000

# 训练模型
w_trained, b_trained = gradient_descent(X, y, w, b, learning_rate, iterations)
print(f"Trained weights: {w_trained}, Trained bias: {b_trained}")

模型可视化

训练结束后,我们可以将拟合直线与原始数据进行对比:

# 绘制拟合直线
plt.scatter(X, y)
plt.plot(X, X.dot(w_trained) + b_trained, color='red')
plt.xlabel("X")
plt.ylabel("y")
plt.title("Linear Regression Fit")
plt.show()

使用 sklearn 实现线性回归

最后,我们使用 sklearn 库快速实现同样的线性回归模型。

from sklearn.linear_model import LinearRegression

# 训练模型
lin_reg = LinearRegression()
lin_reg.fit(X, y)

# 输出权重和偏置
print(f"Sklearn Trained weights: {lin_reg.coef_}, Sklearn Trained bias: {lin_reg.intercept_}")

# 绘制拟合直线
plt.scatter(X, y)
plt.plot(X, lin_reg.predict(X), color='red')
plt.xlabel("X")
plt.ylabel("y")
plt.title("Linear Regression with Sklearn")
plt.show()

总结

在本教程中,我们通过 numpy 实现了线性回归模型,深入理解了损失函数、梯度下降和学习率等概念。最后,我们通过 sklearn 验证了结果。希望这篇文章能帮助你打下机器学习的基础,深入理解线性回归背后的原理。


http://www.kler.cn/a/373364.html

相关文章:

  • 计算机视觉与深度学习:使用深度学习训练基于视觉的车辆检测器(MATLAB源码-Faster R-CNN)
  • 阿里云服务器扩容系统盘后宝塔面板不显示扩容后的大小
  • Unity3D实现WEBGL打开Window文件对话框打开/上传文件
  • 从源码角度分析SpringMVC执行流程
  • 如何在Jupyter中快速切换Anaconda里不同的虚拟环境
  • 道旅科技借助云消息队列 Kafka 版加速旅游大数据创新发展
  • Linux系统rpm安装MySQL详细操作步骤
  • 19 Docker容器集群网络架构:二、etcd 集群部署
  • 【Java多线程】8 Java 中的并发设计模式
  • 【K8S系列】Kubernetes 中 NodePort 类型的 Service 无法访问的问题【已解决】
  • MySQL(2)【库的操作】
  • python爬虫案例——使用aiohttp模块异步请求网站,利用协程加快爬取速度(17)
  • 数据可视化工具深入学习:Seaborn 与 Plotly 的详细教程
  • Linux驱动开发(1):环境搭建
  • 工厂方法模式与抽象工厂模式
  • 九泰智库 | 医械周刊- Vol.65 | 广州发布首批创新药械产品目录
  • libavdevice.so.58: cannot open shared object file: No such file ordirectory踩坑
  • XXE漏洞原理、修复建议及绕过方式
  • 蓝牙4.0/5.1/5.2模组UART通讯基础知识
  • 【C++动态规划】有效括号的嵌套深度
  • 【Triton 教程】矩阵乘法
  • 新闻列表以及详情页面梳理
  • DAY66WEB 攻防-Java 安全SPEL 表达式SSTI 模版注入XXEJDBCMyBatis 注入
  • Linux find 匹配文件内容
  • 无损将GPT转换为MBR的GDisk操作指南:
  • 数据结构和算法-动态规划(1)-认识动态规划