当前位置: 首页 > article >正文

机器学习与数据挖掘_使用梯度下降法训练线性回归模型

目录

实验内容

实验步骤

1. 导入必要的库

2. 加载数据并绘制散点图

3. 设置模型的超参数

4. 实现梯度下降算法

5. 打印训练后的参数和损失值

6. 绘制损失函数随迭代次数的变化图

7. 绘制线性回归拟合曲线

8. 基于训练好的模型进行新样本预测

实验代码

实验结果

实验总结


实验内容

(1)编写代码实现基于梯度下降的单变量线性回归算法,包括梯度的计算与验证;

(2)绘制数据散点图,以及得到的直线;

(3)绘制梯度下降过程中损失的变化图;

(4)基于训练得到的参数,输入新的样本数据,输出预测值。


实验步骤

1. 导入必要的库

使用 `numpy` 进行科学计算,并使用 `matplotlib` 来生成图形。为了保证图形中的中文正常显示,设置 `matplotlib` 的字体为黑体,并解决负号显示问题。

2. 加载数据并绘制散点图

使用 `numpy` 的 `genfromtxt` 函数从文件中加载数据,数据以逗号作为分隔符。分别提取第一列数据为 `x` 值,第二列数据为 `y` 值,展示数据点的分布情况。使用 `scatter` 函数绘制散点图,并使用 `show` 函数显示图形。

3. 设置模型的超参数

初始化线性回归模型的参数:学习率 `alpha` 设置为 `0.0001`。权重 `w` 和偏置 `b` 初始化为 `0`。设置梯度下降的迭代次数为 `1000`。获取数据样本数量 `m`。

4. 实现梯度下降算法

定义一个列表 `MSE` 用来存储每次迭代的均方误差。在每次迭代中,分别计算损失函数和模型参数的梯度:对每一个样本点,计算当前的预测值和真实值的误差,进而计算平方误差并累积。计算梯度,分别对权重 `w` 和偏置 `b` 进行更新。更新后的参数 `w` 和 `b` 基于学习率和当前梯度值来进行调整。

5. 打印训练后的参数和损失值

在训练结束后,打印出模型的最终参数 `w` 和 `b`。使用最后一次迭代的均方误差来表示最终的损失函数值。

6. 绘制损失函数随迭代次数的变化图

使用 `plot` 函数绘制损失函数随迭代次数变化的曲线,`x` 轴为迭代次数,`y` 轴为损失值。图形展示了梯度下降过程中损失函数值的变化趋势,验证模型的收敛情况。

7. 绘制线性回归拟合曲线

再次绘制原始数据的散点图,并基于训练得到的参数计算每个数据点的预测值。使用 `plot` 函数绘制线性回归拟合的曲线,并用红色标出拟合的直线。

8. 基于训练好的模型进行新样本预测

输入新的样本数据 `new_sample`,基于训练得到的参数 `w` 和 `b` 计算新的 `y` 值。打印出新样本数据及其对应的预测值。


实验代码

# 导入必要的库
import numpy as np  # 导入科学计算库
import matplotlib.pyplot as plt  # 导入绘图库
from matplotlib import rcParams  # 导入设置绘图样式的参数

# 设置字体,防止中文乱码
rcParams['font.sans-serif'] = ['SimHei']  # 使用黑体
rcParams['axes.unicode_minus'] = False  # 解决负号显示问题

# 1. 加载数据并画出散点图
points = np.genfromtxt('data1.txt', delimiter=',')  # 从文件中加载数据,数据以逗号分隔
x = points[:, 0]  # 获取第一列数据作为x
y = points[:, 1]  # 获取第二列数据作为y
plt.scatter(x, y)  # 绘制散点图,展示数据的分布情况
plt.show()

# 2. 设置模型的超参数
alpha = 0.0001  # 学习率
w = 0  # 初始化权重w
b = 0  # 初始化偏置b
num_iter = 1000  # 梯度下降的迭代次数
m = len(points)  # 样本数量

# 3. 梯度下降算法
MSE = []  # 用于保存每次迭代的均方误差
for iteration in range(num_iter):
    # 初始化梯度的和
    sum_grad_w = 0  # 用于累加w的梯度
    sum_grad_b = 0  # 用于累加b的梯度
    total_cost = 0  # 每次迭代的总损失初始化为0

    # 遍历所有数据点,计算偏导数并更新梯度
    for i in range(m):
        x_i = points[i, 0]  # 当前数据点的x值
        y_i = points[i, 1]  # 当前数据点的y值

        # 计算当前点的预测值
        pred_y_i = w * x_i + b

        # 计算损失函数(平方误差)
        total_cost += (y_i - pred_y_i) ** 2

        # 计算梯度
        sum_grad_w += (pred_y_i - y_i) * x_i  # 对w的偏导数
        sum_grad_b += (pred_y_i - y_i)  # 对b的偏导数

    # 计算当前迭代的均方误差
    total_cost /= m
    MSE.append(total_cost)  # 保存每次迭代的损失值

    # 计算偏导数的平均值
    grad_w = 2 / m * sum_grad_w
    grad_b = 2 / m * sum_grad_b

    # 更新w和b,基于学习率和梯度
    w -= alpha * grad_w
    b -= alpha * grad_b

