当前位置: 首页 > article >正文

原生稀疏注意力NSA 替换transformer 注意力进行文本生成训练

DeepSeek-R1这篇文章,聚焦范围更加小,R1的重点在于提出了一个文本生成的训练策略和蒸馏策略,这篇文章则是提出了一个注意力机制NSA,主要解决的是长序列做注意力时带来的效率问题。通篇文章看下来,它的实际意义可能比较局限,因此本文仅关注其主要内容,对于具体细节和实验结果并不进一步细究。

论文标题:Native Sparse Attention: Hardware-Aligned and Natively Trainable Sparse Attention
论文链接:[2502.11089] Native Sparse Attention: Hardware-Aligned and Natively Trainable Sparse Attention

简单总结起来就是

实验效果惊艳:性能不降反升,速度大幅提升!

实验结果令人振奋!在通用基准测试、长文本任务和指令推理方面,使用 NSA 预训练的模型性能不仅没有下降,反而超越了 Full Attention 模型!

NSA 的核心亮点可以概括为以下两点:

1.动态分层稀疏策略: NSA 采用了一种动态分层的稀疏策略,结合了粗粒度的 Token 压缩 和 细粒度的 Token 选择。这种策略既能保证模型对全局上下文的感知,又能兼顾局部信息的精确性

2.关键创新:

算术强度平衡的算法设计与硬件优化: NSA 通过精巧的算法设计,并针对现代硬件进行了实现优化,显著提升了计算速度

端到端可训练: NSA 支持端到端训练,这意味着它不仅在推理阶段高效,还能减少预训练的计算量,同时不牺牲模型性能!

 

Attention的稀疏特性,其实从BERT时代开始就已经被广泛验证了。最早像Longformer、BigBird这些模型提出的几种稀疏Attention Pattern(比如Sliding Window、Global Attention——现在叫Attention Sink),直到今天依然被广泛使用。Attention天然的稀疏性,意味着每个词元在计算时,只需要从海量的上文中选出top-k相关的部分进行Attention计算。这个思路很简单,但难点就在于如何快速找到top-k的相关上文。如果逐token去选,计算和访存的过程又会回到Full-Attention的复杂度。

 

稀疏Attn为什么还能超过Full-Attn?
长文本具有天然的高稀疏性与富噪音性。处理每个token确实不要把全文都过一遍,而Full-attention机制,总是能确保每两个token之间的相关性不为0。这也就带来了计算上的噪音。所以不难理解,一个well-trained 稀疏Attn能够为每个token屏蔽掉部分噪音,效果也能带来些许提升。但效果的有限提升外,还是效率的提升更让人惊喜。
 

知乎上也有一篇介绍

https://zhuanlan.zhihu.com/p/24604821449

逛github时已经有大神做了论文复现,不仅提供了SparseAttention 还替换了原有Transformer 模型里的attention层

 

稀疏注意力SparseAttention 模型网络定义

