如何在ms-swift 微调训练deepseekvl2时使用sageattention
sageattention 据说比flash_atten_2还要快很多。 但是如何在deepseekvl2这训练这里把它用上呢?
1.本质上sageattention是sdpa,SDPA的全称为Scaled Dot-Product Attention, 属于乘性注意力机制, 简单一句话来说就是,根据Query (Q)与Key之间的匹配度来对Value进行加权,而事实上不管是Query, Key还是Value都来自于输入,因此所谓的SDPA本质上是对输入信息信息进行重组。
2. sageattention使用了Triton这个包,它可以把python的代码编译成目标机器码,大大加速这个运算的速度。
3. 从官方例子可以看到,它本质的工作原理就是简单的替换torch.nn.functional.scaled_dot_product_attention这个函数,示例如下:
from sageattention import sageattn
import torch.nn.functional as F
。。。
F.scaled_dot_product_attention = sageattn
所以,要用上sageatten,其实只需要原来的模型支持sdpa的注意力机制即可。
但很不幸,从deepseekvl2官方开源github可以看到DeepseekVLV2ForCausalLM和DeepseekV2ForCausalLM都是不支持sdpa的,这个两个类都没有声明_supports_sdpa = True
ATTENTION_CLASSES = {
"eager": DeepseekV2Attention,
"flash_attention_2": DeepseekV2FlashAttention2,
"mla_eager": DeepseekV2Attention,
"mla_flash_attention_2": DeepseekV2FlashAttention2,
"mha_eager": LlamaAttention,
"mha_flash_attention_2": LlamaFlashAttention2
}
这个attention_class也没有sdpa的实现。
因此,deepseekvl2无法直接简单使用sageattion,我们需要改一下deepseek的开源代码,才有可能用上sageattion.
修改步骤如下:
1. 首先要让DeepseekVLV2ForCausalLM和DeepseekV2ForCausalLM先支持sdpa,添加_supports_sdpa,这样transformers/modeling_utils.py的_check_and_enable_sdpa才可以检查通过.
class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
_supports_sdpa = True
class DeepseekVLV2ForCausalLM(DeepseekVLV2PreTrainedModel):
_supports_sdpa = True
2.然后在DeepSeek-VL2/deepseek_vl2/models/modeling_deepseek.py添加sdpa的attention的实现类,我们让它继承LlamaAttention,如下,其实这个直接抄的LlamaSdpaAttention,copy是为了修改方便,实现如下:
class DeepSeekSdpaAttention(LlamaAttention):
"""
Deepseek attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
`DeepseekV2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
SDPA API.
"""
# Adapted from LlamaAttention.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(value_states, position_ids)
else:
if isinstance(position_embeddings, torch.Tensor):
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb2(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and causal_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True #if causal_mask is None and q_len > 1 else False
attn_output = sageattn(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=is_causal,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
3. 修改modeling_deepseek.py的ATTENTION_CLASSES,加上sdpa支持,如下:
ATTENTION_CLASSES = {
"eager": DeepseekV2Attention,
"flash_attention_2": DeepseekV2FlashAttention2,
"mla_eager": DeepseekV2Attention,
"mla_flash_attention_2": DeepseekV2FlashAttention2,
"mha_eager": LlamaAttention,
"mha_flash_attention_2": LlamaFlashAttention2,
"mha_sdpa": DeepSeekSdpaAttention
}
4.使用的sageattion和Triton的版本如下:
Name: sageattention
Version: 1.0.6
Name: triton
Version: 3.2.0
5. 训练测试,
swift sft --model "deepseek-ai/deepseek-vl2-tiny" --dataset ../TEST.json --attn_impl sdp
。。。
using attn_implementation: mha_sdpa
[INFO:swift] model.hf_device_map: {'': device(type='cuda', index=0)}
。。。
[INFO:swift] End time of running main: 2025-03-28 10:43:29.304164
成功进行了训练。