当前位置: 首页 > article >正文

第L2周:机器学习|线性回归模型 LinearRegression:2. 多元线性回归模型

  • 本文为365天深度学习训练营 中的学习记录博客
  • 原作者:K同学啊

任务:
●1. 学习本文的多元线形回归模型。
●2. 参考文本预测花瓣宽度的方法,选用其他三个变量来预测花瓣长度。

一、多元线性回归

简单线性回归:影响 Y 的因素唯一,只有一个。
多元线性回归:影响 Y 的因数不唯一,有多个。

与一元线性回归一样,多元线性回归自然是一个回归问题。
在这里插入图片描述

相当于我们高中学的一元一次方程,变成了 n 元一次方程。因为 y 还是那个 y。只是自变量增加了。

二、代码实现

我的环境:
●语言环境:Python3.9
●编译器:Jupyter Lab

第1步:数据预处理

  1. 导入数据集
import pandas as pd
import numpy as np

url = "https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data"  
names = ['花萼-length', '花萼-width', '花瓣-length', '花瓣-width', 'class'] 

dataset = pd.read_csv(url, names=names)
dataset

代码输出:

花萼-length花萼-width花瓣-length花瓣-widthclass
05.13.51.40.2Iris-setosa
14.93.01.40.2Iris-setosa
24.73.21.30.2Iris-setosa
34.63.11.50.2Iris-setosa
45.03.61.40.2Iris-setosa
..................
1456.73.05.22.3Iris-virginica
1466.32.55.01.9Iris-virginica
1476.53.05.22.0Iris-virginica
1486.23.45.42.3Iris-virginica
1495.93.05.11.8Iris-virginica

150 rows × 5 columns

备注:
如果报下面错误:URLError: <urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: self signed certificate in certificate chain (_ssl.c:1129)>
在代码开头加上如下代码即可:
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

  1. 数据分析
import matplotlib.pyplot as plt

plt.plot(dataset['花萼-length'], dataset['花瓣-width'], 'x', label="marker='x'")
plt.plot(dataset['花萼-width'],  dataset['花瓣-width'], 'o', label="marker='o'")
plt.plot(dataset['花瓣-length'], dataset['花瓣-width'], 'v', label="marker='v'")
    
plt.legend(numpoints=1)
plt.show()

代码输出:

在这里插入图片描述

X = dataset.iloc[ : ,[1,2]].values
Y = dataset.iloc[ : ,  3 ].values
  1. 构建训练集、测试集
from sklearn.model_selection import train_test_split
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, 
                                                    test_size=0.2, 
                                                    random_state=0)

第2步:训练多元线性回归模型

from sklearn.linear_model import LinearRegression

regressor = LinearRegression()
regressor.fit(X_train, Y_train)

第3步:在测试集上预测结果

y_pred = regressor.predict(X_test)
y_pred

代码输出:

array([1.76025586, 1.23794101, 0.29130263, 2.28334281, 0.2668048 ,
       2.18837013, 0.18945083, 1.61397124, 1.63158995, 1.28848086,
       1.95785242, 1.53661727, 1.58870131, 1.54581268, 1.59712462,
       0.24153487, 1.51134735, 1.44318879, 0.19022292, 0.22314407,
       1.67447859, 1.51977066, 0.43835934, 0.18179962, 1.63158995,
       0.06920823, 0.47205258, 1.42557008, 0.94614386, 0.30969343])

第4步:测试集预测结果可视化

plt.scatter(Y_test,y_pred, color='red')

plt.xlabel("True")
plt.ylabel("Prediction")

plt.show()

代码输出:

在这里插入图片描述


http://www.kler.cn/a/326037.html

相关文章:

  • nuget 管理全局包、缓存和临时文件夹
  • STM32保护内部FLASH
  • 【Go】-bufio库解读
  • C++深度搜索(2)
  • Gin 框架中间件详细介绍
  • Redis基础篇
  • Vulhub zico 2靶机详解
  • GS-SLAM论文阅读笔记--MM3DGS SLAM
  • A Learning-Based Approach to Static Program Slicing —— 论文笔记
  • 【Git原理与使用】分支管理
  • C++可见性
  • 关于武汉芯景科技有限公司的IIC电平转换芯片XJ9509开发指南(兼容PCa9509)
  • Matlab实现麻雀优化算法优化回声状态网络模型 (SSA-ESN)(附源码)
  • linux环境oracle11.2.0.4打补丁(p31537677_112040_Linux-x86-64.zip)
  • [M贪心] lc2207. 字符串中最多数目的子序列(模拟+贪心+一次遍历+代码细节+思维)
  • 无人机避障—— 激光雷达定高北醒TF03-UART(二)
  • 【基础算法总结】分治--快排+归并
  • YOLOv8改进,YOLOv8改进损失函数采用Powerful-IoU(2024年最新IOU),助力涨点
  • 【YOLOv10改进[SPPF]】使用 SPPFCSPC替换SPPF模块 + 含全部代码和详细修改方式
  • Linux内核 -- 读写文件系统文件之kernel_read与kernel_write
  • APISIX 联动雷池 WAF 实现 Web 安全防护
  • VLAN Bond 堆叠
  • 苍穹外卖学习笔记(十三)
  • TikTok Shop成印尼第二大电商平台,TikTok怎么快速涨粉?
  • OpenCV开发笔记(八十一):通过棋盘格使用鱼眼方式标定相机内参矩阵矫正摄像头图像
  • 关于音频噪音处理【常见问题、解决方案等】