当前位置: 首页 > article >正文

TensorFlow与Pytorch的转换——1简单线性回归

import numpy as np

# 生成随机数据
# 生成随机数据
x_train = np.random.rand(100000).astype(np.float32)
y_train = 0.5 * x_train + 2 

import tensorflow as tf

# 定义模型
W = tf.Variable(tf.random.normal([1]))
b = tf.Variable(tf.zeros([1]))
y = W * x_train + b
# 定义损失函数
loss = tf.reduce_mean(tf.square(y - y_train))
# 定义优化器
optimizer = tf.optimizers.SGD(0.5)
# 训练模型
for i in range(100):
    with tf.GradientTape() as tape:
        y = W * x_train + b
        loss = tf.reduce_mean(tf.square(y - y_train))
    gradients = tape.gradient(loss, [W, b])
    optimizer.apply_gradients(zip(gradients, [W, b]))

    if (i+1) % 50 == 0:
        print("Epoch [{}/{}], loss: {:.3f}, W: {:.3f}, b: {:.3f}".format(i+1, 1000, loss.numpy(), W.numpy()[0], b.numpy()[0]))

# 预测新数据
x_test = np.array([0.1, 0.2, 0.3], dtype=np.float32)
y_pred = W * x_test + b
print("Predictions:", y_pred.numpy())
import matplotlib.pyplot as plt

# 绘制结果
plt.scatter(x_train, y_train)
plt.plot(x_train, W * x_train + b, c='r')
plt.show()


Pytorch

import torch
import numpy as np
import matplotlib.pyplot as plt

# 生成随机数据
x_train = torch.from_numpy(np.random.rand(100000).astype(np.float32))
y_train = 0.5 * x_train + 2

# 定义模型参数
W = torch.randn(1, requires_grad=True)
b = torch.zeros(1, requires_grad=True)

# 定义损失函数
loss_fn = torch.nn.MSELoss()

# 定义优化器
optimizer = torch.optim.SGD([W, b], lr=0.5)

# 训练模型
for i in range(100):
    y = W * x_train + b
    loss = loss_fn(y, y_train)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (i + 1) % 50 == 0:
        print(f"Epoch [{i + 1}/{100}], loss: {loss.item():.3f}, W: {W.item():.3f}, b: {b.item():.3f}")

# 预测新数据
x_test = torch.tensor([0.1, 0.2, 0.3], dtype=torch.float32)
y_pred = W * x_test + b
print("Predictions:", y_pred.detach().numpy())

# 绘制结果
plt.scatter(x_train.numpy(), y_train.numpy())
plt.plot(x_train.numpy(), (W * x_train + b).detach().numpy(), c='r')
plt.show()

http://www.kler.cn/news/343119.html

相关文章:

  • C++Linux项目推荐-Web多人聊天+MySQL+Redis+Websocket+Json,可以写简历的C++项目
  • CompletionFormer 点云补全 学习笔记
  • 易泊:精准与高效的车牌识别解决方案
  • 黑马点评(更新中)
  • 网站设计公司怎么评估?2024网站定制公司哪家好
  • 大模型在问答领域的探索和实践
  • zabbix7.0配置中文界面
  • EXCELWPS工作表批量重命名(按照sheet1中A列内容)
  • Python 使用函数归纳判断回文质数
  • React父子组件,父组件状态更新,子组件的渲染状况
  • 浙江省发规院产业发展研究所调研组莅临迪捷软件考察调研
  • GR-ConvNet论文 学习笔记
  • 有什么方法可以保护ppt文件不被随意修改呢?
  • 从容应对DDoS攻击:小网站的防守之战
  • 【大数据】大数据治理的全面解析
  • Python | Leetcode Python题解之第463题岛屿的周长
  • JSON 格式化工具:快速便捷地格式化和查看 JSON 数据
  • 简单理解Python代码的重构
  • 重新学习Mysql数据库3:Mysql存储引擎与数据存储原理
  • 音频响度归一化 - python 实现