【深度学习】常见模型-Transformer模型
Transformer 是一种深度学习模型,首次由 Vaswani 等人在 2017 年提出(论文《Attention is All You Need》),在自然语言处理(NLP)领域取得了革命性成果。它的核心思想是通过 自注意力机制(Self-Attention Mechanism) 和完全基于注意力的架构来捕捉序列数据中的依赖关系。
Transformer 的基本结构
Transformer 模型由两个主要模块组成:
-
编码器(Encoder):
- 输入序列经过嵌入(Embedding)和位置编码(Positional Encoding)后,逐层通过多个编码块。
- 每个编码块包括两个主要子层:
- 多头自注意力层(Multi-Head Self-Attention)。
- 前馈全连接网络(Feedforward Neural Network)。
-
解码器(Decoder):
- 解码器也由多层解码块组成,结构类似编码器,但有额外的交叉注意力机制。
- 解码块主要包含:
- 多头自注意力层(Masked Multi-Head Self-Attention)。
- 交叉注意力层(Encoder-Decoder Attention)。
- 前馈全连接网络。
Transformer 的输入经过编码器进行特征提取,解码器利用编码器输出生成目标序列。
核心组件
1. 自注意力机制(Self-Attention Mechanism)
- 目标:在序列的每个位置,计算它与其他所有位置的相关性,捕获全局依赖关系。
- 公式:
- Q:查询矩阵(Query)。
- K:键矩阵(Key)。
- V:值矩阵(Value)。
- :键向量的维度(用于缩放防止梯度爆炸)。
2. 多头注意力机制(Multi-Head Attention)
- 将输入数据分为多个头(head),并分别计算注意力。
- 优点:能够从不同的子空间捕获特征,提高模型的表达能力。
3. 位置编码(Positional Encoding)
- 因为 Transformer 不使用 RNN 或 CNN,所以需要显式地表示序列位置。
- 常用正弦和余弦函数来表示:
- pos:位置索引。
- i:维度索引。
- d:嵌入维度。
4. 前馈全连接网络(FFN)
- 每个编码器或解码器块都包含一个独立的全连接网络:
5. 残差连接与层归一化
- 每个子层后加残差连接(Residual Connection)并归一化(Layer Normalization),以加速训练和稳定梯度。
Transformer 的整体结构
Transformer 使用堆叠的编码器和解码器模块处理输入和输出:
- 输入序列(如句子)经过嵌入和位置编码后输入到编码器。
- 编码器生成的上下文向量传递到解码器。
- 解码器通过交叉注意力结合编码器的上下文向量和解码器中间状态生成输出序列。
代码实现
以下是使用 TensorFlow 和 Keras 构建简单 Transformer 的代码示例:
import tensorflow as tf
from tensorflow.keras.layers import Dense, Embedding, LayerNormalization, Dropout
import numpy as np
# 自注意力机制
class MultiHeadAttention(tf.keras.layers.Layer):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
self.num_heads = num_heads
self.d_model = d_model
assert d_model % self.num_heads == 0
self.depth = d_model // self.num_heads
self.wq = Dense(d_model)
self.wk = Dense(d_model)
self.wv = Dense(d_model)
self.dense = Dense(d_model)
def split_heads(self, x, batch_size):
x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
return tf.transpose(x, perm=[0, 2, 1, 3]) # (batch_size, num_heads, seq_len, depth)
def call(self, q, k, v, mask):
batch_size = tf.shape(q)[0]
q = self.wq(q) # (batch_size, seq_len, d_model)
k = self.wk(k)
v = self.wv(v)
q = self.split_heads(q, batch_size)
k = self.split_heads(k, batch_size)
v = self.split_heads(v, batch_size)
# Scaled dot-product attention
matmul_qk = tf.matmul(q, k, transpose_b=True)
dk = tf.cast(tf.shape(k)[-1], tf.float32)
scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
if mask is not None:
scaled_attention_logits += (mask * -1e9)
attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1) # (batch_size, num_heads, seq_len_q, seq_len_k)
output = tf.matmul(attention_weights, v) # (batch_size, num_heads, seq_len_q, depth_v)
output = tf.transpose(output, perm=[0, 2, 1, 3]) # (batch_size, seq_len_q, num_heads, depth)
concat_attention = tf.reshape(output, (batch_size, -1, self.d_model)) # (batch_size, seq_len_q, d_model)
return self.dense(concat_attention)
# 示例调用
sample_mha = MultiHeadAttention(d_model=512, num_heads=8)
temp_q = tf.random.uniform((1, 60, 512)) # (batch_size, seq_len, d_model)
temp_k = tf.random.uniform((1, 60, 512))
temp_v = tf.random.uniform((1, 60, 512))
temp_out = sample_mha(temp_q, temp_k, temp_v, None)
print(temp_out.shape) # (1, 60, 512)
Transformer 的应用
-
自然语言处理:
- 机器翻译(Google Translate 使用 Transformer)。
- 文本摘要(如 BERT、GPT)。
- 情感分析、问答系统。
-
计算机视觉:
- 图像分类(如 Vision Transformer)。
- 目标检测、图像生成。
-
音频处理:
- 语音识别(如 Wav2Vec)。
- 音乐生成。
-
其他领域:
- 推荐系统、时间序列预测、生物信息学。
优点与缺点
优点:
- 并行处理能力强,速度快。
- 能捕获长距离依赖关系。
- 通用性强,适用于多种任务。
缺点:
- 计算成本高(尤其是自注意力机制在长序列上的时间复杂度)。
- 对内存需求大,训练大型模型需高性能硬件。
Transformer 以其强大的表达能力和灵活性,已经成为深度学习领域的重要基石,为 NLP 和其他领域带来了巨大变革。