当前位置: 首页 > article >正文

【机器学习】过拟合与欠拟合

在机器学习模型的训练过程中,过拟合欠拟合是两个常见问题。过拟合会导致模型在训练集上表现优异,但在新数据上表现不佳;而欠拟合则使得模型在训练集上也无法充分拟合数据,表现较差。本教程中,我们将从基础概念出发,通过梯度下降算法手写模型拟合过程,逐步实现正则化以防止过拟合,并使用 numpy 进行实现,最后展示 sklearn 的简便实现。

过拟合与欠拟合的基本概念

  1. 欠拟合:当模型太简单,难以捕捉数据的模式和趋势,表现为模型在训练集和测试集上都不能很好地拟合数据。

  2. 过拟合:当模型过于复杂,以至于模型将训练数据的噪声也学习到了,导致模型虽然在训练集上表现很好,但在测试集上性能变差。

数据准备

为直观展现过拟合与欠拟合,我们生成一个带有噪声的非线性数据集,并拟合不同复杂度的多项式模型,观察误差的变化。

import numpy as np
import matplotlib.pyplot as plt

# 生成非线性数据
np.random.seed(0)
X = 2 * np.random.rand(100, 1) - 1  # X 在 [-1, 1]
y = 3 * X**3 + X**2 + 0.5 * X + np.random.randn(100, 1) * 0.1  # y = 3x^3 + x^2 + 0.5x + noise

# 数据可视化
plt.scatter(X, y, color="blue", label="Data")
plt.xlabel("X")
plt.ylabel("y")
plt.title("Non-linear Data for Demonstration")
plt.legend()
plt.show()

欠拟合与过拟合的多项式模型拟合

在这里,我们从简单的一次多项式(线性模型)开始,逐渐增加模型复杂度,并使用梯度下降进行参数拟合。

欠拟合示例

我们用线性模型拟合上述数据,由于数据具有非线性趋势,线性模型无法有效拟合数据,表现为欠拟合。

# 构建特征矩阵(一次多项式)
X_b = np.hstack([np.ones((X.shape[0], 1)), X])  # 增加偏置项

梯度下降算法实现

定义一个梯度下降函数,用于通过不断迭代找到最佳的模型参数。

# 定义梯度下降
def gradient_descent(X, y, lr=0.1, epochs=1000):
    m = X.shape[0]
    theta = np.random.randn(X.shape[1], 1)  # 初始化参数
    for epoch in range(epochs):
        gradients = 2 / m * X.T.dot(X.dot(theta) - y)
        theta = theta - lr * gradients
    return theta

使用梯度下降进行拟合

# 使用梯度下降进行欠拟合(线性模型)
theta_underfit = gradient_descent(X_b, y, lr=0.1, epochs=1000)
y_pred_underfit = X_b.dot(theta_underfit)

# 可视化
plt.scatter(X, y, color="blue", label="Data")
plt.plot(X, y_pred_underfit, color="red", label="Underfitting Model (Degree=1)")
plt.xlabel("X")
plt.ylabel("y")
plt.legend()
plt.title("Underfitting with Linear Model (Degree=1)")
plt.show()

过拟合示例

为了展示过拟合现象,我们将使用10次多项式,计算复杂模型在数据上的表现。

# 生成10次多项式的特征矩阵
X_poly = np.hstack([X**i for i in range(11)])

# 使用梯度下降拟合10次多项式模型
theta_overfit = gradient_descent(X_poly, y, lr=0.01, epochs=10000)
y_pred_overfit = X_poly.dot(theta_overfit)

# 可视化过拟合情况
plt.scatter(X, y, color="blue", label="Data")
plt.plot(np.sort(X, axis=0), np.sort(y_pred_overfit, axis=0), color="red", label="Overfitting Model (Degree=10)")
plt.xlabel("X")
plt.ylabel("y")
plt.legend()
plt.title("Overfitting with Polynomial Model (Degree=10)")
plt.show()

正则化:L2 正则化(Ridge 回归)

L2 正则化(Ridge 回归):L2 正则化会在损失函数中增加一个权重平方和的惩罚项,从而降低模型的复杂度,使模型在训练数据中的表现更加平滑。
L o s s = MSE + λ ∑ i = 1 n w i 2 Loss = \text{MSE} + \lambda \sum_{i=1}^n w_i^2 Loss=MSE+λi=1nwi2
其中 (\lambda) 是正则化强度,越大则约束越强。

带 L2 正则化的梯度下降实现

# 带 L2 正则化的梯度下降
def gradient_descent_ridge(X, y, lr=0.01, epochs=1000, alpha=0.1):
    m = X.shape[0]
    theta = np.random.randn(X.shape[1], 1)
    for epoch in range(epochs):
        gradients = 2 / m * X.T.dot(X.dot(theta) - y) + 2 * alpha * theta
        theta = theta - lr * gradients
    return theta

使用 L2 正则化进行拟合

# 使用正则化进行模型拟合
theta_ridge = gradient_descent_ridge(X_poly, y, lr=0.01, epochs=10000, alpha=0.1)
y_pred_ridge = X_poly.dot(theta_ridge)

# 可视化
plt.scatter(X, y, color="blue", label="Data")
plt.plot(np.sort(X, axis=0), np.sort(y_pred_ridge, axis=0), color="red", label="Ridge Regularization")
plt.xlabel("X")
plt.ylabel("y")
plt.legend()
plt.title("Regularization with Ridge (L2)")
plt.show()

正则化:L1 正则化(Lasso 回归)