class SparseAttention(Module):
    def __init__(
        self,
        dim,
        dim_head,
        heads,
        sliding_window_size,
        compress_block_size,
        selection_block_size,
        num_selected_blocks,
        num_compressed_mem_kv = 4,
        norm = True,
        use_diff_topk = False,
    ):
        super().__init__()
        self.heads = heads
        self.scale = dim_head ** -0.5

        assert compress_block_size == selection_block_size, 'start off with compressed being equal to selection block sizes'

        dim_inner = dim_head * heads

        self.norm = nn.RMSNorm(dim) if norm else nn.Identity()

        # rotary

        self.rotary_emb = RotaryEmbedding(dim_head)

        # qkv

        self.to_qkv = nn.Linear(dim, dim_inner * 3, bias = False)

        # sliding window strategy

        self.sliding_window = LocalAttention(
            dim = dim_head,
            window_size = sliding_window_size,
            causal = True,
            exact_windowsize = True,
            autopad = True
        )

        # compress strategy

        self.compress_block_size = compress_block_size

        assert num_compressed_mem_kv > 0

        self.compress_mem_kv = nn.Parameter(torch.zeros(2, heads, num_compressed_mem_kv, dim_head))
        self.k_intrablock_positions = nn.Parameter(torch.zeros(heads, compress_block_size, dim_head))
        self.v_intrablock_positions = nn.Parameter(torch.zeros(heads, compress_block_size, dim_head))

        self.k_compress = nn.Sequential(
            Rearrange('b h n d -> b (h d) n'),
            nn.Conv1d(dim_head * heads, dim_head * heads, compress_block_size, stride = compress_block_size, groups = heads),
            Rearrange('b (h d) nc -> b h nc d', h = heads)
        )

        self.v_compress = nn.Sequential(
            Rearrange('b h n d -> b (h d) n'),
            nn.Conv1d(dim_head * heads, dim_head * heads, compress_block_size, stride = compress_block_size, groups = heads),
            Rearrange('b (h d) nc -> b h nc d', h = heads)
        )

        # selection related
        self.use_diff_topk = use_diff_topk
        self.selection_block_size = selection_block_size
        self.num_selected_blocks = num_selected_blocks

        # they combine the three sparse branches through a learned combine with sigmoid activation

        self.to_strategy_combine = nn.Sequential(
            nn.Linear(dim, 3 * heads),
            nn.Sigmoid(),
            Rearrange('b n (h s) -> b h n s', h = heads)
        )

        # split and merging heads

        self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
        self.merge_heads = Rearrange('b h n d -> b n (h d)')

        # combining heads

        self.combine_heads = nn.Linear(dim_inner, dim, bias = False)
  • dim: 输入特征的维度。
  • dim_head: 每个注意力头的维度。
  • heads: 注意力头的数量。
  • sliding_window_size: 滑动窗口的大小,用于局部注意力。
  • compress_block_size: 压缩块的大小。
  • selection_block_size: 选择块的大小。
  • num_selected_blocks: 选择的块数量。
  • num_compressed_mem_kv: 压缩的记忆键值对数量(默认为 4)。
  • norm: 是否使用归一化(默认为 True)。
  • use_diff_topk: 是否使用不同的 Top-k 策略(默认为 False)。

 

SparseAttention流程总结

  • 头部和缩放

    self.heads = heads
    self.scale = dim_head ** -0.5
    
    • heads 保存注意力头的数量。
    • scale 用于缩放注意力的分数,防止数值过大。
  • 归一化

    self.norm = nn.RMSNorm(dim) if norm else nn.Identity()
    
    • 使用 RMSNorm 进行归一化,或者使用身份函数(如果不需要归一化)。
  • 旋转嵌入

    self.rotary_emb = RotaryEmbedding(dim_head)
    
    • 用于实现旋转位置编码,增强模型对序列位置信息的理解。
  • QKV 线性变换

    self.to_qkv = nn.Linear(dim, dim_inner * 3, bias = False)
    
    • 将输入特征映射到查询(Q)、键(K)和值(V)的线性空间。
  • 滑动窗口注意力

    self.sliding_window = LocalAttention(...)
    
    • 实现局部注意力机制,限制注意力计算在滑动窗口内,以减少计算复杂度。
  • 压缩策略

    self.compress_mem_kv = nn.Parameter(torch.zeros(2, heads, num_compressed_mem_kv, dim_head))
    
    • 初始化压缩后的键值存储。
  • 内块位置参数

    self.k_intrablock_positions = nn.Parameter(torch.zeros(heads, compress_block_size, dim_head))
    self.v_intrablock_positions = nn.Parameter(torch.zeros(heads, compress_block_size, dim_head))
    
    • 用于保存每个头部在压缩块中的位置。
  • 压缩操作

    self.k_compress = nn.Sequential(...)
    self.v_compress = nn.Sequential(...)
    
    • 使用卷积层对键和值进行压缩,减少计算量。
  • 选择策略

    self.to_strategy_combine = nn.Sequential(...)
    
    • 通过线性层和 Sigmoid 激活函数对不同的注意力策略进行组合。
  • 头部的分割与合并

    self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
    self.merge_heads = Rearrange('b h n d -> b n (h d)')
    
    • split_heads 将输入张量拆分为多个头部。
    • merge_heads 将多个头部合并回一个张量。
  • 组合头部输出

    self.combine_heads = nn.Linear(dim_inner, dim, bias = False)
    
    • 将多个头部的输出通过线性层组合成最终的输出。

 

