当前位置: 首页 > article >正文

【机器学习】多项式回归

多项式回归是回归分析的一种扩展形式,通过增加多项式特征,可以模拟输入特征与输出之间的非线性关系。与线性回归不同,线性回归仅适用于直线拟合,而多项式回归则可以用曲线拟合复杂数据。本教程将系统讲解多项式回归模型,包括模型的推导、数据转换、模型训练、损失函数的定义和梯度下降,并使用numpy实现。最后,我们会通过sklearn实现多项式回归模型。

多项式回归模型简介

多项式回归模型的核心在于扩展输入特征。对于一个特征 (x),我们可以将它转化为多项式特征,例如将输入特征 (x) 转换为二次特征(或更高次),模型表达式如下:

y = w 0 + w 1 x + w 2 x 2 + ⋯ + w n x n + b y = w_0 + w_1 x + w_2 x^2 + \cdots + w_n x^n + b y=w0+w1x+w2x2++wnxn+b
其中:

  • ( y ) 是预测值,
  • ( x ) 是输入特征,
  • ( w_0, w_1, \ldots, w_n ) 是模型的权重,
  • ( b ) 是偏置项。

这个模型是输入特征 (x) 的 n次多项式。多项式回归通过增加特征的次幂,允许模型更好地拟合非线性数据。

数据转换:构建多项式特征

在构建多项式回归模型前,需要将输入特征转换为多项式特征。例如,给定一个特征 ( x ),我们将它扩展为多项式特征 ( [1, x, x^2, \ldots, x^n] )。

我们定义一个函数 poly_features 来生成多项式特征。

import numpy as np

def poly_features(X, degree):
    """
    将输入特征 X 扩展为多项式特征矩阵。
    X : 原始特征 (n_samples, 1)
    degree : 多项式的最高次数
    """
    X_poly = np.ones((X.shape[0], degree + 1))  # 初始化为 1
    for i in range(1, degree + 1):
        X_poly[:, i] = X[:, 0] ** i
    return X_poly

损失函数:均方误差 (Mean Squared Error, MSE)

为了评估模型预测值与真实值的误差,我们使用均方误差作为损失函数:
M S E = 1 n ∑ i = 1 n ( y i − y i ^ ) 2 MSE = \frac{1}{n} \sum_{i=1}^{n} (y_i - \hat{y_i})^2 MSE=n1i=1n(yiyi^)2
其中:

  • ( y_i ) 是第 ( i ) 个样本的真实值,
  • ( \hat{y_i} ) 是模型的预测值,
  • ( n ) 是样本总数。

模型训练:梯度下降优化

多项式回归模型的优化可以通过梯度下降实现。以下是计算梯度的实现:

def compute_gradients(X, y, w, b):
    n = len(y)
    y_pred = X.dot(w) + b
    dw = (2/n) * X.T.dot(y_pred - y)
    db = (2/n) * np.sum(y_pred - y)
    return dw, db

使用梯度下降训练模型

接下来,我们使用梯度下降来优化模型参数:

def gradient_descent(X, y, w, b, learning_rate, iterations):
    for i in range(iterations):
        dw, db = compute_gradients(X, y, w, b)
        w -= learning_rate * dw
        b -= learning_rate * db
        
        if i % 100 == 0:
            y_pred = X.dot(w) + b
            loss = mse_loss(y, y_pred)
            print(f"Iteration {i}: Loss = {loss}")
    return w, b

代码实现:多项式回归模型

我们使用numpy从头实现一个多项式回归模型。

数据准备

我们生成一个非线性数据集,用来训练多项式回归模型。

import matplotlib.pyplot as plt

# 生成非线性数据
np.random.seed(42)
X = 6 * np.random.rand(100, 1) - 3  # X 范围在 [-3, 3]
y = 0.5 * X**2 + X + 2 + np.random.randn(100, 1)  # y = 0.5x^2 + x + 2 + 噪声

# 可视化数据
plt.scatter(X, y)
plt.xlabel("X")
plt.ylabel("y")
plt.title("Generated Non-linear Data")
plt.show()

转换多项式特征

