当前位置: 首页 > article >正文

L2线性回归模型

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

鸢尾花数据集的单变量与多变量预测

在这周学习如何使用 机器学习 模型对鸢尾花(Iris)数据集进行单变量与多变量预测。我们将使用鸢尾花数据集中不同的特征进行预测,分别使用单变量和多变量的回归模型。

0.学习时长与学习成绩
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

dataset = pd.read_csv('data/studentscores.csv')
X = dataset.iloc[ : , :1].values
Y = dataset.iloc[ : ,1].values

from sklearn.model_selection import train_test_split
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, 
                                                    test_size=1/4, 
                                                    random_state=0)
from sklearn.linear_model import LinearRegression

regressor = LinearRegression()
regressor = regressor.fit(X_train, Y_train)   
Y_pred = regressor.predict(X_test)
plt.scatter(X_train, Y_train, color='red')
plt.plot(X_train, regressor.predict(X_train), color='blue')

plt.show()       
plt.scatter(X_test, Y_test, color='red')
plt.plot(X_test, regressor.predict(X_test), color='blue')
plt.show()                                          

在这里插入图片描述
在这里插入图片描述

1. 鸢尾花数据集简介

鸢尾花数据集是一个经典的机器学习数据集,包含150条记录,每条记录有四个特征以及一个类别标签。这四个特征分别是:

  • 花萼长度 (sepal length)
  • 花萼宽度 (sepal width)
  • 花瓣长度 (petal length)
  • 花瓣宽度 (petal width)

我们可以使用这些特征来构建预测模型。在单变量预测中,我们将使用单一特征来预测目标变量,而在多变量预测中,将使用多个特征来提高预测准确性。

2. 数据准备

首先,我们加载鸢尾花数据集并做简单的数据清理和预处理。这里我们选择了三个特征来预测花瓣长度。

import pandas as pd

# 数据集URL
url = "https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data"
names = ['花萼-length', '花萼-width', '花瓣-length', '花瓣-width', 'class']

# 读取数据
dataset = pd.read_csv(url, names=names)

# 查看前5行数据
print(dataset.head())
3. 单变量预测:基于花萼宽度预测花瓣长度
3.1 数据可视化

在开始模型训练之前,我们首先对数据进行可视化,观察特征之间的关系。

import matplotlib.pyplot as plt
import seaborn as sns

# 选择特征
x = dataset[['花萼-width']]
y = dataset['花瓣-length']

# 绘制散点图
sns.scatterplot(x=x['花萼-width'], y=y)
plt.xlabel('花萼宽度')
plt.ylabel('花瓣长度')
plt.title('花萼宽度与花瓣长度的关系')
plt.show()
3.2 构建线性回归模型

接下来,我们使用 sklearn 中的 线性回归 模型,基于单一变量(花萼宽度)来预测花瓣长度。

from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, r2_score

# 划分训练集和测试集
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)

# 构建线性回归模型
model = LinearRegression()
model.fit(x_train, y_train)

# 预测
y_pred = model.predict(x_test)

# 评估模型
print(f"均方误差: {mean_squared_error(y_test, y_pred)}")
print(f"R²得分: {r2_score(y_test, y_pred)}")
3.3 结果可视化
# 绘制预测值与真实值的对比
plt.figure(figsize=(6, 4))
plt.plot(range(len(y_test)), y_test, label='真实值')
plt.plot(range(len(y_pred)), y_pred, label='预测值')
plt.title('单变量线性回归预测')
plt.legend()
plt.show()

在这里插入图片描述

4. 多变量预测:使用多个特征预测花瓣长度

在多变量预测中,我们将选择多个特征来提高模型的预测效果。这里我们选择了 花萼宽度花瓣宽度花瓣长度 三个特征来预测 花瓣长度

4.1 数据可视化
# 选择特征
x = dataset[['花萼-width', '花瓣-width', '花萼-length']]
y = dataset['花瓣-length']

# 绘制相关关系图
sns.pairplot(dataset, x_vars=['花萼-width', '花瓣-width', '花萼-length'], y_vars='花瓣-length', kind='reg')
plt.show()

在这里插入图片描述

4.2 构建多变量线性回归模型
# 划分训练集和测试集
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)

# 构建线性回归模型
model = LinearRegression()
model.fit(x_train, y_train)

# 预测
y_pred = model.predict(x_test)

# 评估模型
print(f"均方误差: {mean_squared_error(y_test, y_pred)}")
print(f"R²得分: {r2_score(y_test, y_pred)}")
4.3 结果可视化
# 绘制预测值与真实值的对比
plt.figure(figsize=(6, 4))
plt.plot(range(len(y_test)), y_test, label='真实值')
plt.plot(range(len(y_pred)), y_pred, label='预测值')
plt.title('多变量线性回归预测')
plt.legend()
plt.show()

6. 总结

这周学习了如何使用机器学习模型对鸢尾花数据集进行单变量和多变量预测。通过特征的选择和模型的构建,可以看到多变量模型可以更好地捕捉数据特征间的关系,从而提高预测的准确性。


http://www.kler.cn/a/299917.html

相关文章:

  • vscode下poetry管理项目的debug配置
  • 023:到底什么是感受野?
  • Ubuntu24.04初始化MySQL报错 error while loading shared libraries libaio.so.1
  • 把网站程序数据上传到服务器的方法和注意事项
  • OpenStack基础架构
  • go-zero框架基本配置和错误码封装
  • Vue跨域问题、Vue配置开发环境代理服务、集成Axios发送Ajax请求、集成vue-resource发送Ajax请求
  • MySQL系统库——mysql库
  • 一些硬件知识(二十二)
  • MDK编译过程、文件及_attribute__关键字
  • 常见的ROM(只读存储器)及其区别(超详细)
  • 探索深度学习的奥秘:从理论到实践的奇幻之旅
  • NPU是什么?特点及应用
  • 系统分析师--企业信息化战略与实施
  • LeetCode 61. 旋转链表
  • openeuler-无法dnf安装包问题
  • electron: 将网址打包成exe桌面应用
  • Android Dialog:Dialog和DialogFragment的区别?DialogFragment如何使用?源码解析
  • 解锁10款超棒的图表制作软件,让数据可视化不再困难
  • leetcode 994.腐烂的橘子
  • Django-Celery-Flower实现异步和定时爬虫及其监控邮件告警
  • 【HCIA-Datacom】数据通信网络基础
  • CSS“多列布局”(补充)——WEB开发系列35
  • 网络层 VII(IP多播、移动IP)【★★★★★★】
  • 【C++】——string
  • 揭开面纱--机器学习