当前位置: 首页 > article >正文

Axial Transformer笔记

来源:

[1912.12180] Axial Attention in Multidimensional Transformers

相关工作:

#transformer #Attention-mechanism

创新点:

9cdd0bc518a125ebbb38f671c225dd30.png

6dac50b223678783969f829540402519.png

8a82d0c0e641c6385773f29b049c5165.png

贡献:

  1. 轴向注意力(Axial Attention):在多维张量的单个轴上应用注意力机制,避免将张量展平。例如,在图像中分别沿行和列应用注意力,计算复杂度从O(N²)降低到O(N√N)。

  2. 模型结构:Axial Transformer由内解码器(Inner Decoder)和外解码器(Outer Decoder)组成。内解码器负责行内建模,外解码器通过轴向注意力层捕获之前行的信息,从而实现全局上下文建模。

  3. 多通道图像和视频建模:通过额外的轴向注意力层对之前通道进行编码,为当前通道建模提供上下文信息。

代码:

# ---------------------------------------  
# 论文:Axial Attention in Multidimensional Transformers. ArXiv, abs/1912.12180.  
# ---------------------------------------  
  
import torch  
from torch import nn  
from operator import itemgetter  
from torch.autograd.function import Function  
from torch.utils.checkpoint import get_device_states, set_device_states  
  
  
class Deterministic(nn.Module):  
    def __init__(self, net):  
        super().__init__()  
        self.net = net  
        self.cpu_state = None  
        self.cuda_in_fwd = None  
        self.gpu_devices = None  
        self.gpu_states = None  
  
    def record_rng(self, *args):  
        self.cpu_state = torch.get_rng_state()  
        if torch.cuda._initialized:  
            self.cuda_in_fwd = True  
            self.gpu_devices, self.gpu_states = get_device_states(*args)  
  
    def forward(self, *args, record_rng=False, set_rng=False, **kwargs):  
        if record_rng:  
            self.record_rng(*args)  
  
        if not set_rng:  
            return self.net(*args, **kwargs)  
  
        rng_devices = []  
        if self.cuda_in_fwd:  
            rng_devices = self.gpu_devices  
  
        with torch.random.fork_rng(devices=rng_devices, enabled=True):  
            torch.set_rng_state(self.cpu_state)  
            if self.cuda_in_fwd:  
                set_device_states(self.gpu_devices, self.gpu_states)  
            return self.net(*args, **kwargs)  
  
  
# heavily inspired by https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py  
# once multi-GPU is confirmed working, refactor and send PR back to source  
class ReversibleBlock(nn.Module):  
    def __init__(self, f, g):  
        super().__init__()  
        self.f = Deterministic(f)  
        self.g = Deterministic(g)  
  
    def forward(self, x, f_args={}, g_args={}):  
        x1, x2 = torch.chunk(x, 2, dim=1)  
        y1, y2 = None, None  
  
        with torch.no_grad():  
            y1 = x1 + self.f(x2, record_rng=self.training, **f_args)  
            y2 = x2 + self.g(y1, record_rng=self.training, **g_args)  
  
        return torch.cat([y1, y2], dim=1)  
  
    def backward_pass(self, y, dy, f_args={}, g_args={}):  
        y1, y2 = torch.chunk(y, 2, dim=1)  
        del y  
  
        dy1, dy2 = torch.chunk(dy, 2, dim=1)  
        del dy  
  
        with torch.enable_grad():  
            y1.requires_grad = True  
            gy1 = self.g(y1, set_rng=True, **g_args)  
            torch.autograd.backward(gy1, dy2)  
  
        with torch.no_grad():  
            x2 = y2 - gy1  
            del y2, gy1  
  
            dx1 = dy1 + y1.grad  
            del dy1  
            y1.grad = None  
  
        with torch.enable_grad():  
            x2.requires_grad = True  
            fx2 = self.f(x2, set_rng=True, **f_args)  
            torch.autograd.backward(fx2, dx1, retain_graph=True)  
  
        with torch.no_grad():  
            x1 = y1 - fx2  
            del y1, fx2  
  
            dx2 = dy2 + x2.grad  
            del dy2  
            x2.grad = None  
  
            x = torch.cat([x1, x2.detach()], dim=1)  
            dx = torch.cat([dx1, dx2], dim=1)  
  
        return x, dx  
  
  
