tensorflow代码解读和Transformer解析
本文是对谷歌Transformer模型的解读,记录了我个人的理解和顺序。
此外,对Kyubyong的tensorflow代码实现进行了分析,代码地址为:GitHub - Kyubyong/transformer: A TensorFlow Implementation of the Transformer: Attention Is All You Need。本文并未详细描述Transformer的实现原理,若对Transformer不熟悉,可以先阅读《Attention is all you need》及一些推荐的参考博客。
Layer Normalization
Layer Normalization(LN)和Batch Normalization(BN)有一定区别。虽然BN能加速模型收敛,但其存在以下缺点:
- BN对Batch Size高度依赖,在小Batch Size时效果不佳,特别是在显存有限时。
- BN不适用于序列化数据(如RNN)。
- BN只在训练时使用,推理阶段不会用到,这并非缺点而是特点。
BN计算方向是基于Batch,而LN则基于单条数据维度方向,LN相当于将BN的计算转置,对同层输出进行标准化。
Mask
Mask是关键部分之一。Kyubyong的早期代码在Mask实现上有误,后期进行了修正。Mask主要用于在参数更新时屏蔽特定值,Transformer中有两种Mask:padding mask和sequence mask。
- Padding Mask:为填充的部分赋极小值,使其在softmax后接近0权重,防止注意力分配到填充位置。
- Sequence Mask:用于阻止decoder中某步看到未来信息,通过上三角矩阵实现。
Context-Attention
Context-Attention用于encoder-decoder之间,是scaled dot-product的应用。通过query和key的相似度计算value的权重分布。
Multi-head Attention
多头self-attention是Transformer的核心,使用多个head合并表示attention分布。该部分计算QKV向量分布,并通过残差连接和Layer Normalization完成多头self-attention。
Positional Embedding
为提取序列顺序信息,Transformer采用正余弦编码生成Positional Embedding,使模型获得词语顺序信息。
其他模块
其他模块如前向网络和Label Smoothing较简单。此外代码中使用了Noam计划衰减学习率。
utils代码和数据加载
- num_batch计算:total_num除以batch_size取整+1。
- postprocess方法用于翻译预测处理,结合BPE解码处理双字节编码。
- 数据加载:加载词汇表及源语目标语数据,encode函数将字符序列转换为数字序列。
class Transformer:
'''
xs: tuple of
x: int32 tensor. (N, T1)
x_seqlens: int32 tensor. (N,)
sents1: str tensor. (N,)
ys: tuple of
decoder_input: int32 tensor. (N, T2)
y: int32 tensor. (N, T2)
y_seqlen: int32 tensor. (N, )
sents2: str tensor. (N,)
training: boolean.
'''
def __init__(self, hp):
self.hp = hp
self.token2idx, self.idx2token = load_vocab(hp.vocab)
self.embeddings = get_token_embeddings(self.hp.vocab_size, self.hp.d_model, zero_pad=True)
def encode(self, xs, training=True):
'''
Returns
memory: encoder outputs. (N, T1, d_model)
'''
with tf.variable_scope("encoder", reuse=tf.AUTO_REUSE):
x, seqlens, sents1 = xs
# embedding
enc = tf.nn.embedding_lookup(self.embeddings, x) # (N, T1, d_model)
enc *= self.hp.d_model**0.5 # scale
enc += positional_encoding(enc, self.hp.maxlen1)
enc = tf.layers.dropout(enc, self.hp.dropout_rate, training=training)
## Blocks
for i in range(self.hp.num_blocks):
with tf.variable_scope("num_blocks_{}".format(i), reuse=tf.AUTO_REUSE):
# self-attention
enc = multihead_attention(queries=enc,
keys=enc,
values=enc,
num_heads=self.hp.num_heads,
dropout_rate=self.hp.dropout_rate,
training=training,
causality=False)
# feed forward
enc = ff(enc, num_units=[self.hp.d_ff, self.hp.d_model])
memory = enc
return memory, sents1
def decode(self, ys, memory, training=True):
'''
memory: encoder outputs. (N, T1, d_model)
Returns
logits: (N, T2, V). float32.
y_hat: (N, T2). int32
y: (N, T2). int32
sents2: (N,). string.
'''
with tf.variable_scope("decoder", reuse=tf.AUTO_REUSE):
decoder_inputs, y, seqlens, sents2 = ys
# embedding
dec = tf.nn.embedding_lookup(self.embeddings, decoder_inputs) # (N, T2, d_model)
dec *= self.hp.d_model ** 0.5 # scale
dec += positional_encoding(dec, self.hp.maxlen2)
dec = tf.layers.dropout(dec, self.hp.dropout_rate, training=training)
# Blocks
for i in range(self.hp.num_blocks):
with tf.variable_scope("num_blocks_{}".format(i), reuse=tf.AUTO_REUSE):
# Masked self-attention (Note that causality is True at this time)
dec = multihead_attention(queries=dec,
keys=dec,
values=dec,
num_heads=self.hp.num_heads,
dropout_rate=self.hp.dropout_rate,
training=training,
causality=True,
scope="self_attention")
# Vanilla attention
dec = multihead_attention(queries=dec,
keys=memory,
values=memory,
num_heads=self.hp.num_heads,
dropout_rate=self.hp.dropout_rate,
training=training,
causality=False,
scope="vanilla_attention")
### Feed Forward
dec = ff(dec, num_units=[self.hp.d_ff, self.hp.d_model])
# Final linear projection (embedding weights are shared)
weights = tf.transpose(self.embeddings) # (d_model, vocab_size)
logits = tf.einsum('ntd,dk->ntk', dec, weights) # (N, T2, vocab_size)
y_hat = tf.to_int32(tf.argmax(logits, axis=-1))
return logits, y_hat, y, sents2
def train(self, xs, ys):
'''
Returns
loss: scalar.
train_op: training operation
global_step: scalar.
summaries: training summary node
'''
# forward
memory, sents1 = self.encode(xs)
logits, preds, y, sents2 = self.decode(ys, memory)
# train scheme
y_ = label_smoothing(tf.one_hot(y, depth=self.hp.vocab_size))
ce = tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=y_)
nonpadding = tf.to_float(tf.not_equal(y, self.token2idx["<pad>"])) # 0: <pad>
# 测试一下********************************************
print(tf.reduce_sum(nonpadding))
# ********************************************************
loss = tf.reduce_sum(ce * nonpadding) / (tf.reduce_sum(nonpadding) + 1e-7)
global_step = tf.train.get_or_create_global_step()
lr = noam_scheme(self.hp.lr, global_step, self.hp.warmup_steps)
optimizer = tf.train.AdamOptimizer(lr)
train_op = optimizer.minimize(loss, global_step=global_step)
tf.summary.scalar('lr', lr)
tf.summary.scalar("loss", loss)
tf.summary.scalar("global_step", global_step)
summaries = tf.summary.merge_all()
return loss, train_op, global_step, summaries
def eval(self, xs, ys):
'''Predicts autoregressively
At inference, input ys is ignored.
Returns
y_hat: (N, T2)
'''
decoder_inputs, y, y_seqlen, sents2 = ys
decoder_inputs = tf.ones((tf.shape(xs[0])[0], 1), tf.int32) * self.token2idx["<s>"]
ys = (decoder_inputs, y, y_seqlen, sents2)
memory, sents1 = self.encode(xs, False)
logging.info("Inference graph is being built. Please be patient.")
for _ in tqdm(range(self.hp.maxlen2)):
logits, y_hat, y, sents2 = self.decode(ys, memory, False)
if tf.reduce_sum(y_hat, 1) == self.token2idx["<pad>"]: break
_decoder_inputs = tf.concat((decoder_inputs, y_hat), 1)
ys = (_decoder_inputs, y, y_seqlen, sents2)
# monitor a random sample
n = tf.random_uniform((), 0, tf.shape(y_hat)[0]-1, tf.int32)
sent1 = sents1[n]
pred = convert_idx_to_token_tensor(y_hat[n], self.idx2token)
sent2 = sents2[n]
tf.summary.text("sent1", sent1)
tf.summary.text("pred", pred)
tf.summary.text("sent2", sent2)
summaries = tf.summary.merge_all()
return y_hat, summaries
模型整合
model.py整合了模块,代码简洁。需注意tf.nn.dropout和tf.layers.dropout的区别。