结合论文分析SparseAttention

结合上述代码,以下是具体神经网络层如何体现 NSA 的核心亮点:

1. 动态分层稀疏策略

  • 压缩层 (self.k_compressself.v_compress)
    self.k_compress = nn.Sequential(
        Rearrange('b h n d -> b (h d) n'),
        nn.Conv1d(dim_head * heads, dim_head * heads, compress_block_size, stride=compress_block_size, groups=heads),
        Rearrange('b (h d) nc -> b h nc d', h=heads)
    )
    
    self.v_compress = nn.Sequential(
        Rearrange('b h n d -> b (h d) n'),
        nn.Conv1d(dim_head * heads, dim_head * heads, compress_block_size, stride=compress_block_size, groups=heads),
        Rearrange('b (h d) nc -> b h nc d', h=heads)
    )
    
    • 说明:这些层实现了粗粒度的 Token 压缩。使用卷积层对键和值进行压缩,从而减少计算量,同时保持重要信息。这种设计使模型能够动态调整处理的 Token 数量,兼顾全局上下文和局部信息的捕捉。

2. 关键创新

  • 算术强度平衡的算法设计与硬件优化

    • 局部注意力层 (self.sliding_window)
      • 说明:局部注意力机制通过限制注意力计算在滑动窗口内,显著降低了计算复杂度。这种设计不仅提高了计算速度,还优化了内存使用,特别是在处理长序列时。
    self.sliding_window = LocalAttention(
        dim=dim_head,
        window_size=sliding_window_size,
        causal=True,
        exact_windowsize=True,
        autopad=True
    )
    
  • 组合头部输出 (self.combine_heads)

    self.combine_heads = nn.Linear(dim_inner, dim, bias=False)
    
    • 说明:通过线性层组合多个头部的输出,保持了模型的灵活性和表达能力,同时减少了冗余计算,进一步提升了算术强度的平衡。

3. 端到端可训练

  • 归一化层 (self.norm)

    self.norm = nn.RMSNorm(dim) if norm else nn.Identity()
    
    • 说明:归一化层的使用确保了模型在训练过程中的稳定性,支持端到端的训练方式,使得模型能够在推理阶段高效,同时优化了预训练和微调过程,从而减少计算量。
  • 策略组合层 (self.to_strategy_combine)

    self.to_strategy_combine = nn.Sequential(
        nn.Linear(dim, 3 * heads),
        nn.Sigmoid(),
        Rearrange('b n (h s) -> b h n s', h=heads)
    )
    
    • 说明:这一层通过组合不同的稀疏策略,确保模型在训练过程中能够灵活适应不同的任务和数据,支持端到端训练,提升了模型的实用性和效率。

 

Transformer模型结构, attn 使用了SparseAttention

class Transformer(Module):
    def __init__(
        self,
        num_tokens,
        dim,
        depth,
        dim_head = 64,
        heads = 8,
        ff_expansion_factor = 4.,
        use_sparse_attn = True,
        sparse_attn_kwargs: dict = dict(
            sliding_window_size = 32,
            compress_block_size = 4,
            selection_block_size = 4,
            num_selected_blocks = 4,
        )
    ):
        super().__init__()
        self.token_emb = nn.Embedding(num_tokens, dim)

        layers = []
        for _ in range(depth):

            if use_sparse_attn:
                attn = SparseAttention(
                    dim = dim,
                    dim_head = dim_head,
                    heads = heads,
                    **sparse_attn_kwargs
                )
            else:
                attn = Attention(dim = dim, dim_head = dim_head, heads = heads)

            ff = FeedForward(dim = dim, expansion_factor = ff_expansion_factor)

            layers.append(ModuleList([attn, ff]))

        self.layers = ModuleList(layers)

        self.norm = RMSNorm(dim)
        self.to_logits = Linear(dim, num_tokens, bias = False)
 
    def forward(
        self,
        ids,
        return_loss = False
    ):
        if return_loss:
            ids, labels = ids[:, :-1], ids[:, 1:]

        tokens = self.token_emb(ids)

        for attn, ff in self.layers:
            tokens = attn(tokens) + tokens
            tokens = ff(tokens) + tokens

        embed = self.norm(tokens)

        logits = self.to_logits(embed)

        if not return_loss:
            return logits

        return F.cross_entropy(rearrange(logits, 'b n l -> b l n'), labels)

 