class IrreversibleBlock(nn.Module):  
    def __init__(self, f, g):  
        super().__init__()  
        self.f = f  
        self.g = g  
  
    def forward(self, x, f_args, g_args):  
        x1, x2 = torch.chunk(x, 2, dim=1)  
        y1 = x1 + self.f(x2, **f_args)  
        y2 = x2 + self.g(y1, **g_args)  
        return torch.cat([y1, y2], dim=1)  
  
  
class _ReversibleFunction(Function):  
    @staticmethod  
    def forward(ctx, x, blocks, kwargs):  
        ctx.kwargs = kwargs  
        for block in blocks:  
            x = block(x, **kwargs)  
        ctx.y = x.detach()  
        ctx.blocks = blocks  
        return x  
  
    @staticmethod  
    def backward(ctx, dy):  
        y = ctx.y  
        kwargs = ctx.kwargs  
        for block in ctx.blocks[::-1]:  
            y, dy = block.backward_pass(y, dy, **kwargs)  
        return dy, None, None  
  
  
class ReversibleSequence(nn.Module):  
    def __init__(self, blocks, ):  
        super().__init__()  
        self.blocks = nn.ModuleList([ReversibleBlock(f, g) for (f, g) in blocks])  
  
    def forward(self, x, arg_route=(True, True), **kwargs):  
        f_args, g_args = map(lambda route: kwargs if route else {}, arg_route)  
        block_kwargs = {'f_args': f_args, 'g_args': g_args}  
        x = torch.cat((x, x), dim=1)  
        x = _ReversibleFunction.apply(x, self.blocks, block_kwargs)  
        return torch.stack(x.chunk(2, dim=1)).mean(dim=0)  
  
  
# helper functions  
  
def exists(val):  
    return val is not None  
  
  
def map_el_ind(arr, ind):  
    return list(map(itemgetter(ind), arr))  
  
  
def sort_and_return_indices(arr):  
    indices = [ind for ind in range(len(arr))]  
    arr = zip(arr, indices)  
    arr = sorted(arr)  
    return map_el_ind(arr, 0), map_el_ind(arr, 1)  
  
  
# calculates the permutation to bring the input tensor to something attend-able  
# also calculates the inverse permutation to bring the tensor back to its original shape  
  
def calculate_permutations(num_dimensions, emb_dim):  
    total_dimensions = num_dimensions + 2  
    emb_dim = emb_dim if emb_dim > 0 else (emb_dim + total_dimensions)  
    axial_dims = [ind for ind in range(1, total_dimensions) if ind != emb_dim]  
  
    permutations = []  
  
    for axial_dim in axial_dims:  
        last_two_dims = [axial_dim, emb_dim]  
        dims_rest = set(range(0, total_dimensions)) - set(last_two_dims)  
        permutation = [*dims_rest, *last_two_dims]  
        permutations.append(permutation)  
  
    return permutations  
  
  
# helper classes  
  
class ChanLayerNorm(nn.Module):  
    def __init__(self, dim, eps=1e-5):  
        super().__init__()  
        self.eps = eps  
        self.g = nn.Parameter(torch.ones(1, dim, 1, 1))  
        self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))  
  
    def forward(self, x):  
        std = torch.var(x, dim=1, unbiased=False, keepdim=True).sqrt()  
        mean = torch.mean(x, dim=1, keepdim=True)  
        return (x - mean) / (std + self.eps) * self.g + self.b  
  
  
class PreNorm(nn.Module):  
    def __init__(self, dim, fn):  
        super().__init__()  
        self.fn = fn  
        self.norm = nn.LayerNorm(dim)  
  
    def forward(self, x):  
        x = self.norm(x)  
        return self.fn(x)  
  
  