# 4. 打印训练后的参数和损失值
print("参数w = ", w)
print("参数b = ", b)
# 使用 MSE[-1] 来表示最后一次迭代的损失函数值
print("最后的损失函数 = ", MSE[-1])

# 5. 绘制损失函数随迭代次数的变化图
plt.plot(MSE)
plt.xlabel('迭代次数')
plt.ylabel('损失值')
plt.title('梯度下降过程中的损失函数变化')
plt.show()

# 6. 画出拟合曲线
plt.scatter(x, y)  # 原始数据的散点图
pred_y = w * x + b  # 基于最终的w和b计算所有数据点的预测值
plt.plot(x, pred_y, color='red')  # 绘制线性回归拟合的直线,颜色为红色
plt.title('线性回归拟合曲线')
plt.show()

# 7. 基于训练得到的参数进行新样本预测
new_sample = np.array([5, 10, 15])  # 新的输入数据
predicted_y = w * new_sample + b  # 计算新样本的预测值
print("输入的新样本数据: ", new_sample)
print("预测的y值: ", predicted_y)

实验结果

1. 数据散点图及其线性回归拟合曲线

数据散点图及其线性回归拟合曲线

2. 梯度下降过程中损失函数变化图

梯度下降过程中损失函数变化图

3. 相关参数展示及新样本数据和其预测值

相关参数展示及新样本数据和其预测值


实验总结

本次实验通过使用梯度下降法训练线性回归模型,实现了单变量线性回归的训练与预测。实验中,我们成功编写了基于梯度下降算法的代码,并通过图形展示了数据的分布情况及模型的拟合效果。

在实验过程中,模型的权重参数和偏置参数通过多次迭代逐步更新,梯度下降法有效地减少了损失函数值。最终,模型收敛到了一个较好的参数组合,使得拟合曲线能够较好地反映数据的趋势。此外,通过绘制损失函数的变化图,我们直观地看到了随着迭代次数的增加,损失值不断下降的过程,验证了梯度下降算法的收敛性。

实验结果表明,使用梯度下降法能够有效训练线性回归模型,并且在小数据集上可以获得较为理想的拟合效果。同时,通过该实验,进一步加深了对线性回归和梯度下降算法的理解和掌握。

总体而言,实验达到了预期的目标,完成了线性回归模型的训练、损失函数的可视化及新样本的预测任务。


http://www.kler.cn/a/384154.html

相关文章:

  • 大模型应用编排工具Dify二开之工具和模型页面改造
  • 【WebRTC】视频采集模块流程的简单分析
  • 大数据集群中实用的三个脚本文件解析与应用
  • kelp protocol
  • k8s图形化显示(KRM)
  • HTML 标签属性——<a>、<img>、<form>、<input>、<table> 标签属性详解
  • 有什么办法换网络ip动态
  • 算法每日双题精讲——双指针(移动零,复写零)
  • Windows系统服务器怎么设置远程连接?详细步骤
  • Windows下QT调用MinGW编译的OpenCV
  • SIwave:释放 EMI 扫描仪/探测器的强大功能
  • 【CSS】“flex: 1“有什么用?
  • 如何在Linux环境中的Qt项目中使用ActiveMQ-CPP
  • 简单又便宜的实现电脑远程开机唤醒方法
  • 前端 | MYTED单篇TED词汇学习功能优化
  • leetcode 622.设计循环队列
  • DeBiFormer实战:使用DeBiFormer实现图像分类任务(二)
  • 高级 SQL 技巧详解
  • MDC(重要)
  • 物联网核心安全系列——物联网安全需求
  • 100种算法【Python版】第37篇—— Jarvis March算法
  • 快速上手vue3+js+Node.js
  • 实践出真知:MVEL表达式empty的坑
  • vue中html如何转成pdf下载,pdf转base64,忽略某个元素渲染在pdf中,方法封装
  • 【Python爬虫实战】DrissionPage 与 ChromiumPage:高效网页自动化与数据抓取的双利器
  • 【AI】【提高认知】卷积神经网络:从LeNet到AI淘金热的深度学习之旅