使用wiki百科中文语料来训练下Transformer + NSA 注意力模型

Index of /zhwiki/latest/ 下载 zhwiki-latest-abstract.xml.gz

安装opencc 

pip install opencc-python-reimplemented

把繁体转发简体中文

import gzip
import opencc
import os
from tqdm import tqdm

# 检查 OpenCC 配置文件的路径
opencc_path = os.path.join(
    os.path.dirname(opencc.__file__), 'config', 't2s.json'
)

# 初始化 OpenCC 转换器
converter = opencc.OpenCC(opencc_path)

# 计算文件行数以便显示进度条
with gzip.open('zhwiki-latest-abstract.xml.gz', 'rt', encoding='utf-8') as infile:
    total_lines = sum(1 for _ in infile)  # 计算总行数
    infile.seek(0)  # 重置文件指针

# 压缩为新的 gz 文件
with gzip.open('zhwiki-latest-abstract-simplified.xml.gz', 'wt', encoding='utf-8') as outfile:
    with gzip.open('zhwiki-latest-abstract.xml.gz', 'rt', encoding='utf-8') as infile:
        for line in tqdm(infile, total=total_lines, desc="Processing"):
            simplified_line = converter.convert(line)
            outfile.write(simplified_line)

print("转换完成,已保存为 zhwiki-latest-abstract-simplified.xml.gz")

 

加载数据集,中文处理需要使用tokenizer

tokenizer = BertTokenizer.from_pretrained('./base_model/bert-base-chinese')
print(f"Vocabulary size: {len(tokenizer.vocab)}")
model = Transformer(
    num_tokens=len(tokenizer.vocab),
    dim=512,
    depth=6,
    use_sparse_attn=USE_SPARSE_ATTN,
    sparse_attn_kwargs=dict(
        sliding_window_size=16,  # 调整为更小的块大小
        compress_block_size=16,
        selection_block_size=16,
        num_selected_blocks=4,
        use_diff_topk=False
    )
).cuda()

# Data processing
with gzip.open('./data/zhwiki-latest-abstract-simplified.xml.gz', 'rb') as file:
    data = np.frombuffer(file.read(int(10e6)), dtype=np.uint8).copy()
    decoded_string = data.tobytes().decode('utf-8')
    tokens = []
    chunk_size = 10000
    vocab_size = len(tokenizer.vocab)
    for i in range(0, len(decoded_string), chunk_size):
        chunk = decoded_string[i:i + chunk_size]
        tokens.extend(tokenizer.tokenize(chunk))
    token_ids = tokenizer.convert_tokens_to_ids(tokens)
    token_ids = [tid for tid in token_ids if 0 <= tid < vocab_size]
    token_tensor = torch.tensor(token_ids)
    split_idx = int(len(token_tensor) * 0.8)
    data_train = token_tensor[:split_idx]
    data_val = token_tensor[split_idx:]

print("Train shape:", data_train.shape)
print("Validation shape:", data_val.shape)

class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __len__(self):
        return (self.data.size(0) - self.seq_len) // self.seq_len

    def __getitem__(self, index):
        rand_start = index * self.seq_len
        if rand_start + self.seq_len + 1 > self.data.size(0):
            raise IndexError("Index out of range for dataset.")
        full_seq = self.data[rand_start: rand_start + self.seq_len + 1]
        return full_seq.long().cuda()

解码使用

def decode_tokens(tokens):
    if isinstance(tokens, torch.Tensor):
        tokens = tokens.cpu().tolist()
    token_list = tokenizer.convert_ids_to_tokens(tokens)
    filtered_tokens = [token for token in token_list if token not in ['[CLS]', '[SEP]', '[PAD]']]
    return ''.join(filtered_tokens)

 

