《机器学习》线性回归模型实现
目录
一、一元线性回归模型
1、数据
2、代码
3、结果
二、多元线性回归模型
1、数据
2、代码
3、结果
一、一元线性回归模型
1、数据
2、代码
# 导入所需的库
import pandas as pd # 用于数据处理和分析
from matplotlib import pyplot as plt # 用于数据可视化
from sklearn.linear_model import LinearRegression # 用于线性回归建模
# 从CSV文件中读取数据
data = pd.read_csv('./data/data.csv') # 读取名为 'data.csv' 的文件,并将其存储在 DataFrame 中
# 绘制散点图,查看广告投入与销售额之间的关系
plt.scatter(data.广告投入, data.销售额) # 以广告投入为横轴,销售额为纵轴绘制散点图
plt.show() # 显示图形
# 计算数据集中各列之间的相关系数
corr = data.corr() # 生成相关系数矩阵,用于分析变量之间的线性关系
# 创建线性回归模型对象
lr = LinearRegression() # 初始化一个线性回归模型
# 准备自变量(特征)和因变量(目标)
x = data[['广告投入']] # 将 '广告投入' 列作为自变量(特征)
y = data[['销售额']] # 将 '销售额' 列作为因变量(目标)
# 训练线性回归模型
lr.fit(x, y) # 使用数据拟合线性回归模型,找到最佳拟合参数
# 使用训练好的模型进行预测
result = lr.predict(x) # 对输入的自变量 x 进行预测,得到对应的销售额预测值
# 计算模型的拟合优度(R² 分数)
score = lr.score(x, y) # 计算模型在训练数据上的 R² 分数,表示模型的解释能力
# 获取线性回归模型的系数和截距
a = lr.coef_ # 获取回归系数(斜率),表示广告投入对销售额的影响程度
b = lr.intercept_ # 获取截距,表示当广告投入为 0 时的销售额
# 输出线性回归模型的方程
print('线性回归模型是 y = %.2fX1 + %.2f' % (a[0][0], b)) # 打印线性回归方程,格式为 y = aX + b
3、结果
二、多元线性回归模型
1、数据
2、代码
# 导入所需的库
import pandas as pd # 用于数据处理和分析
from sklearn.linear_model import LinearRegression # 用于线性回归建模
# 从CSV文件中读取数据,指定编码格式为 'gbk',并使用 'python' 引擎
data = pd.read_csv('./data/多元线性回归.csv', encoding='gbk', engine='python') # 读取名为 '多元线性回归.csv' 的文件,解决中文编码问题
# 计算 '体重'、'年龄' 和 '血压收缩' 列之间的相关系数
corr = data[['体重', '年龄', '血压收缩']].corr() # 生成相关系数矩阵,分析变量之间的线性关系
# 创建线性回归模型对象
lr_model = LinearRegression() # 初始化一个线性回归模型
# 准备自变量(特征)和因变量(目标)
x = data[['体重', '年龄']] # 将 '体重' 和 '年龄' 列作为自变量(特征)
y = data[['血压收缩']] # 将 '血压收缩' 列作为因变量(目标)
# 训练线性回归模型
lr_model.fit(x, y) # 使用数据拟合线性回归模型,找到最佳拟合参数
# 计算模型的拟合优度(R² 分数)
score = lr_model.score(x, y) # 计算模型在训练数据上的 R² 分数,表示模型的解释能力
# 使用训练好的模型进行预测
print(lr_model.predict([[80, 60]])) # 对体重为 80、年龄为 60 的样本进行预测,输出血压收缩的预测值
print(lr_model.predict([[80, 60], [70, 20]])) # 对多组样本进行预测,输出对应的血压收缩预测值
# 获取线性回归模型的系数和截距
a = lr_model.coef_ # 获取回归系数(斜率),表示体重和年龄对血压收缩的影响程度
b = lr_model.intercept_ # 获取截距,表示当体重和年龄均为 0 时的血压收缩值
# 输出线性回归模型的方程
print('线性回归模型是 y = %.2fX1 + %.2fX2 + %.2f' % (a[0][0], a[0][1], b)) # 打印多元线性回归方程,格式为 y = a1X1 + a2X2 + b