当前位置: 首页 > article >正文

pytorch线性回归模型预测房价例子

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

# 1. 创建线性回归模型类
class LinearRegressionModel(nn.Module):
    def __init__(self):
        super(LinearRegressionModel, self).__init__()
        self.linear = nn.Linear(1, 1)  # 1个输入特征,1个输出

    def forward(self, x):
        return self.linear(x)

# 2. 生成训练数据
area = np.array([1000, 1500, 1800, 2400, 3000], dtype=np.float32).reshape(-1, 1)  # 房屋面积(平方英尺)
price = np.array([250000, 300000, 350000, 500000, 600000], dtype=np.float32).reshape(-1, 1)  # 房价

# 标准化房屋面积
area = area / 3000  # 假设最大面积为3000平方英尺

# 转换为 PyTorch 张量
x_train = torch.from_numpy(area)
y_train = torch.from_numpy(price)

# 3. 实例化模型、定义损失函数和优化器
model = LinearRegressionModel()
criterion = nn.MSELoss()  # 均方误差损失函数
optimizer = optim.SGD(model.parameters(), lr=0.001)  # 学习率调低

# 4. 训练模型
epochs = 1000
for epoch in range(epochs):
    # 前向传播
    outputs = model(x_train)
    loss = criterion(outputs, y_train)

    # 反向传播
    optimizer.zero_grad()  # 清零梯度
    loss.backward()  # 计算梯度
    optimizer.step()  # 更新权重

    # 每100次输出一次损失值
    if (epoch + 1) % 100 == 0:
        print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')

# 5. 保存训练好的模型
torch.save(model.state_dict(), 'linear_regression_model.pth')
print("模型已保存!")

# 6. 加载模型并进行预测
loaded_model = LinearRegressionModel()
loaded_model.load_state_dict(torch.load('linear_regression_model.pth'))
loaded_model.eval()  # 设置为评估模式

# 进行预测
new_area = torch.tensor([[2500 / 3000]], dtype=torch.float32)  # 假设新房屋面积为2500平方英尺,标准化处理
predicted_price = loaded_model(new_area)
print(f"Predicted price for area 2500 sq.ft: ${predicted_price.item():,.2f}")
  • 创建模型LinearRegressionModel 是一个简单的线性回归模型,只有一个线性层 (nn.Linear)。
  • 数据准备:我们生成了一个简单的示例数据集,包含房屋面积和房价数据。数据被转换为 PyTorch 张量格式。
  • 模型训练:使用均方误差损失函数 (MSELoss) 和随机梯度下降优化器 (SGD) 来训练模型。模型在1000个迭代中进行训练,并在每100次迭代后输出损失值。
  • 保存模型:训练完成后,使用 torch.save 保存模型的参数。
  • 加载模型并进行预测:使用 torch.load 加载模型参数,并将模型设置为评估模式 (eval)。然后,我们通过模型对一个新的房屋面积值进行预测。

http://www.kler.cn/a/525281.html

相关文章:

  • 【视频+图文详解】HTML基础3-html常用标签
  • 网关登录校验
  • mysql.sock.lock 导致mysql重启失败
  • nvm安装详细教程(安装nvm、node、npm、cnpm、yarn及环境变量配置)
  • 数字化转型-工具变量(2024.1更新)-社科数据
  • AI大模型开发原理篇-2:语言模型雏形之词袋模型
  • 乐优商城项目总结
  • AI大模型开发原理篇-3:词向量和词嵌入
  • Ubuntu 16.04安装Lua
  • Yii框架中的正则表达式:如何实现高效的文本操作
  • 【Unity教程】零基础带你从小白到超神part3
  • 观察者模式和订阅发布模式的关系
  • 03链表+栈+队列(D1_链表(D1_基础学习))
  • hdfs之读写流程
  • AI学习指南Ollama篇-使用Ollama构建自己的私有化知识库
  • 【单细胞-第三节 多样本数据分析】
  • 大模型(LLM)工程师实战之路(含学习路线图、书籍、课程等免费资料推荐)
  • 为AI聊天工具添加一个知识系统 之78 详细设计之19 正则表达式 之6
  • 租赁系统为企业资产管理提供高效解决方案促进业务增长与创新
  • premierePro 2022创建序列方式
  • 为AI聊天工具添加一个知识系统 之77 详细设计之18 正则表达式 之5
  • 高级同步工具解析
  • 认识小程序页面,小程序的宿主环境
  • Python 类型注解
  • 新手项目管理的实用工具推荐
  • 《探秘人工智能:从基础到未来变革》