当前位置: 首页 > article >正文

[paddle] 非线性拟合问题的训练

利用paddlepaddle建立神经网络,模拟有限个数据的非线性拟合

本文仍然考虑 f ( x ) = sin ⁡ ( x ) x f(x)=\frac{\sin(x)}{x} f(x)=xsin(x) 函数在区间 [-10,10] 上固定数据的拟合。

import paddle
import paddle.nn as nn
import numpy as np
import matplotlib.pyplot as plt

# 设置随机种子以确保结果的可重复性
paddle.seed(1)

# 生成数据集
x_data = (np.random.rand(500) * 20 - 10).astype('float32')  # 生成500个随机x值,范围在-10到10之间
y_data = np.sin(x_data) / x_data  # 生成y值
y_data = y_data.reshape(-1, 1)  # 将y_data转换为二维数组

# 定义模型,一个具有2个隐藏层的多层感知器
class MyModel(nn.Layer):
    def __init__(self):
        super(MyModel, self).__init__()
        self.hidden1 = nn.Linear(in_features=1, out_features=50)
        self.bn = nn.BatchNorm1D(num_features=50)
        self.hidden2 = nn.Linear(in_features=50, out_features=1)

    def forward(self, x):
        x = paddle.tanh(self.hidden1(x))
        x = self.bn(x)
        x = self.hidden2(x)
        return x

model = MyModel()

# 定义损失函数
loss_fn = nn.MSELoss()

# 设置优化器
optimizer = paddle.optimizer.Adam(learning_rate=0.01, parameters=model.parameters())

# 训练数据
train_data = paddle.to_tensor(x_data).unsqueeze(-1), paddle.to_tensor(y_data)

# 训练模型
epochs = 1000
for epoch in range(1, epochs + 1):
    loss = loss_fn(model(train_data[0]), train_data[1])
    loss.backward()
    optimizer.step()
    optimizer.clear_grad()
    if epoch % 100 == 0:
        print(f'Epoch {epoch}: Loss = {loss.numpy()}')

# 使用训练好的模型进行预测
y_pred = model(train_data[0]).numpy()

# 可视化结果
plt.scatter(x_data, y_data, label='True')
plt.scatter(x_data, y_pred, label='Predicted')
plt.legend()
plt.show()

在这里插入图片描述


http://www.kler.cn/a/466785.html

相关文章:

  • 用公网服务代理到本地电脑笔记
  • 如何通过API实现淘宝商品评论数据抓取?item_review获取淘宝商品评论
  • springboot550乐乐农产品销售系统(论文+源码)_kaic
  • 如何提高软件研发效率?
  • LLM大模型RAG内容安全合规检查
  • 【C++】P2550 [AHOI2001] 彩票摇奖
  • React 性能优化
  • 数仓建模(二) 从关系型数据库到数据仓库的演变
  • 淘宝商品详情API返回值说明:Python爬虫代码示例
  • perf:对hutool的BeanUtil工具类做补充
  • 【51单片机零基础-chapter3:按键:独立按键|||附带常见C语句.逻辑运算符】
  • 中国科技产业化促进会深入深圳企业调研
  • gesp(C++一级)(17)洛谷:B4062:[GESP202412 一级] 温度转换
  • 在Linux系统中使用字符图案和VNC运行Qt Widgets程序
  • IDEA Plugins中搜索不到插件解决办法
  • 自动化测试常考的面试题+答案汇总(持续更新)
  • React 网络请求优化
  • CVSS漏洞评分系统曝出严重缺陷
  • 【源码+文档+调试讲解】“健康早知道”微信小程序
  • 生成对抗网络 (Generative Adversarial Network, GAN) 算法MNIST图像生成任务及CelebA图像超分辨率任务
  • 深入理解 Android 中的 ComponentInfo
  • Hive集群安装部署
  • Markdown中流程图的用法
  • 解决 HTML 表单输入框与按钮对齐问题
  • LeetCode 力扣 热题 100道(二十三)找到字符串中所有字母异位词(C++)
  • issue问题全解