当前位置: 首页 > article >正文

算法面经手撕系列(3)--手撕LayerNormlization

LayerNormlization

 在许多的语言模型如Bert里,虽然都是说做的LayerNormlization,但计算均值和方差只会沿着channel维度做,并不是沿着seq_L和channel维度一起做,参考:BERT用的LayerNorm可能不是你认为的那个Layer Norm
 LayerNormlization计算流程:

  1. init里初始化C_in大小的scale和shift向量
  2. 沿Channel维度计算均值和方差
  3. 归一化

代码

 LayerNorm(InstanceNorm)实现如下:

class LayerNormalization(nn.Module):
    def __init__(self,hidden_dim,eps=1e-6):
        super(LayerNormalization, self).__init__()

        self.eps=eps
        self.gamma=nn.Parameter(torch.ones(hidden_dim))
        self.beta=nn.Parameter(torch.zeros(hidden_dim))

    def forward(self,x):
        B,seq_L,C=x.shape

        mean=x.mean(dim=-1,keepdim=True)
        std=x.std(dim=-1,keepdim=True)

        out=(x-mean)/(std+self.eps)
        out=out*self.gamma+self.beta
        return out
if __name__=="__main__":
 tensor_input=torch.rand(5,10,8)
 model=LayerNormalization(8)
 res=model(tensor_input)
 print(res)


http://www.kler.cn/news/308145.html

相关文章:

  • 【算法】滑动窗口—最小覆盖子串
  • MyBatis的配置文件详解
  • druid jdbc 执行 sql 输出 开销耗时
  • Linux下抓包分析Java应用程序HTTP接口调用:基于tcpdump与Wireshark的综合示例
  • 秒验HarmonyOS NEXT集成指南
  • ERP进销存管理系统的业务全流程 Axure高保真原型源文件分享
  • 仪表盘检测系统源码分享
  • Ubuntu 20.04 部署 NET8 Web - Systemd 的方式 达到外网访问的目的
  • 【运维监控】influxdb 2.0 + grafana 11 监控jmeter 5.6.3 性能指标(2)
  • Git进阶(十五):Git LFS 使用详解
  • Leetcode—740. 删除并获得点数【中等】(unordered_map+set+sort)
  • python提取pdf表格到excel:拆分、提取、合并
  • LLM - 理解 多模态大语言模型 (MLLM) 的预训练与相关技术 (三)
  • S-Procedure的基本形式及使用
  • 补题篇--codeforces
  • 安卓将本地日志上传到服务器
  • C语言 | Leetcode C语言题解之题409题最长回文串
  • 深入理解Appium定位策略与元素交互
  • 使用原生HTML的drag实现元素的拖拽
  • Linux C execv/execl函数调用 bash -c
  • 【疑难杂症2024-005】docker-compose中设置容器的ip为固定ip后,服务无法启动
  • supermap iclient3d for cesium中entity使用
  • 【梯度下降|链式法则】卷积神经网络中的参数是如何传输和更新的?
  • 常用压接线端子教程
  • 力扣爆刷第176天之贪心全家桶(共15道题)
  • Java 入门指南:JVM(Java虚拟机)垃圾回收机制 —— 内存分配和回收规则
  • Linux 基础入门操作-实验二 makefile使用介绍 和 实验三 hello 输出
  • 【计算机网络】HTTP相关问题与解答
  • 深度学习:入门简介
  • ESP01的AT指令连接到阿里云平台