L1 正则化(Lasso 回归):L1 正则化会在损失函数中增加一个权重绝对值和的惩罚项,增加模型的稀疏性,可以有效降低特征数量。
L o s s = MSE + λ ∑ i = 1 n ∣ w i ∣ Loss = \text{MSE} + \lambda \sum_{i=1}^n |w_i| Loss=MSE+λi=1nwi

带 L1 正则化的梯度下降实现

L1 正则化具有非连续性,我们使用近似法进行梯度计算。

# 带 L1 正则化的梯度下降
def gradient_descent_lasso(X, y, lr=0.01, epochs=1000, alpha=0.1):
    m = X.shape[0]
    theta = np.random.randn(X.shape[1], 1)
    for epoch in range(epochs):
        gradients = 2 / m * X.T.dot(X.dot(theta) - y) + alpha * np.sign(theta)
        theta = theta - lr * gradients
    return theta

使用 L1 正则化进行拟合

# 使用 L1 正则化进行模型拟合
theta_lasso = gradient_descent_lasso(X_poly, y, lr=0.01, epochs=10000, alpha=0.1)
y_pred_lasso = X_poly.dot(theta_lasso)

# 可视化
plt.scatter(X, y, color="blue", label="Data")
plt.plot(np.sort(X, axis=0), np.sort(y_pred_lasso, axis=0), color="red", label="Lasso Regularization")
plt.xlabel("X")
plt.ylabel("y")
plt.legend()
plt.title("Regularization with Lasso (L1)")
plt.show()

使用 sklearn 实现 Ridge 和 Lasso 回归

通过 sklearnRidgeLasso 可以快速实现上述正则化方法。

from sklearn.linear_model import Ridge, Lasso
from sklearn.preprocessing import PolynomialFeatures
from sklearn.pipeline import make_pipeline

# Ridge 正则化
ridge_model = make_pipeline(PolynomialFeatures(degree=10), Ridge(alpha=0.1))
ridge_model.fit(X, y)
y_pred_ridge_sklearn = ridge_model.predict(X)

# Lasso 正则化
lasso_model = make_pipeline(PolynomialFeatures(degree=10), Lasso(alpha=0.1, max_iter=10000))
lasso_model.fit(X, y)
y_pred_lasso_sklearn = lasso_model.predict(X)

# 可视化
plt.scatter(X, y, color="blue", label="Data")
plt.plot(np.sort(X, axis=0), np.sort(y_pred_ridge_sklearn, axis=0), color="green", label="Ridge (Sklearn)")
plt.plot(np.sort(X, axis=0), np.sort(y_pred_lasso_sklearn, axis=0), color="purple", label="Lasso (Sklearn)")
plt.xlabel("X")
plt.ylabel("y")
plt.legend()
plt.title("Regularization with Ridge and Lasso (Sklearn)")
plt.show()

总结

本文通过手写梯度下降算法,从基础的线性模型到多项式模型,详细演示了欠拟合与过拟合现象,并通过L2和L1正则化的方法有效缓解了过拟合问题。通过逐步推导和手写实现,我们深入理解了模型复杂度控制、损失函数、正则化以及梯度下降算法在模型拟合中的重要性。

  1. 欠拟合:主要由模型简单或数据复杂度较低造成。增加模型复杂度或数据量可以有效缓解欠拟合。

  2. 过拟合:通常是模型复杂度过高所致,通过正则化、减少特征数量或增加数据量可以减轻过拟合。

  3. 正则化:L2(Ridge)和L1(Lasso)正则化分别通过约束参数值的大小和稀疏化模型参数有效降低了模型的复杂性,是应对过拟合的重要手段。L2正则化倾向于缩小参数值而保留所有特征,L1正则化则倾向于筛选出对模型最有用的特征。

  4. 梯度下降:手写梯度下降不仅强化了对损失函数、参数更新过程的理解,也帮助我们深入理解了正则化在损失函数中的表现形式。

通过这些基本方法和工具,我们能够更好地控制模型在复杂数据集上的表现,确保它们在泛化能力和拟合能力之间达到良好平衡。

本教程的完整代码和推导内容为构建机器学习基础提供了扎实的理解和工具。希望这篇详细教程能为大家在处理模型拟合问题时提供深入见解和实用方法!


http://www.kler.cn/a/371667.html

相关文章:

  • 功能测试:方法、流程与工具介绍
  • 音视频入门基础:AAC专题(11)——AudioSpecificConfig简介
  • Hexo提交部署命令与Git Bash Here控制终端中按下Ctrl+C无法中断hexo s的解决办法
  • 基于SSM+小程序的智慧旅游平台登录管理系统(旅游2)
  • 代码随想录算法训练营第十一天(补) 栈与队列| 后序表达式、滑动窗口、高频元素、链表总结
  • 使用 Kafka 和 MinIO 实现人工智能数据工作流
  • 用哈希表封装unordered_map与unordered_set
  • sklearn机器学习实战
  • C++ 二叉树进阶:相关习题解析
  • C#实现与Windows服务的交互与控制
  • flinksql-Queries查询相关实战
  • 算法篇——动态规划最终篇 (js版)
  • uniapp position: fixed 兼容性不显示问题
  • Python Flask 数据库开发
  • Modbus TCP报文协议(ModbusTCP)
  • H5底部输入框点击弹起来的时候被软键盘遮挡bug
  • QT编译报错:-1: error: cannot find -lGL
  • 淘宝商品评价API的获取与应用
  • Prometheus自定义PostgreSQL监控指标
  • 直接删除Github上的文件
  • [flask] flask-mail邮件发送
  • 论区块链技术及应用
  • 网络安全领域推荐职位
  • Data+AI下的数据飞轮:如何重塑企业增长
  • SpringBoot 解析@Value注解型解析注入时机以及原理
  • GPT-4V 是什么?