当前位置: 首页 > article >正文

深度学习:利用随机数据更快地测试一个新的模型在自己数据格式很复杂的时候

技巧:

比如下面一个新的模型deeponet,我自己的数据很复杂,这里在代码最后用用随机生成的数据,两分钟就完成了代码的测试成功。 

import torch
import torch.nn as nn
import torch.optim as optim

# 带偏置项的 DeepONet 结构,包括 Branch 和 Trunk 网络
class DeepONet(nn.Module):
    def __init__(self, branch_input_dim, trunk_input_dim, hidden_dim):
        super(DeepONet, self).__init__()
        
        # Branch 网络,用于处理输入点云的特征(例如位移量、压强)
        self.branch_net = nn.Sequential(
            nn.Linear(branch_input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Trunk 网络,用于处理时间和空间坐标 [x, y, z, t]
        self.trunk_net = nn.Sequential(
            nn.Linear(trunk_input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # 偏置项 bias
        self.bias = nn.Parameter(torch.zeros(1))  # 可训练的偏置项
        
        # 最终的输出层,预测位移或压强等物理状态
        self.fc_output = nn.Linear(hidden_dim, 3)
    
    def forward(self, point_features, coord_time):
        # Branch网络的输出
        branch_output = self.branch_net(point_features)
        
        # Trunk网络的输出
        trunk_output = self.trunk_net(coord_time)
        
        # 将 Branch 和 Trunk 的输出结合,计算最终的输出
        combined = branch_output * trunk_output
        output = self.fc_output(combined) + self.bias  # 加上偏置项
        
        return output

# 数据准备
# 输入的数据格式:
# point_features:3D点云的物理特征(例如位移量 pointDisplacement、压强 p)
# coord_time:空间位置和时间 [x, y, z, t]

# 示例数据的维度设置
branch_input_dim = 3  # 例如 [pointDisplacement, p, ...] 
trunk_input_dim = 4   # [x, y, z, t]
hidden_dim = 64       # 隐藏层维度,可根据需求调整

# 模型初始化
model = DeepONet(branch_input_dim, trunk_input_dim, hidden_dim)

# 损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练流程
def train(model, point_features, coord_time, target, epochs=1000):
    for epoch in range(epochs):
        optimizer.zero_grad()
        
        # 前向传播
        output = model(point_features, coord_time)
        
        # 计算损失
        loss = criterion(output, target)
        
        # 反向传播和优化
        loss.backward()
        optimizer.step()
        
        if epoch % 100 == 0:
            print(f"Epoch {epoch}, Loss: {loss.item()}")

# 示例数据,实际应用时需要替换为真实数据
N = 1000  # 样本数量
point_features = torch.randn(N, branch_input_dim)  # 3D点云的物理特征
coord_time = torch.randn(N, trunk_input_dim)       # [x, y, z, t]
target = torch.randn(N, 3)                         # 目标物理状态

# 训练模型
train(model, point_features, coord_time, target, epochs=1000)

# 推理:给定新的时空点,预测物理状态
def predict(model, point_features, coord_time):
    model.eval()
    with torch.no_grad():
        prediction = model(point_features, coord_time)
    return prediction

# 示例推理
new_point_features = torch.randn(1, branch_input_dim)
new_coord_time = torch.tensor([[0.5, 0.5, 0.5, 0.1]])  # 在 t=0.1 的 (0.5, 0.5, 0.5) 空间点
prediction = predict(model, new_point_features, new_coord_time)
print("Predicted state:", prediction)

输出如下:

Epoch 0, Loss: 1.0260347127914429
Epoch 100, Loss: 0.7669863104820251
Epoch 200, Loss: 0.5786211490631104
Epoch 300, Loss: 0.4749055504798889
Epoch 400, Loss: 0.41076529026031494
Epoch 500, Loss: 0.36538082361221313
Epoch 600, Loss: 0.39494913816452026
Epoch 700, Loss: 0.30206459760665894
Epoch 800, Loss: 0.2839098572731018
Epoch 900, Loss: 0.2648167908191681
Predicted state: tensor([[-0.2604,  0.2214,  0.5066]])

Process finished with exit code 0


http://www.kler.cn/a/394964.html

相关文章:

  • 【Qt聊天室客户端】消息功能--发布程序
  • AutoDL远程连接技巧
  • 开源项目推荐——OpenDroneMap无人机影像数据处理
  • RabbitMQ高效的消息队列中间件原理及实践
  • 深度学习——优化算法、激活函数、归一化、正则化
  • 排序算法 - 冒泡
  • layui的table组件中,对某一列的文字设置颜色为浅蓝怎么设置
  • anzocapital 昂首资本:外汇机器人趋势判断秘籍
  • 108. UE5 GAS RPG 实现地图名称更新和加载关卡
  • 爱普生机器人EPSON RC
  • python贪心算法实现(纸币找零举例)
  • DNS解析 附实验:DNS正反向解析
  • C++常用的特性-->day05
  • 【JavaEE进阶】Spring AOP 原理
  • vue3【组件封装】S-icon 图标 ( 集成 iconify )
  • 删库跑路,启动!
  • 三:网络为什么要分层:OSI模型与TCP/IP模型
  • 北京大学c++程序设计听课笔记101
  • 握手协议是如何在SSL VPN中发挥作用的?
  • torch.nn.**和torch.nn.functional.**的区别
  • 同局域网ssh连接wsl2
  • 鸿蒙NEXT开发案例:光强仪
  • 【数学二】线性代数-二次型
  • 基于STM32设计的矿山环境监测系统(NBIOT)_262
  • 机器学习——30种常见机器学习算法简要汇总
  • Ue5 umg学习(一)