训练epoch 代码

for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
    model.train()
    for _ in range(GRAD_ACCUM_EVERY):
        data = next(train_loader)
        input_data = data[:, :-1]
        target_data = data[:, 1:]
        logits = model(input_data)
        loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), target_data.reshape(-1))
        (loss / GRAD_ACCUM_EVERY).backward()

    wandb.log(dict(loss=loss.item()), step=i)
    print(f"training loss: {loss.item():.3f}")
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optim.step()
    optim.zero_grad()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            valid_data = next(val_loader)
            input_data = valid_data[:, :-1]
            target_data = valid_data[:, 1:]
            logits = model(input_data)
            loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), target_data.reshape(-1))
            wandb.log(dict(valid_loss=loss.item()), step=i)
            print(f"validation loss: {loss.item():.3f}")
            if loss.item() < min_loss:
                min_loss = loss.item()
                torch.save(model.state_dict(), f'model_min_loss.pt')
                print(f'Model saved at validation loss: {min_loss:.3f}')

    if i % GENERATE_EVERY == 0:
        model.eval()
        inp = random.choice(val_dataset)[:PRIME_LENGTH]
        inp = inp.cuda()
        print(f"Input token IDs: {inp}")
        prime = decode_tokens(inp)
        print(f"\nprime: {prime}\n")
        prompt = inp[None, ...]
        sampled = base_decoding(model, prompt, GENERATE_LENGTH)
        base_decode_output = decode_tokens(sampled[0])
        decoded_str = urllib.parse.unquote(base_decode_output)
        print(f'output: {decoded_str}')

在loss最小的时候保存模型,  训练过程同步到wandb

 

训练中生成测试

training loss: 0.026
training loss: 0.018
validation loss: 0.787
Input token IDs: tensor([  110,   130,  8168,   110, 12888,   110,  8416,   110,   144,  8129,
          110,   147,  8159,   110, 10322,   110, 10322,   108,  1146,  2357,
          133,   120,  9025,   135,   133,   120, 11541,  8204,  9989,   135,
          133, 11541,  8204,  9989,  9025, 11085,   134,   107, 11469,  8225,
          107,   135,   133,  9064,  8370,  8372,   135,  4495,  3833,   133,
          120,  9064,  8370,  8372,   135,   133,  9025,   135,  8532,   131,
          120,   120,  9998,   119], device='cuda:0')

prime: %9##d%e7%94%b##0%e##4%ba%ba#分布</link></su##b##link><su##b##linklink##type="na##v"><an##ch##or>生活</an##ch##or><link>https://zh.

output: wikipedia.org/wiki/�%8##c��%b##0%e##4��#生活</link></su##b##link><su##b##linklink##type="na##v"><an##ch##or>理论</an##ch##or><link>https://zh.wikipedia.org/wiki/�%8##c�_(�%b##0%e##4%b##f��%8##c�#理论</link></su##b##link><su##b##linklink##type="na##v"><an##ch##or>参考文献</an##ch##or><link>https://zh.wikipedia.org/wiki/�%8##c��%b##0%e##4%b##f�#参考

结束训练后调用模型进行推理输出

import torch
from pytorch_pretrained_bert import BertTokenizer
from native_sparse_attention_pytorch.transformer import Transformer

# 常量(与训练时一致)
PRIME_LENGTH = 64
GENERATE_LENGTH = 256
SEQ_LEN = 256
USE_SPARSE_ATTN = True

# 采样辅助函数
def log(t, eps=1e-20):
    return torch.log(t.clamp(min=eps))

def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))

def gumbel_sample(t, temperature=1., dim=-1, keepdim=True):
    return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim, keepdim=keepdim)

def top_k(logits, thres=0.9):
    k = math.ceil((1 - thres) * logits.shape[-1])
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(-1, ind, val)
    return probs