class Sequential(nn.Module):  
    def __init__(self, blocks):  
        super().__init__()  
        self.blocks = blocks  
  
    def forward(self, x):  
        for f, g in self.blocks:  
            x = x + f(x)  
            x = x + g(x)  
        return x  
  
  
class PermuteToFrom(nn.Module):  
    def __init__(self, permutation, fn):  
        super().__init__()  
        self.fn = fn  
        _, inv_permutation = sort_and_return_indices(permutation)  
        self.permutation = permutation  
        self.inv_permutation = inv_permutation  
  
    def forward(self, x, **kwargs):  
        axial = x.permute(*self.permutation).contiguous()  
  
        shape = axial.shape  
        *_, t, d = shape  
  
        # merge all but axial dimension  
        axial = axial.reshape(-1, t, d)  
  
        # attention  
        axial = self.fn(axial, **kwargs)  
  
        # restore to original shape and permutation  
        axial = axial.reshape(*shape)  
        axial = axial.permute(*self.inv_permutation).contiguous()  
        return axial  
  
  
# axial pos emb  
  
class AxialPositionalEmbedding(nn.Module):  
    def __init__(self, dim, shape, emb_dim_index=1):  
        super().__init__()  
        parameters = []  
        total_dimensions = len(shape) + 2  
        ax_dim_indexes = [i for i in range(1, total_dimensions) if i != emb_dim_index]  
  
        self.num_axials = len(shape)  
  
        for i, (axial_dim, axial_dim_index) in enumerate(zip(shape, ax_dim_indexes)):  
            shape = [1] * total_dimensions  
            shape[emb_dim_index] = dim  
            shape[axial_dim_index] = axial_dim  
            parameter = nn.Parameter(torch.randn(*shape))  
            setattr(self, f'param_{i}', parameter)  
  
    def forward(self, x):  
        for i in range(self.num_axials):  
            x = x + getattr(self, f'param_{i}')  
        return x  
  
  
# attention  
  
