使用python实现线性回归
一、概述
本代码主要演示了如何使用 Python 的 numpy
、matplotlib
和 sklearn
库进行简单线性回归分析。通过生成模拟数据,训练线性回归模型,对模型进行评估,并将结果可视化,帮助用户理解线性回归的基本原理和操作流程。
二、依赖库
numpy
:用于数值计算和数组操作,如生成随机数和处理数组数据。matplotlib.pyplot
:用于数据可视化,绘制散点图和回归线。sklearn.linear_model.LinearRegression
:用于创建和训练线性回归模型。sklearn.metrics.mean_squared_error
和sklearn.metrics.r2_score
:分别用于计算均方误差(MSE)和 \(R^2\) 分数,评估模型的性能。
三、代码详细解释
1. 导入必要的库
收起
python
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, r2_score
- 导入
numpy
库并将其别名为np
,方便后续使用。 - 导入
matplotlib.pyplot
库并将其别名为plt
,用于绘图。 - 从
sklearn.linear_model
模块中导入LinearRegression
类,用于创建线性回归模型。 - 从
sklearn.metrics
模块中导入mean_squared_error
和r2_score
函数,用于评估模型性能。
2. 设置中文字体
收起
python
plt.rcParams['font.family'] = 'SimSun'
- 设置
matplotlib
的字体为宋体,确保在绘图时可以正常显示中文。
3. 生成模拟数据
收起
python
np.random.seed(42) # 固定随机种子
X = np.random.rand(100, 1) * 10 # 生成 100 个 0~10 之间的特征值
y = 3 * X + 5 + np.random.randn(100, 1) * 2 # y = 3x + 5 + 噪声
np.random.seed(42)
:固定随机种子,确保每次运行代码时生成的随机数相同,方便结果的复现。X = np.random.rand(100, 1) * 10
:使用np.random.rand
函数生成一个形状为(100, 1)
的数组,数组中的元素是 0 到 1 之间的随机数,然后将其乘以 10,得到 100 个 0 到 10 之间的特征值。y = 3 * X + 5 + np.random.randn(100, 1) * 2
:根据真实方程 \(y = 3x + 5\) 生成目标值,并添加高斯噪声(使用np.random.randn
函数生成),模拟真实世界中的数据。
4. 创建和训练线性回归模型
收起
python
# 创建线性回归模型
model = LinearRegression()
# 训练模型
model.fit(X, y)
model = LinearRegression()
:创建一个LinearRegression
类的实例,即一个线性回归模型。model.fit(X, y)
:使用生成的特征值X
和目标值y
对模型进行训练,让模型学习X
和y
之间的线性关系。
5. 模型预测
收起
python
# 预测
y_pred = model.predict(X)
y_pred = model.predict(X)
:使用训练好的模型对特征值X
进行预测,得到预测的目标值y_pred
。
6. 获取模型参数
收起
python
# 获取模型参数
slope = model.coef_[0][0] # 斜率
intercept = model.intercept_[0] # 截距
slope = model.coef_[0][0]
:从模型的系数(斜率)数组中获取斜率值。intercept = model.intercept_[0]
:从模型的截距数组中获取截距值。
7. 打印模型参数和评估指标
收起
python
# 打印模型参数和评估指标
print(f"真实方程: y = 3x + 5")
print(f"学习到的方程: y = {slope:.2f}x + {intercept:.2f}")
print(f"均方误差 (MSE): {mean_squared_error(y, y_pred):.2f}")
print(f"R² 分数: {r2_score(y, y_pred):.2f}")
- 打印真实方程和模型学习到的方程,方便对比。
- 使用
mean_squared_error
函数计算均方误差(MSE),衡量模型预测值与真实值之间的平均误差。 - 使用
r2_score
函数计算 \(R^2\) 分数,评估模型对数据的拟合程度,\(R^2\) 分数越接近 1 表示模型拟合效果越好。
8. 可视化结果
收起
python
# 可视化
plt.figure(figsize=(10, 6))
plt.scatter(X, y, color='blue', label='原始数据', alpha=0.6)
plt.plot(X, y_pred, color='red', linewidth=2, label='回归线')
plt.plot(X, 3*X+5, color='green', linestyle='--', label='真实关系')
plt.xlabel('X')
plt.ylabel('y')
plt.title('线性回归示例')
plt.legend()
plt.grid(True)
plt.show()
plt.figure(figsize=(10, 6))
:创建一个大小为(10, 6)
的图形窗口。plt.scatter(X, y, color='blue', label='原始数据', alpha=0.6)
:绘制原始数据的散点图,颜色为蓝色,设置透明度为 0.6。plt.plot(X, y_pred, color='red', linewidth=2, label='回归线')
:绘制模型的回归线,颜色为红色,线宽为 2。plt.plot(X, 3*X+5, color='green', linestyle='--', label='真实关系')
:绘制真实的线性关系,颜色为绿色,线型为虚线。plt.xlabel('X')
和plt.ylabel('y')
:设置 x 轴和 y 轴的标签。plt.title('线性回归示例')
:设置图形的标题。plt.legend()
:显示图例,方便区分不同的图形元素。plt.grid(True)
:显示网格线,增强图形的可读性。plt.show()
:显示绘制好的图形。
四、总结
通过本代码示例,我们可以看到如何使用 sklearn
库进行简单线性回归分析,包括数据生成、模型训练、预测、评估和可视化。用户可以根据需要修改代码中的参数,如随机种子、数据规模、噪声水平等,进一步探索线性回归的特性。
完整代码
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, r2_score
plt.rcParams['font.family'] = 'SimSun'
# 生成模拟数据
np.random.seed(42) # 固定随机种子
X = np.random.rand(100, 1) * 10 # 生成 100 个 0~10 之间的特征值
y = 3 * X + 5 + np.random.randn(100, 1) * 2 # y = 3x + 5 + 噪声
# 创建线性回归模型
model = LinearRegression()
# 训练模型
model.fit(X, y)
# 预测
y_pred = model.predict(X)
# 获取模型参数
slope = model.coef_[0][0] # 斜率
intercept = model.intercept_[0] # 截距
# 打印模型参数和评估指标
print(f"真实方程: y = 3x + 5")
print(f"学习到的方程: y = {slope:.2f}x + {intercept:.2f}")
print(f"均方误差 (MSE): {mean_squared_error(y, y_pred):.2f}")
print(f"R² 分数: {r2_score(y, y_pred):.2f}")
# 可视化
plt.figure(figsize=(10, 6))
plt.scatter(X, y, color='blue', label='原始数据', alpha=0.6)
plt.plot(X, y_pred, color='red', linewidth=2, label='回归线')
plt.plot(X, 3*X+5, color='green', linestyle='--', label='真实关系')
plt.xlabel('X')
plt.ylabel('y')
plt.title('线性回归示例')
plt.legend()
plt.grid(True)
plt.show()