当前位置: 首页 > article >正文

机器学习之线性回归

线性回归是机器学习中比较简单的模型,给定 X 和 Y 的值拟合出一条直线,数据点离线的距离越近想过越好。如果一个特征,最终会呈现为一条直线,如果是多参数,输入就是一个矩阵,通过超平面进行分割。线性回归的损失函数使用均方差,这个好理解是方差和越小越好。我们是用 Sklearn 来实现线性回归。

安装相关类型

pip install numpy
pip install pandas
pip install 

加载训练数据

加载数据,442 Batch,10 个特征。

from sklearn.datasets import load_diabetes
diabetes = load_diabetes()
data = diabetes.data
target = diabetes.target 
print(data.shape)
print(target.shape)
print(data[:5])
print(target[:5])

在这里插入图片描述
制作训练集和测试集

# 导入sklearn diabetes数据接口
from sklearn.datasets import load_diabetes
# 导入sklearn打乱数据函数
from sklearn.utils import shuffle
# 获取diabetes数据集
diabetes = load_diabetes()
# 获取输入和标签
data, target = diabetes.data, diabetes.target 
# 打乱数据集
X, y = shuffle(data, target, random_state=13)
# 按照8/2划分训练集和测试集
offset = int(X.shape[0] * 0.8)
# 训练集
X_train, y_train = X[:offset], y[:offset]
# 测试集
X_test, y_test = X[offset:], y[offset:]
# 将训练集改为列向量的形式
y_train = y_train.reshape((-1,1))
# 将验证集改为列向量的形式
y_test = y_test.reshape((-1,1))
# 打印训练集和测试集维度
print("X_train's shape: ", X_train.shape)
print("X_test's shape: ", X_test.shape)
print("y_train's shape: ", y_train.shape)
print("y_test's shape: ", y_test.shape)


训练模型

直接使用 sklearn LinearRegression 方法

### sklearn版本为1.0.2
# 导入线性回归模块
from sklearn import linear_model
from sklearn.metrics import mean_squared_error, r2_score
# 创建模型实例
regr = linear_model.LinearRegression()
# 模型拟合
regr.fit(X_train, y_train)
# 模型预测
y_pred = regr.predict(X_test)
# 打印模型均方误差
print("Mean squared error: %.2f" % mean_squared_error(y_test, y_pred))
# 打印R2
print('R2 score: %.2f' % r2_score(y_test, y_pred))

在这里插入图片描述

总结

线性回归是机器学习非常简单的模型,通过 Sklearn 可以方便的训练模型。


http://www.kler.cn/a/372927.html

相关文章:

  • ElasticSearch备考 -- Index rollover
  • 前端组件化
  • SQLI LABS | Less-20 POST-Cookie Injections-Uagent field-error based
  • 电科金仓(人大金仓)更新授权文件(致命错误: XX000: License file expired.)
  • vue添加省市区
  • 国标GB28181软件EasyGBS国标GB28181网页直播平台在邮政快递场景中的应用
  • 二、k8s快速入门之docker+Kubernetes平台搭建
  • 提升网站速度与性能优化的有效策略与实践
  • ShellScript脚本编程(函数与正则表达式)
  • 软考:中间件
  • leetcode 303.区域和检索-数组不可变
  • 1.5 新特性 C++面试常见问题
  • 【Linux】-常见指令(1)
  • MS01SF1 精准测距UWB模组助力露天采矿中的人车定位安全和作业效率提升
  • 62.不同路径 63.不同路径ii
  • 我的电脑问题
  • C++设计模式创建型模式———单例模式
  • 计算机网络(Ⅵ)应用层原理
  • HTML入门教程20:HTML头部
  • 代码随想录第十五天
  • oracle和mysql的区别常用的sql语句
  • 模块化CSS
  • 汽车零部件展|2025 第十二届广州国际汽车零部件加工技术及汽车模具展览会邀您共赏汽车行业盛会
  • 使用 Git 命令将本地项目上传到 GitLab
  • JVM 复习1
  • 修改IP分组头部内容的场景