原生稀疏注意力NSA 替换transformer 注意力进行文本生成训练
DeepSeek-R1这篇文章,聚焦范围更加小,R1的重点在于提出了一个文本生成的训练策略和蒸馏策略,这篇文章则是提出了一个注意力机制NSA,主要解决的是长序列做注意力时带来的效率问题。通篇文章看下来,它的实际意义可能比较局限,因此本文仅关注其主要内容,对于具体细节和实验结果并不进一步细究。
论文标题:Native Sparse Attention: Hardware-Aligned and Natively Trainable Sparse Attention
论文链接:[2502.11089] Native Sparse Attention: Hardware-Aligned and Natively Trainable Sparse Attention
简单总结起来就是
实验效果惊艳:性能不降反升,速度大幅提升!
实验结果令人振奋!在通用基准测试、长文本任务和指令推理方面,使用 NSA 预训练的模型性能不仅没有下降,反而超越了 Full Attention 模型!
NSA 的核心亮点可以概括为以下两点:
1.动态分层稀疏策略: NSA 采用了一种动态分层的稀疏策略,结合了粗粒度的 Token 压缩 和 细粒度的 Token 选择。这种策略既能保证模型对全局上下文的感知,又能兼顾局部信息的精确性
2.关键创新:
算术强度平衡的算法设计与硬件优化: NSA 通过精巧的算法设计,并针对现代硬件进行了实现优化,显著提升了计算速度
端到端可训练: NSA 支持端到端训练,这意味着它不仅在推理阶段高效,还能减少预训练的计算量,同时不牺牲模型性能!
Attention的稀疏特性,其实从BERT时代开始就已经被广泛验证了。最早像Longformer、BigBird这些模型提出的几种稀疏Attention Pattern(比如Sliding Window、Global Attention——现在叫Attention Sink),直到今天依然被广泛使用。Attention天然的稀疏性,意味着每个词元在计算时,只需要从海量的上文中选出top-k相关的部分进行Attention计算。这个思路很简单,但难点就在于如何快速找到top-k的相关上文。如果逐token去选,计算和访存的过程又会回到Full-Attention的复杂度。
稀疏Attn为什么还能超过Full-Attn?
长文本具有天然的高稀疏性与富噪音性。处理每个token确实不要把全文都过一遍,而Full-attention机制,总是能确保每两个token之间的相关性不为0。这也就带来了计算上的噪音。所以不难理解,一个well-trained 稀疏Attn能够为每个token屏蔽掉部分噪音,效果也能带来些许提升。但效果的有限提升外,还是效率的提升更让人惊喜。
知乎上也有一篇介绍
https://zhuanlan.zhihu.com/p/24604821449
逛github时已经有大神做了论文复现,不仅提供了SparseAttention 还替换了原有Transformer 模型里的attention层
稀疏注意力SparseAttention 模型网络定义
class SparseAttention(Module):
def __init__(
self,
dim,
dim_head,
heads,
sliding_window_size,
compress_block_size,
selection_block_size,
num_selected_blocks,
num_compressed_mem_kv = 4,
norm = True,
use_diff_topk = False,
):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
assert compress_block_size == selection_block_size, 'start off with compressed being equal to selection block sizes'
dim_inner = dim_head * heads
self.norm = nn.RMSNorm(dim) if norm else nn.Identity()
# rotary
self.rotary_emb = RotaryEmbedding(dim_head)
# qkv
self.to_qkv = nn.Linear(dim, dim_inner * 3, bias = False)
# sliding window strategy
self.sliding_window = LocalAttention(
dim = dim_head,
window_size = sliding_window_size,
causal = True,
exact_windowsize = True,
autopad = True
)
# compress strategy
self.compress_block_size = compress_block_size
assert num_compressed_mem_kv > 0
self.compress_mem_kv = nn.Parameter(torch.zeros(2, heads, num_compressed_mem_kv, dim_head))
self.k_intrablock_positions = nn.Parameter(torch.zeros(heads, compress_block_size, dim_head))
self.v_intrablock_positions = nn.Parameter(torch.zeros(heads, compress_block_size, dim_head))
self.k_compress = nn.Sequential(
Rearrange('b h n d -> b (h d) n'),
nn.Conv1d(dim_head * heads, dim_head * heads, compress_block_size, stride = compress_block_size, groups = heads),
Rearrange('b (h d) nc -> b h nc d', h = heads)
)
self.v_compress = nn.Sequential(
Rearrange('b h n d -> b (h d) n'),
nn.Conv1d(dim_head * heads, dim_head * heads, compress_block_size, stride = compress_block_size, groups = heads),
Rearrange('b (h d) nc -> b h nc d', h = heads)
)
# selection related
self.use_diff_topk = use_diff_topk
self.selection_block_size = selection_block_size
self.num_selected_blocks = num_selected_blocks
# they combine the three sparse branches through a learned combine with sigmoid activation
self.to_strategy_combine = nn.Sequential(
nn.Linear(dim, 3 * heads),
nn.Sigmoid(),
Rearrange('b n (h s) -> b h n s', h = heads)
)
# split and merging heads
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
self.merge_heads = Rearrange('b h n d -> b n (h d)')
# combining heads
self.combine_heads = nn.Linear(dim_inner, dim, bias = False)
dim
: 输入特征的维度。dim_head
: 每个注意力头的维度。heads
: 注意力头的数量。sliding_window_size
: 滑动窗口的大小,用于局部注意力。compress_block_size
: 压缩块的大小。selection_block_size
: 选择块的大小。num_selected_blocks
: 选择的块数量。num_compressed_mem_kv
: 压缩的记忆键值对数量(默认为 4)。norm
: 是否使用归一化(默认为True
)。use_diff_topk
: 是否使用不同的 Top-k 策略(默认为False
)。
SparseAttention流程总结
-
头部和缩放:
self.heads = heads self.scale = dim_head ** -0.5
heads
保存注意力头的数量。scale
用于缩放注意力的分数,防止数值过大。
-
归一化:
self.norm = nn.RMSNorm(dim) if norm else nn.Identity()
- 使用 RMSNorm 进行归一化,或者使用身份函数(如果不需要归一化)。
-
旋转嵌入:
self.rotary_emb = RotaryEmbedding(dim_head)
- 用于实现旋转位置编码,增强模型对序列位置信息的理解。
-
QKV 线性变换:
self.to_qkv = nn.Linear(dim, dim_inner * 3, bias = False)
- 将输入特征映射到查询(Q)、键(K)和值(V)的线性空间。
-
滑动窗口注意力:
self.sliding_window = LocalAttention(...)
- 实现局部注意力机制,限制注意力计算在滑动窗口内,以减少计算复杂度。
-
压缩策略:
self.compress_mem_kv = nn.Parameter(torch.zeros(2, heads, num_compressed_mem_kv, dim_head))
- 初始化压缩后的键值存储。
-
内块位置参数:
self.k_intrablock_positions = nn.Parameter(torch.zeros(heads, compress_block_size, dim_head)) self.v_intrablock_positions = nn.Parameter(torch.zeros(heads, compress_block_size, dim_head))
- 用于保存每个头部在压缩块中的位置。
-
压缩操作:
self.k_compress = nn.Sequential(...) self.v_compress = nn.Sequential(...)
- 使用卷积层对键和值进行压缩,减少计算量。
-
选择策略:
self.to_strategy_combine = nn.Sequential(...)
- 通过线性层和 Sigmoid 激活函数对不同的注意力策略进行组合。
-
头部的分割与合并:
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads) self.merge_heads = Rearrange('b h n d -> b n (h d)')
split_heads
将输入张量拆分为多个头部。merge_heads
将多个头部合并回一个张量。
-
组合头部输出:
self.combine_heads = nn.Linear(dim_inner, dim, bias = False)
- 将多个头部的输出通过线性层组合成最终的输出。
结合论文分析SparseAttention
结合上述代码,以下是具体神经网络层如何体现 NSA 的核心亮点:
1. 动态分层稀疏策略
- 压缩层 (
self.k_compress
和self.v_compress
):self.k_compress = nn.Sequential( Rearrange('b h n d -> b (h d) n'), nn.Conv1d(dim_head * heads, dim_head * heads, compress_block_size, stride=compress_block_size, groups=heads), Rearrange('b (h d) nc -> b h nc d', h=heads) ) self.v_compress = nn.Sequential( Rearrange('b h n d -> b (h d) n'), nn.Conv1d(dim_head * heads, dim_head * heads, compress_block_size, stride=compress_block_size, groups=heads), Rearrange('b (h d) nc -> b h nc d', h=heads) )
- 说明:这些层实现了粗粒度的 Token 压缩。使用卷积层对键和值进行压缩,从而减少计算量,同时保持重要信息。这种设计使模型能够动态调整处理的 Token 数量,兼顾全局上下文和局部信息的捕捉。
2. 关键创新
-
算术强度平衡的算法设计与硬件优化:
- 局部注意力层 (
self.sliding_window
):- 说明:局部注意力机制通过限制注意力计算在滑动窗口内,显著降低了计算复杂度。这种设计不仅提高了计算速度,还优化了内存使用,特别是在处理长序列时。
self.sliding_window = LocalAttention( dim=dim_head, window_size=sliding_window_size, causal=True, exact_windowsize=True, autopad=True )
- 局部注意力层 (
-
组合头部输出 (
self.combine_heads
):self.combine_heads = nn.Linear(dim_inner, dim, bias=False)
- 说明:通过线性层组合多个头部的输出,保持了模型的灵活性和表达能力,同时减少了冗余计算,进一步提升了算术强度的平衡。
3. 端到端可训练
-
归一化层 (
self.norm
):self.norm = nn.RMSNorm(dim) if norm else nn.Identity()
- 说明:归一化层的使用确保了模型在训练过程中的稳定性,支持端到端的训练方式,使得模型能够在推理阶段高效,同时优化了预训练和微调过程,从而减少计算量。
-
策略组合层 (
self.to_strategy_combine
):self.to_strategy_combine = nn.Sequential( nn.Linear(dim, 3 * heads), nn.Sigmoid(), Rearrange('b n (h s) -> b h n s', h=heads) )
- 说明:这一层通过组合不同的稀疏策略,确保模型在训练过程中能够灵活适应不同的任务和数据,支持端到端训练,提升了模型的实用性和效率。
Transformer模型结构, attn 使用了SparseAttention
class Transformer(Module):
def __init__(
self,
num_tokens,
dim,
depth,
dim_head = 64,
heads = 8,
ff_expansion_factor = 4.,
use_sparse_attn = True,
sparse_attn_kwargs: dict = dict(
sliding_window_size = 32,
compress_block_size = 4,
selection_block_size = 4,
num_selected_blocks = 4,
)
):
super().__init__()
self.token_emb = nn.Embedding(num_tokens, dim)
layers = []
for _ in range(depth):
if use_sparse_attn:
attn = SparseAttention(
dim = dim,
dim_head = dim_head,
heads = heads,
**sparse_attn_kwargs
)
else:
attn = Attention(dim = dim, dim_head = dim_head, heads = heads)
ff = FeedForward(dim = dim, expansion_factor = ff_expansion_factor)
layers.append(ModuleList([attn, ff]))
self.layers = ModuleList(layers)
self.norm = RMSNorm(dim)
self.to_logits = Linear(dim, num_tokens, bias = False)
def forward(
self,
ids,
return_loss = False
):
if return_loss:
ids, labels = ids[:, :-1], ids[:, 1:]
tokens = self.token_emb(ids)
for attn, ff in self.layers:
tokens = attn(tokens) + tokens
tokens = ff(tokens) + tokens
embed = self.norm(tokens)
logits = self.to_logits(embed)
if not return_loss:
return logits
return F.cross_entropy(rearrange(logits, 'b n l -> b l n'), labels)
使用wiki百科中文语料来训练下Transformer + NSA 注意力模型
Index of /zhwiki/latest/ 下载 zhwiki-latest-abstract.xml.gz
安装opencc
pip install opencc-python-reimplemented
把繁体转发简体中文
import gzip
import opencc
import os
from tqdm import tqdm
# 检查 OpenCC 配置文件的路径
opencc_path = os.path.join(
os.path.dirname(opencc.__file__), 'config', 't2s.json'
)
# 初始化 OpenCC 转换器
converter = opencc.OpenCC(opencc_path)
# 计算文件行数以便显示进度条
with gzip.open('zhwiki-latest-abstract.xml.gz', 'rt', encoding='utf-8') as infile:
total_lines = sum(1 for _ in infile) # 计算总行数
infile.seek(0) # 重置文件指针
# 压缩为新的 gz 文件
with gzip.open('zhwiki-latest-abstract-simplified.xml.gz', 'wt', encoding='utf-8') as outfile:
with gzip.open('zhwiki-latest-abstract.xml.gz', 'rt', encoding='utf-8') as infile:
for line in tqdm(infile, total=total_lines, desc="Processing"):
simplified_line = converter.convert(line)
outfile.write(simplified_line)
print("转换完成,已保存为 zhwiki-latest-abstract-simplified.xml.gz")
加载数据集,中文处理需要使用tokenizer
tokenizer = BertTokenizer.from_pretrained('./base_model/bert-base-chinese')
print(f"Vocabulary size: {len(tokenizer.vocab)}")
model = Transformer(
num_tokens=len(tokenizer.vocab),
dim=512,
depth=6,
use_sparse_attn=USE_SPARSE_ATTN,
sparse_attn_kwargs=dict(
sliding_window_size=16, # 调整为更小的块大小
compress_block_size=16,
selection_block_size=16,
num_selected_blocks=4,
use_diff_topk=False
)
).cuda()
# Data processing
with gzip.open('./data/zhwiki-latest-abstract-simplified.xml.gz', 'rb') as file:
data = np.frombuffer(file.read(int(10e6)), dtype=np.uint8).copy()
decoded_string = data.tobytes().decode('utf-8')
tokens = []
chunk_size = 10000
vocab_size = len(tokenizer.vocab)
for i in range(0, len(decoded_string), chunk_size):
chunk = decoded_string[i:i + chunk_size]
tokens.extend(tokenizer.tokenize(chunk))
token_ids = tokenizer.convert_tokens_to_ids(tokens)
token_ids = [tid for tid in token_ids if 0 <= tid < vocab_size]
token_tensor = torch.tensor(token_ids)
split_idx = int(len(token_tensor) * 0.8)
data_train = token_tensor[:split_idx]
data_val = token_tensor[split_idx:]
print("Train shape:", data_train.shape)
print("Validation shape:", data_val.shape)
class TextSamplerDataset(Dataset):
def __init__(self, data, seq_len):
super().__init__()
self.data = data
self.seq_len = seq_len
def __len__(self):
return (self.data.size(0) - self.seq_len) // self.seq_len
def __getitem__(self, index):
rand_start = index * self.seq_len
if rand_start + self.seq_len + 1 > self.data.size(0):
raise IndexError("Index out of range for dataset.")
full_seq = self.data[rand_start: rand_start + self.seq_len + 1]
return full_seq.long().cuda()
解码使用
def decode_tokens(tokens):
if isinstance(tokens, torch.Tensor):
tokens = tokens.cpu().tolist()
token_list = tokenizer.convert_ids_to_tokens(tokens)
filtered_tokens = [token for token in token_list if token not in ['[CLS]', '[SEP]', '[PAD]']]
return ''.join(filtered_tokens)
训练epoch 代码
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
model.train()
for _ in range(GRAD_ACCUM_EVERY):
data = next(train_loader)
input_data = data[:, :-1]
target_data = data[:, 1:]
logits = model(input_data)
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), target_data.reshape(-1))
(loss / GRAD_ACCUM_EVERY).backward()
wandb.log(dict(loss=loss.item()), step=i)
print(f"training loss: {loss.item():.3f}")
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optim.step()
optim.zero_grad()
if i % VALIDATE_EVERY == 0:
model.eval()
with torch.no_grad():
valid_data = next(val_loader)
input_data = valid_data[:, :-1]
target_data = valid_data[:, 1:]
logits = model(input_data)
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), target_data.reshape(-1))
wandb.log(dict(valid_loss=loss.item()), step=i)
print(f"validation loss: {loss.item():.3f}")
if loss.item() < min_loss:
min_loss = loss.item()
torch.save(model.state_dict(), f'model_min_loss.pt')
print(f'Model saved at validation loss: {min_loss:.3f}')
if i % GENERATE_EVERY == 0:
model.eval()
inp = random.choice(val_dataset)[:PRIME_LENGTH]
inp = inp.cuda()
print(f"Input token IDs: {inp}")
prime = decode_tokens(inp)
print(f"\nprime: {prime}\n")
prompt = inp[None, ...]
sampled = base_decoding(model, prompt, GENERATE_LENGTH)
base_decode_output = decode_tokens(sampled[0])
decoded_str = urllib.parse.unquote(base_decode_output)
print(f'output: {decoded_str}')
在loss最小的时候保存模型, 训练过程同步到wandb
训练中生成测试
training loss: 0.026
training loss: 0.018
validation loss: 0.787
Input token IDs: tensor([ 110, 130, 8168, 110, 12888, 110, 8416, 110, 144, 8129,
110, 147, 8159, 110, 10322, 110, 10322, 108, 1146, 2357,
133, 120, 9025, 135, 133, 120, 11541, 8204, 9989, 135,
133, 11541, 8204, 9989, 9025, 11085, 134, 107, 11469, 8225,
107, 135, 133, 9064, 8370, 8372, 135, 4495, 3833, 133,
120, 9064, 8370, 8372, 135, 133, 9025, 135, 8532, 131,
120, 120, 9998, 119], device='cuda:0')prime: %9##d%e7%94%b##0%e##4%ba%ba#分布</link></su##b##link><su##b##linklink##type="na##v"><an##ch##or>生活</an##ch##or><link>https://zh.
output: wikipedia.org/wiki/�%8##c��%b##0%e##4��#生活</link></su##b##link><su##b##linklink##type="na##v"><an##ch##or>理论</an##ch##or><link>https://zh.wikipedia.org/wiki/�%8##c�_(�%b##0%e##4%b##f��%8##c�#理论</link></su##b##link><su##b##linklink##type="na##v"><an##ch##or>参考文献</an##ch##or><link>https://zh.wikipedia.org/wiki/�%8##c��%b##0%e##4%b##f�#参考
结束训练后调用模型进行推理输出
import torch
from pytorch_pretrained_bert import BertTokenizer
from native_sparse_attention_pytorch.transformer import Transformer
# 常量(与训练时一致)
PRIME_LENGTH = 64
GENERATE_LENGTH = 256
SEQ_LEN = 256
USE_SPARSE_ATTN = True
# 采样辅助函数
def log(t, eps=1e-20):
return torch.log(t.clamp(min=eps))
def gumbel_noise(t):
noise = torch.zeros_like(t).uniform_(0, 1)
return -log(-log(noise))
def gumbel_sample(t, temperature=1., dim=-1, keepdim=True):
return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim, keepdim=keepdim)
def top_k(logits, thres=0.9):
k = math.ceil((1 - thres) * logits.shape[-1])
val, ind = torch.topk(logits, k)
probs = torch.full_like(logits, float('-inf'))
probs.scatter_(-1, ind, val)
return probs
def base_decoding(net, prompt: torch.Tensor, seq_len: int, temperature=1., filter_thres=0.9):
prompt_seq_len, out = prompt.shape[-1], prompt.clone()
sample_num_times = max(0, seq_len - prompt_seq_len)
for _ in range(sample_num_times):
logits = net(out)
logits = logits[:, -1]
logits = top_k(logits, thres=filter_thres)
sample = gumbel_sample(logits, temperature=temperature, dim=-1)
out = torch.cat((out, sample), dim=-1)
return out[..., prompt_seq_len:]
# 解码函数
def decode_tokens(tokens, tokenizer):
if isinstance(tokens, torch.Tensor):
tokens = tokens.cpu().tolist()
token_list = tokenizer.convert_ids_to_tokens(tokens)
filtered_tokens = [token for token in token_list if token not in ['[CLS]', '[SEP]', '[PAD]']]
return ''.join(filtered_tokens)
# 加载 tokenizer 和模型
tokenizer = BertTokenizer.from_pretrained('./base_model/bert-base-chinese')
vocab_size = len(tokenizer.vocab)
print(f"Vocabulary size: {vocab_size}")
model = Transformer(
num_tokens=vocab_size,
dim=512,
depth=6,
use_sparse_attn=USE_SPARSE_ATTN,
sparse_attn_kwargs=dict(
sliding_window_size=16,
compress_block_size=16,
selection_block_size=16,
num_selected_blocks=4,
use_diff_topk=False
)
)
# 加载训练好的模型权重
model_path = 'model_min_loss.pt'
state_dict = torch.load(model_path, map_location='cpu')
model.load_state_dict(state_dict)
model = model.cuda()
model.eval()
# 输入示例
input_text = "阿氏吻鳐"
input_tokens = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(input_text))
min_length = max(PRIME_LENGTH, 16) # 确保长度 >= compress_block_size
if len(input_tokens) < min_length:
pad_id = tokenizer.convert_tokens_to_ids(['[PAD]'])[0]
input_tokens = input_tokens + [pad_id] * (min_length - len(input_tokens))
input_tensor = torch.tensor(input_tokens[:PRIME_LENGTH], dtype=torch.long).cuda()
prompt = input_tensor[None, :]
# 进行推理
print(f"Input text: {input_text}")
print(f"Input token IDs: {input_tensor}")
with torch.no_grad():
generated_tokens = base_decoding(model, prompt, GENERATE_LENGTH, temperature=0.7, filter_thres=0.9)
generated_text = decode_tokens(generated_tokens[0], tokenizer)
# 输出结果
print(f"\nGenerated text: {generated_text}")
Generated text: 为软骨鱼纲鳐目鳐科吻鳐属的一种[1],分布于中西大西洋美国佛罗里达州到墨西哥犹加敦半岛海域,深度32至384米,本鱼体盘宽圆形,上表面颜色苍白并有暗斑,每个胸鳍上的眼斑通常为椭圆形,下表面白色,无深色斑纹