class SelfAttention(nn.Module):  
    def __init__(self, dim, heads, dim_heads=None):  
        super().__init__()  
        self.dim_heads = (dim // heads) if dim_heads is None else dim_heads  
        dim_hidden = self.dim_heads * heads  
  
        self.heads = heads  
        self.to_q = nn.Linear(dim, dim_hidden, bias=False)  
        self.to_kv = nn.Linear(dim, 2 * dim_hidden, bias=False)  
        self.to_out = nn.Linear(dim_hidden, dim)  
  
    def forward(self, x, kv=None):  
        kv = x if kv is None else kv  
        q, k, v = (self.to_q(x), *self.to_kv(kv).chunk(2, dim=-1))  
  
        b, t, d, h, e = *q.shape, self.heads, self.dim_heads  
  
        merge_heads = lambda x: x.reshape(b, -1, h, e).transpose(1, 2).reshape(b * h, -1, e)  
        q, k, v = map(merge_heads, (q, k, v))  
  
        dots = torch.einsum('bie,bje->bij', q, k) * (e ** -0.5)  
        dots = dots.softmax(dim=-1)  
        out = torch.einsum('bij,bje->bie', dots, v)  
  
        out = out.reshape(b, h, -1, e).transpose(1, 2).reshape(b, -1, d)  
        out = self.to_out(out)  
        return out  
  
  
# axial attention class  
  
class AxialAttention(nn.Module):  
    def __init__(self, dim, num_dimensions=2, heads=8, dim_heads=None, dim_index=-1, sum_axial_out=True):  
        assert (dim % heads) == 0, 'hidden dimension must be divisible by number of heads'  
        super().__init__()  
        self.dim = dim  
        self.total_dimensions = num_dimensions + 2  
        self.dim_index = dim_index if dim_index > 0 else (dim_index + self.total_dimensions)  
  
        attentions = []  
        for permutation in calculate_permutations(num_dimensions, dim_index):  
            attentions.append(PermuteToFrom(permutation, SelfAttention(dim, heads, dim_heads)))  
  
        self.axial_attentions = nn.ModuleList(attentions)  
        self.sum_axial_out = sum_axial_out  
  
    def forward(self, x):  
        assert len(x.shape) == self.total_dimensions, 'input tensor does not have the correct number of dimensions'  
        assert x.shape[self.dim_index] == self.dim, 'input tensor does not have the correct input dimension'  
  
        if self.sum_axial_out:  
            return sum(map(lambda axial_attn: axial_attn(x), self.axial_attentions))  
  
        out = x  
        for axial_attn in self.axial_attentions:  
            out = axial_attn(out)  
        return out  
  
  
# axial image transformer  
  
class AxialImageTransformer(nn.Module):  
    def __init__(self, dim, depth, heads=8, dim_heads=None, dim_index=1, reversible=True, axial_pos_emb_shape=None):  
        super().__init__()  
        permutations = calculate_permutations(2, dim_index)  
  
        get_ff = lambda: nn.Sequential(  
            ChanLayerNorm(dim),  
            nn.Conv2d(dim, dim * 4, 3, padding=1),  
            nn.LeakyReLU(inplace=True),  
            nn.Conv2d(dim * 4, dim, 3, padding=1)  
        )  
  
        self.pos_emb = AxialPositionalEmbedding(dim, axial_pos_emb_shape, dim_index) if exists(  
            axial_pos_emb_shape) else nn.Identity()  
  
        layers = nn.ModuleList([])  
        for _ in range(depth):  
            attn_functions = nn.ModuleList(  
                [PermuteToFrom(permutation, PreNorm(dim, SelfAttention(dim, heads, dim_heads))) for permutation in  
                 permutations])  
            conv_functions = nn.ModuleList([get_ff(), get_ff()])  
            layers.append(attn_functions)  
            layers.append(conv_functions)  
  
        execute_type = ReversibleSequence if reversible else Sequential  
        self.layers = execute_type(layers)  
  
    def forward(self, x):  
        x = self.pos_emb(x)  
        return self.layers(x)  
  
  
# 输入 N C H W,  输出 N C H Wif __name__ == '__main__':  
    block = AxialImageTransformer(dim=64, depth=12, reversible=True).cuda()  
    input = torch.rand(3, 64, 32, 32).cuda()  
    output = block(input)  
    print(input.size(), output.size())

http://www.kler.cn/a/520059.html

相关文章:

  • Couchbase UI: Dashboard
  • Hook 函数
  • 乒乓球日常烧拍日记之四海绵支撑
  • 左右互博02-unidbg主动调用外层so函数
  • 【2024年华为OD机试】 (A卷,100分)- 整理扑克牌(JavaScriptJava PythonC/C++)
  • Spring Data JPA 实战:构建高性能数据访问层
  • 人工智能在医疗领域的应用与挑战
  • 基于dlib/face recognition人脸识别推拉流实现
  • 评估篇| 大模型评测综述
  • 裁员避坑指南(9)
  • wxwidgets直接获取系统图标,效果类似QFileIconProvider
  • 【测试】UI自动化测试
  • pyhton学习笔记(三)
  • 相同的树及延伸题型(C语言详解版)
  • 机器学习-线性回归(对于f(x;w)=w^Tx+b理解)
  • 几种常见的求特殊方程正整数解的方法和示例
  • 第28章 测试驱动开发模式:深入绿条模式及相关技术
  • C++17 命名空间的新特性:简化与优化的典范
  • 详解三种常用标准化:Batch Norm、Layer Norm和RMSNorm
  • centos7执行yum操作时报错Could not retrieve mirrorlist http://mirrorlist.centos.org解决
  • 使用 Redis List 和 Pub/Sub 实现简单的消息队列
  • 代码随想录训练营第五十八天| 拓扑排序精讲 dijkstra(朴素版)精讲
  • Vue3 provide/inject用法总结
  • 解锁.NET Standard库:从0到1的创建与打包秘籍
  • 使用递归函数求1~n之和
  • 基于SpringBoot的网上考试系统