当前位置: 首页 > article >正文

MUSE: PARALLEL MULTI-SCALE ATTENTION FOR SEQUENCE TO SEQUENCE LEARNING 笔记

来源:

MUSE: Parallel Multi-Scale Attention for Sequence to Sequence Learning

相关工作:

#自注意力机制 #卷积神经网络 #动态卷积 #局部注意力 #并行计算

创新点:

3iosf5sv.nnz.png

j0wx3bfj.zhc.png

izxiqlqk.5ib.png

bhncuctb.pwa.png

0lis4twz.dfp.png

贡献:

  1. 并行多尺度注意力机制

    • MUSE引入了并行多尺度注意力机制,该机制同时捕获序列数据中的长距离和短距离语言结构。这是通过对序列进行不同尺度的并行编码来实现的,利用自注意力和逐点变换。
  2. 结合卷积和自注意力

    • 为了解决自注意力在处理长序列时可能忽略局部信息的问题,MUSE模型将卷积操作和自注意力结合起来,以更有效地学习序列的局部和全局特征。
  3. 共享投影空间

    • 文章强调了在自注意力和卷积操作中使用共享投影空间的重要性。这种设计使得两种操作能够在相同的隐藏空间中进行,从而有助于整合局部和全局特征表示。
  4. 动态卷积核选择

    • MUSE引入了一种门控机制,用于自动选择不同卷积单元的权重,这允许模型动态地选择最适合当前层的卷积核大小。

代码:


# ---------------------------------------  
# 论文: MUSE: PARALLEL MULTI-SCALE ATTENTION FOR SEQUENCE TO SEQUENCE LEARNING (arxiv 2019)  
# ---------------------------------------  
import numpy as np  
import torch  
from torch import nn  
from torch.nn import init  
  
  
class Depth_Pointwise_Conv1d(nn.Module):  
    def __init__(self, in_ch, out_ch, k):  
        super().__init__()  
        if (k == 1):  
            self.depth_conv = nn.Identity()  
        else:  
            self.depth_conv = nn.Conv1d(  
                in_channels=in_ch,  
                out_channels=in_ch,  
                kernel_size=k,  
                groups=in_ch,  
                padding=k // 2  
            )  
        self.pointwise_conv = nn.Conv1d(  
            in_channels=in_ch,  
            out_channels=out_ch,  
            kernel_size=1,  
            groups=1  
        )  
  
    def forward(self, x):  
        out = self.pointwise_conv(self.depth_conv(x))  
        return out  
  
  
class MUSEAttention(nn.Module):  
  
    def __init__(self, d_model, d_k, d_v, h, dropout=.1):  
  
        super(MUSEAttention, self).__init__()  
        self.fc_q = nn.Linear(d_model, h * d_k)  
        self.fc_k = nn.Linear(d_model, h * d_k)  
        self.fc_v = nn.Linear(d_model, h * d_v)  
        self.fc_o = nn.Linear(h * d_v, d_model)  
        self.dropout = nn.Dropout(dropout)  
  
        self.conv1 = Depth_Pointwise_Conv1d(h * d_v, d_model, 1)  
        self.conv3 = Depth_Pointwise_Conv1d(h * d_v, d_model, 3)  
        self.conv5 = Depth_Pointwise_Conv1d(h * d_v, d_model, 5)  
        self.dy_paras = nn.Parameter(torch.ones(3))  
        self.softmax = nn.Softmax(-1)  
  
        self.d_model = d_model  
        self.d_k = d_k  
        self.d_v = d_v  
        self.h = h  
  
        self.init_weights()  
  
    def init_weights(self):  
        for m in self.modules():  
            if isinstance(m, nn.Conv2d):  
                init.kaiming_normal_(m.weight, mode='fan_out')  
                if m.bias is not None:  
                    init.constant_(m.bias, 0)  
            elif isinstance(m, nn.BatchNorm2d):  
                init.constant_(m.weight, 1)  
                init.constant_(m.bias, 0)  
            elif isinstance(m, nn.Linear):  
                init.normal_(m.weight, std=0.001)  
                if m.bias is not None:  
                    init.constant_(m.bias, 0)  
  
    def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):  
  
        # Self Attention  
        b_s, nq = queries.shape[:2]  
        nk = keys.shape[1]  
  
        q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3)  # (b_s, h, nq, d_k)  
        k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1)  # (b_s, h, d_k, nk)  
        v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3)  # (b_s, h, nk, d_v)  
  
        att = torch.matmul(q, k) / np.sqrt(self.d_k)  # (b_s, h, nq, nk)  
        if attention_weights is not None:  
            att = att * attention_weights  
        if attention_mask is not None:  
            att = att.masked_fill(attention_mask, -np.inf)  
        att = torch.softmax(att, -1)  
        att = self.dropout(att)  
  
        out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v)  # (b_s, nq, h*d_v)  
        out = self.fc_o(out)  # (b_s, nq, d_model)  
  
        v2 = v.permute(0, 1, 3, 2).contiguous().view(b_s, -1, nk)  # bs,dim,n  
        self.dy_paras = nn.Parameter(self.softmax(self.dy_paras))  
        out2 = self.dy_paras[0] * self.conv1(v2) + self.dy_paras[1] * self.conv3(v2) + self.dy_paras[2] * self.conv5(v2)  
        out2 = out2.permute(0, 2, 1)  # bs.n.dim  
  
        out = out + out2  
        return out  
  
  
# 输入 B N C,  输出 B N Cif __name__ == '__main__':  
    block = MUSEAttention(d_model=32, d_k=32, d_v=32, h=8).cuda()  
    print("开始运行")  
    input = torch.rand(3, 64, 32).cuda()  
    output = block(input, input, input)  
    print(input.size(), output.size())  
    print('结束运行')


http://www.kler.cn/a/517510.html

相关文章:

  • Python Pandas数据清洗与处理
  • Spring 源码学习(七)——注解后处理器-2
  • 【数据结构】_不带头非循环单向链表
  • vue3中自定一个组件并且能够用v-model对自定义组件进行数据的双向绑定
  • 关于deepin上运行Qt开发的程序
  • 利用 SAM2 模型探测卫星图像中的农田边界
  • Go语言中变量在栈和堆上分配情况分析
  • 论文:深度可分离神经网络存内计算处理芯片
  • [MySQL]数据库表内容的增删查改操作大全
  • Word 中实现方框内点击自动打 √ ☑
  • -bash: ./uninstall.command: /bin/sh^M: 坏的解释器: 没有那个文件或目录
  • Kotlin泛型学习篇
  • 机器学习-线性回归(参数估计之经验风险最小化)
  • Hive之加载csv格式数据到hive
  • 设计模式的艺术-命令模式
  • 嵌入式知识点总结 ARM体系与架构 专题提升(四)-编程
  • 【Java】阿里云OSS上传、删除文件
  • git基础使用命令
  • YOLOv10-1.1部分代码阅读笔记-val.py
  • 《罗宾逊-旅途VR》Build2108907官方学习版
  • Oracle 机器宕机之后启动数据库
  • 大数据,Hadoop,HDFS的简单介绍
  • 从根源分析,调试,定位和解决MacOS ld: unsupported tapi file type ‘!tapi-tbd‘ in YAML file
  • Leecode刷题C语言之购买水果需要的最小金币数
  • 【实践】Python实现气象数据分析与可视化
  • Ubuntu 安装 QGIS LTR 3.34