当前位置: 首页 > article >正文

【作业】LSTM

目录

习题6-4  推导LSTM网络中参数的梯度, 并分析其避免梯度消失的效果

​编辑

习题6-3P  编程实现下图LSTM运行过程

 1. 使用Numpy实现LSTM算子

2. 使用nn.LSTMCell实现 

3. 使用nn.LSTM实现

参考链接


 

习题6-4  推导LSTM网络中参数的梯度, 并分析其避免梯度消失的效果

LSTM框架如下: 

 

 

 

 

        总而言之:LSTM遗忘门值可以选择在[0,1]之间,让LSTM来改善梯度消失的情况。也可以选择接近1,让遗忘门饱和,此时远距离信息梯度不消失。也可以选择接近0,此时模型是故意阻断梯度流,遗忘之前信息。 

习题6-3P  编程实现下图LSTM运行过程

1. 使用Numpy实现LSTM算子

2. 使用nn.LSTMCell实现

3. 使用nn.LSTM实现

 1. 使用Numpy实现LSTM算子

import numpy as np

#定义激活函数
def sigmoid(x):
    return 1/(1+np.exp(-x))

def tanh(x):
    return (np.exp(x)-np.exp(-x))/(np.exp(x)+np.exp(-x))

#权重
input_weight=np.array([1,0,0,0])
inputgate_weight=np.array([0,100,0,-10])
forgetgate_weight=np.array([0,100,0,10])
outputgate_weight=np.array([0,0,100,-10])

#输入
input=np.array([[1,0,0,1],[3,1,0,1],[2,0,0,1],[4,1,0,1],[2,0,0,1],[1,0,1,1],[3,-1,0,1],[6,1,0,1],[1,0,1,1]])

y=[]   #输出
c_t=0  #内部状态

for x in input:
    g_t=tanh(np.matmul(input_weight,x)) #候选状态
    i_t=np.round(sigmoid(np.matmul(inputgate_weight,x)))  #输入门
    after_inputgate=g_t*i_t       #候选状态经过输入门
    f_t=np.round(sigmoid(np.matmul(forgetgate_weight,x))) #遗忘门
    after_forgetgate=f_t*c_t      #内部状态经过遗忘门
    c_t=np.add(after_inputgate,after_forgetgate) #新的内部状态
    o_t=np.round(sigmoid(np.matmul(outputgate_weight,x))) #输出门
    after_outputgate=o_t*tanh(c_t)     #激活后新的内部状态经过输出门
    y.append(round(after_outputgate,2))   #输出

print('输出:',y)

输出:

output:[0.0, 0.0, 0.0, 0.0, 0.0, 0.96, 0.0, 0.0, 0.76]

 

2. 使用nn.LSTMCell实现 

import torch
import torch.nn as nn

#实例化
input_size=4
hidden_size=1
cell=nn.LSTMCell(input_size=input_size,hidden_size=hidden_size)
#修改模型参数 weight_ih.shape=(4*hidden_size, input_size),weight_hh.shape=(4*hidden_size, hidden_size),
#weight_ih、weight_hh分别为输入x、隐层h分别与输入门、遗忘门、候选、输出门的权重
cell.weight_ih.data=torch.tensor([[0,100,0,-10],[0,100,0,10],[1,0,0,0],[0,0,100,-10]],dtype=torch.float32)
cell.weight_hh.data=torch.zeros(4,1)
print('cell.weight_ih.shape:',cell.weight_ih.shape)
print('cell.weight_hh.shape',cell.weight_hh.shape)
#初始化h_0,c_0
h_t=torch.zeros(1,1)
c_t=torch.zeros(1,1)
#模型输入input_0.shape=(batch,seq_len,input_size)
input_0=torch.tensor([[[1,0,0,1],[3,1,0,1],[2,0,0,1],[4,1,0,1],[2,0,0,1],[1,0,1,1],[3,-1,0,1],[6,1,0,1],[1,0,1,1]]],dtype=torch.float32)
#交换前两维顺序,方便遍历input.shape=(seq_len,batch,input_size)
input=torch.transpose(input_0,1,0)
print('input.shape:',input.shape)
output=[]
#调用
for x in input:
    h_t,c_t=cell(x,(h_t,c_t))
    output.append(np.around(h_t.item(), decimals=3))#保留3位小数
print('output:',output)

输出:

output:[0.0, 0.0, 0.0, 0.0, 0.0, 0.96, 0.0, 0.0, 0.76]

 

3. 使用nn.LSTM实现

#LSTM
#实例化
input_size=4
hidden_size=1
lstm=nn.LSTM(input_size=input_size,hidden_size=hidden_size,batch_first=True)
#修改模型参数
lstm.weight_ih_l0.data=torch.tensor([[0,100,0,-10],[0,100,0,10],[1,0,0,0],[0,0,100,-10]],dtype=torch.float32)
lstm.weight_hh_l0.data=torch.zeros(4,1)
#模型输入input.shape=(batch,seq_len,input_size)
input=torch.tensor([[[1,0,0,1],[3,1,0,1],[2,0,0,1],[4,1,0,1],[2,0,0,1],[1,0,1,1],[3,-1,0,1],[6,1,0,1],[1,0,1,1]]],dtype=torch.float32)
#初始化h_0,c_0
h_t=torch.zeros(1,1,1)
c_t=torch.zeros(1,1,1)
#调用
output,(h_t,c_t)=lstm(input,(h_t,c_t))
rounded_output = torch.round(output * 1000) / 1000  # 保留3位小数
print(rounded_output)

 输出结果

output:[0.0, 0.0, 0.0, 0.0, 0.0, 0.96, 0.0, 0.0, 0.7672]

 

参考链接

LSTM参数梯度推导与实现:对抗梯度消失,

LSTM参数梯度推导与编程实现,

李宏毅机器学习笔记:RNN循环神经网络_李宏毅机器学习课程笔记-CSDN博客

HBU-NNDL 作业10:第六章课后题(LSTM | GRU)-CSDN博客


http://www.kler.cn/a/452061.html

相关文章:

  • Websocket客户端从Openai Realtime api Sever只收到部分数据问题分析
  • Ruby+Selenium教程
  • 【Linux】ChatGLM-4-9B模型之All Tools
  • GTID下复制问题和解决
  • SecureCRT汉化版
  • Git远程仓库的多人协作
  • Linux系统编程之目录操作
  • 如何确保数据大屏的交互设计符合用户需求?(附实践资料下载)
  • 3D制图软件SOLIDWORKS:开启高效设计与创新的新纪元
  • python爬虫爬抖音小店商品数据+数据可视化
  • 自动驾驶AVM环视算法--python版本的车轮投影模式
  • 云手机与Temu矩阵:跨境电商运营新引擎
  • 从零到一:如何快速生成和优化Prompt
  • Webpack学习笔记(5)
  • Ubuntu vi(vim)编辑器配置一键补全main函数
  • Linux使用教程及常用命令大全
  • 【西安电子科技大学考研】25官方复试专业课参考书目汇总
  • “家政服务匹配”:家政公司服务平台的智能推荐系统
  • 深度学习笔记2:使用pytorch构建神经网络
  • 服务器压力测试怎么做
  • 如何借助 Babel+TS+ESLint 构建现代 JS 工程环境?
  • 鸿蒙UI开发——全局自定义弹窗实现
  • 选择排序:简单算法的实现与优化探索
  • nvidia docker, nvidia docker2, nvidia container toolkits区别
  • 数据格式之-XML数据查询语句xpath介绍
  • LLaMA-Factory GLM4-9B-CHAT LoRA 微调实战