当前位置: 首页 > article >正文

lstm实践

今年华为杯研究生数学建模的C题第四问用到了lstm,这里配合代码简要地讲一下。

数据类型

磁通密度是一个时序数据,包含了一个周期内的磁通密度变化,我们需要对它进行降维,但PCA是不合适的,因为PCA主要关注数据的方差,无法有效捕捉周期性数据的重要特征,而磁通密度是周期性变化的。

在自然语言处理领域中,LSTM可以捕获序列内部元素之间的关联性,并且其隐藏层可以包含前序序列的信息。最后一层的隐藏层就包含了整个序列的信息,所以我们可以将最后一层的隐藏层作为降维后的向量。

我们选择LSTM对1024 维的磁通密度进行降维,具体做法是:训练时对一个周期进行切片,使用LSTM预测切片的下一时刻的磁通密度;降维时使用整个周期,获取最后一 层的hidden state作为该样本的磁通密度特征。

代码

1.数据处理

import pandas as pd
import torch as pt
import os
os.chdir('/home/burger/math/')


df1 = pd.read_excel('./data/附件一(训练集).xlsx', sheet_name='材料1')
df2 = pd.read_excel('./data/附件一(训练集).xlsx', sheet_name='材料2')
df3 = pd.read_excel('./data/附件一(训练集).xlsx', sheet_name='材料3')
df4 = pd.read_excel('./data/附件一(训练集).xlsx', sheet_name='材料4')

collom_name = [i for i in range(1,1024)]
B1 = df1[['0(磁通密度B,T)']+collom_name]
B2 = df2[['0(磁通密度,T)']+collom_name]
B3 = df3[['0(磁通密度B,T)']+collom_name]
B4 = df4[['0(磁通密度B,T)']+collom_name]
print(B1.head())

B1_t = pt.tensor(B1.values)
B2_t = pt.tensor(B2.values)
B3_t = pt.tensor(B3.values)
B4_t = pt.tensor(B4.values)
B = pt.cat((B1_t, B2_t, B3_t, B4_t), 0)
print(B.shape)

def create_dataset(data, time_step=64):  
    x, y = [], []  
    for i in range(0, data.shape[1] - time_step, 32):  
        a = data[:, i:(i + time_step)]  
        x.append(a)  
        y.append(data[:, i + time_step])  
    return pt.concat(x).float(), pt.concat(y).float()

X, Y = create_dataset(B)
print(X.shape, Y.shape)
X, Y = X.unsqueeze(-1), Y.unsqueeze(-1)
print(X.shape, Y.shape)

 2.模型定义

import torch.nn as nn 
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm


# LSTM模型定义
class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers=1):
        super(LSTMModel, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        lstm_out, _ = self.lstm(x)
        out = self.fc(lstm_out[:, -1, :])  # 只取最后一个时间步的输出
        return out
    
    def embedding(self, x):
        _, hid_cell = self.lstm(x)
        return hid_cell[0]

3.训练

input_size = 1
hidden_size = 1
output_size = 1
num_layers = 1
num_epochs = 5
batch_size = 2048
gpu = 6
train_dataset = TensorDataset(X.to(gpu), Y.to(gpu))
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

model = LSTMModel(input_size, hidden_size, output_size, num_layers).to(gpu)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)


def train(model, dataloader, criterion, optimizer, num_epochs):
    model.train()
    for epoch in range(num_epochs):
        for inputs, labels in tqdm(dataloader, unit='batch'):
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
train(model, train_loader, criterion, optimizer, num_epochs)

4.降维

df_san = pd.read_excel('./data/附件三(测试集).xlsx')
B_san = df_san[['0(磁通密度B,T)']+collom_name]
B_san_t = pt.tensor(B_san.values)
B_san_t = B_san_t.unsqueeze(-1).float().to(gpu)
print(B_san_t.shape)

emb_dataset = TensorDataset(B_san_t)
emb_loader = DataLoader(dataset=emb_dataset, batch_size=400, shuffle=False)

def embedding(model, dataloader):
    model.eval()
    embeddings = []
    for inputs in tqdm(dataloader, unit='batch'):
        outputs = model.embedding(inputs[0])
        embeddings.append(outputs)
    return pt.cat(embeddings).cpu().detach().numpy()

embeddings = embedding(model, emb_loader)
print(embeddings.shape)
embeddings = embeddings.reshape(-1)
print(embeddings.shape)

embeddings_df = pd.DataFrame({
    '磁通密度编码': embeddings,
})
embeddings_df.to_excel('./data/磁通密度编码1.xlsx', index=False)

有问题欢迎在评论区讨论!


http://www.kler.cn/news/329645.html

相关文章:

  • 如何在 Windows 10 上恢复未保存/删除的 Word 文档
  • C++ 学习,标准库
  • 结构光编解码—正反格雷码解码代码
  • SQL_create_view
  • VR、AR、MR、XR 领域最新科研资讯获取指南
  • CSS链接
  • 查找与排序-快速排序
  • 数造科技入选中国信通院《高质量数字化转型产品及服务全景图》三大板块
  • OpenCV透视变换:原理、应用与实现
  • Mysql 学习——项目实战
  • 企业级版本管理工具(1)----Git
  • WPF之UI进阶--完整了解wpf的控件和布局容器及应用
  • 栏目一:使用echarts绘制简单图形
  • HttpSession使用方法及原理
  • .c、.cpp、.cc、.cxx、.cp后缀的区别
  • YOLOv8改进,YOLOv8改进主干网络为GhostNetV3(2024年华为的轻量化架构,全网首发),助力涨点
  • C++ STL(3)list
  • 卡夫卡的理解
  • 事务原理,以及MVCC如何实现RC,RR隔离级别的
  • 告别PPT熬夜!Kimi+AIPPT一键生成PPT,效率upup!
  • Docker全家桶:从0到加载本地项目
  • docker 部署 Seatunnel 和 Seatunnel Web
  • 浏览器用户行为集群建设-数仓建模-数据计算
  • 828华为云征文|华为云Flexus云服务器X实例搭建部署H5美妆护肤分销商城、前端uniapp
  • pytorch千问模型源码分析
  • leetcode.每日一题.2516.每种字符至少取 K 个
  • 【C++】C++基础
  • 魔都千丝冥缘——软件终端架构思维———未来之窗行业应用跨平台架构
  • D21【python接口自动化学习】-python基础之内置数据类型
  • Git记录