def base_decoding(net, prompt: torch.Tensor, seq_len: int, temperature=1., filter_thres=0.9):
    prompt_seq_len, out = prompt.shape[-1], prompt.clone()
    sample_num_times = max(0, seq_len - prompt_seq_len)
    for _ in range(sample_num_times):
        logits = net(out)
        logits = logits[:, -1]
        logits = top_k(logits, thres=filter_thres)
        sample = gumbel_sample(logits, temperature=temperature, dim=-1)
        out = torch.cat((out, sample), dim=-1)
    return out[..., prompt_seq_len:]

# 解码函数
def decode_tokens(tokens, tokenizer):
    if isinstance(tokens, torch.Tensor):
        tokens = tokens.cpu().tolist()
    token_list = tokenizer.convert_ids_to_tokens(tokens)
    filtered_tokens = [token for token in token_list if token not in ['[CLS]', '[SEP]', '[PAD]']]
    return ''.join(filtered_tokens)

# 加载 tokenizer 和模型
tokenizer = BertTokenizer.from_pretrained('./base_model/bert-base-chinese')
vocab_size = len(tokenizer.vocab)
print(f"Vocabulary size: {vocab_size}")

model = Transformer(
    num_tokens=vocab_size,
    dim=512,
    depth=6,
    use_sparse_attn=USE_SPARSE_ATTN,
    sparse_attn_kwargs=dict(
        sliding_window_size=16,
        compress_block_size=16,
        selection_block_size=16,
        num_selected_blocks=4,
        use_diff_topk=False
    )
)

# 加载训练好的模型权重
model_path = 'model_min_loss.pt'
state_dict = torch.load(model_path, map_location='cpu')
model.load_state_dict(state_dict)
model = model.cuda()
model.eval()

# 输入示例
input_text = "阿氏吻鳐"
input_tokens = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(input_text))
min_length = max(PRIME_LENGTH, 16)  # 确保长度 >= compress_block_size
if len(input_tokens) < min_length:
    pad_id = tokenizer.convert_tokens_to_ids(['[PAD]'])[0]
    input_tokens = input_tokens + [pad_id] * (min_length - len(input_tokens))
input_tensor = torch.tensor(input_tokens[:PRIME_LENGTH], dtype=torch.long).cuda()
prompt = input_tensor[None, :]

# 进行推理
print(f"Input text: {input_text}")
print(f"Input token IDs: {input_tensor}")

with torch.no_grad():
    generated_tokens = base_decoding(model, prompt, GENERATE_LENGTH, temperature=0.7, filter_thres=0.9)
    generated_text = decode_tokens(generated_tokens[0], tokenizer)

# 输出结果
print(f"\nGenerated text: {generated_text}")

Generated text: 为软骨鱼纲鳐目鳐科吻鳐属的一种[1],分布于中西大西洋美国佛罗里达州到墨西哥犹加敦半岛海域,深度32至384米,本鱼体盘宽圆形,上表面颜色苍白并有暗斑,每个胸鳍上的眼斑通常为椭圆形,下表面白色,无深色斑纹

 


http://www.kler.cn/a/560620.html

相关文章:

  • 【开源免费】基于SpringBoot+Vue.JS物流管理系统(JAVA毕业设计)
  • 普通人使用生成式语言模型的几个阶段
  • javaweb-vue3基础
  • R Excel 文件:高效数据处理的利器
  • 在CentOS 7下部署NFS的详细教程
  • 一些时间方法
  • 如何保证bug在改完之后不会引起新bug
  • 如何通过阿里云CDN优化网站访问与下载速度?
  • 数据库-事务的ACID
  • Linux 系统内存不足导致服务崩溃的排查方法
  • TCP重传机制
  • 使用 Three.js 转换 GLSL 粒子效果着色器
  • 【C++设计模式】观察者模式(1/2):从基础到优化实现
  • Mesh自组网技术及应用
  • 网络运维学习笔记(DeepSeek优化版)002网工初级(HCIA-Datacom与CCNA-EI)子网划分与协议解析
  • 七.智慧城市数据治理平台架构
  • 【LeetCode 热题100】48. 旋转图像以及旋转任意角度的算法思路及python代码
  • LabVIEW Browser.vi 库说明
  • H5--开发适配
  • Web Developer 1靶场渗透测试