当前位置: 首页 > article >正文

算法面经手撕系列(3)--手撕LayerNormlization

LayerNormlization

 在许多的语言模型如Bert里,虽然都是说做的LayerNormlization,但计算均值和方差只会沿着channel维度做,并不是沿着seq_L和channel维度一起做,参考:BERT用的LayerNorm可能不是你认为的那个Layer Norm
 LayerNormlization计算流程:

  1. init里初始化C_in大小的scale和shift向量
  2. 沿Channel维度计算均值和方差
  3. 归一化

代码

 LayerNorm(InstanceNorm)实现如下:

class LayerNormalization(nn.Module):
    def __init__(self,hidden_dim,eps=1e-6):
        super(LayerNormalization, self).__init__()

        self.eps=eps
        self.gamma=nn.Parameter(torch.ones(hidden_dim))
        self.beta=nn.Parameter(torch.zeros(hidden_dim))

    def forward(self,x):
        B,seq_L,C=x.shape

        mean=x.mean(dim=-1,keepdim=True)
        std=x.std(dim=-1,keepdim=True)

        out=(x-mean)/(std+self.eps)
        out=out*self.gamma+self.beta
        return out
if __name__=="__main__":
 tensor_input=torch.rand(5,10,8)
 model=LayerNormalization(8)
 res=model(tensor_input)
 print(res)


http://www.kler.cn/a/308145.html

相关文章:

  • Elasticsearch 8.16:适用于生产的混合对话搜索和创新的向量数据量化,其性能优于乘积量化 (PQ)
  • 【算法一周目】双指针(2)
  • 基于碎纸片的拼接复原算法及MATLAB实现
  • SpringBoot参数注解
  • C语言 | Leetcode C语言题解之第557题反转字符串中的单词III
  • HAproxy 详解
  • 【算法】滑动窗口—最小覆盖子串
  • MyBatis的配置文件详解
  • druid jdbc 执行 sql 输出 开销耗时
  • Linux下抓包分析Java应用程序HTTP接口调用:基于tcpdump与Wireshark的综合示例
  • 秒验HarmonyOS NEXT集成指南
  • ERP进销存管理系统的业务全流程 Axure高保真原型源文件分享
  • 仪表盘检测系统源码分享
  • Ubuntu 20.04 部署 NET8 Web - Systemd 的方式 达到外网访问的目的
  • 【运维监控】influxdb 2.0 + grafana 11 监控jmeter 5.6.3 性能指标(2)
  • Git进阶(十五):Git LFS 使用详解
  • Leetcode—740. 删除并获得点数【中等】(unordered_map+set+sort)
  • python提取pdf表格到excel:拆分、提取、合并
  • LLM - 理解 多模态大语言模型 (MLLM) 的预训练与相关技术 (三)
  • S-Procedure的基本形式及使用
  • 补题篇--codeforces
  • 安卓将本地日志上传到服务器
  • C语言 | Leetcode C语言题解之题409题最长回文串
  • 深入理解Appium定位策略与元素交互
  • 使用原生HTML的drag实现元素的拖拽
  • Linux C execv/execl函数调用 bash -c