假设我们想训练一个二次多项式回归模型。

# 转换成二次多项式特征
degree = 2
X_poly = poly_features(X, degree)

初始化模型参数并定义损失函数

我们初始化权重和偏置,并定义损失函数:

# 初始化参数
w = np.random.randn(degree + 1, 1)
b = np.random.randn(1)

# 定义均方误差损失函数
def mse_loss(y_true, y_pred):
    return np.mean((y_true - y_pred) ** 2)

训练模型

设置学习率和迭代次数,进行模型训练。

learning_rate = 0.01
iterations = 1000
w_trained, b_trained = gradient_descent(X_poly, y, w, b, learning_rate, iterations)
print(f"Trained weights: {w_trained}, Trained bias: {b_trained}")

可视化拟合曲线

我们将模型拟合的曲线与原始数据进行对比。

# 生成预测值
X_fit = np.linspace(-3, 3, 100).reshape(100, 1)  # 用于绘制拟合曲线
X_fit_poly = poly_features(X_fit, degree)
y_fit = X_fit_poly.dot(w_trained) + b_trained

# 绘制结果
plt.scatter(X, y, label="Original Data")
plt.plot(X_fit, y_fit, color='red', label="Polynomial Fit")
plt.xlabel("X")
plt.ylabel("y")
plt.title("Polynomial Regression Fit")
plt.legend()
plt.show()

使用sklearn实现多项式回归

最后,我们使用sklearn快速实现多项式回归模型。

from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression

# 生成多项式特征
poly = PolynomialFeatures(degree=2)
X_poly_sklearn = poly.fit_transform(X)

# 训练模型
lin_reg = LinearRegression()
lin_reg.fit(X_poly_sklearn, y)

# 可视化拟合曲线
y_sklearn_fit = lin_reg.predict(poly.fit_transform(X_fit))
plt.scatter(X, y, label="Original Data")
plt.plot(X_fit, y_sklearn_fit, color='red', label="Sklearn Polynomial Fit")
plt.xlabel("X")
plt.ylabel("y")
plt.title("Polynomial Regression with Sklearn")
plt.legend()
plt.show()

总结

本文通过逐步推演实现了多项式回归模型,深入理解了多项式特征转换、损失函数和梯度下降优化过程。最后,我们通过sklearn验证了模型的实现并进行了可视化展示,希望这篇教程帮助你掌握多项式回归的基本原理与实现。


http://www.kler.cn/a/370649.html

相关文章:

  • 蓝桥杯刷题第二天——背包问题
  • 算法与数据结构——复杂度
  • PDF文件提取开源工具调研总结
  • RPC 源码解析~Apache Dubbo
  • excel仅复制可见单元格,仅复制筛选后内容
  • QT 如何禁止QComboBox鼠标滚轮
  • 实战OpenCV之深度学习
  • <大厂实战场景> ~ flutter鸿蒙next处理后端返回来的数据的转义问题
  • 大数据-186 Elasticsearch - ELK 家族 Logstash Input插件 JDBC syslog
  • SSRF服务端请求伪造
  • Pandas 数据分析基础操作:从创建到统计的实用指南
  • 人工智能与机器学习相关算法介绍
  • 掌握机器学习中的偏差与方差:模型性能的关键
  • DAPT: Distribution-Aware Prompt Tuning for Vision-Language Models
  • 实现梦想:Spring Boot驱动的摄影工作室网站
  • GeoWebCache1.26调用ArcGIS切片
  • 【数据集】2015-2100年8种情景(SSPs-RCP)下中国土地利用数据
  • 命令模式(C++)三分钟读懂
  • 企业如何用WordPress站群布局多个行业站点,轻松覆盖关键词
  • Linux之nfs服务器和dns服务器
  • node升级package.json中的版本
  • pip 和 pipx 的主要区别?
  • Vue笔记-element ui中关于table的前端分页
  • CSS 样式 box-sizing: border-box; 用于控制元素的盒模型如何计算宽度和高度
  • 解决minio跨域问题
  • 【数据结构和算法】三、动态规划原理讲解与实战演练