算法面经手撕系列(3)--手撕LayerNormlization
LayerNormlization
在许多的语言模型如Bert里,虽然都是说做的LayerNormlization,但计算均值和方差只会沿着channel维度做,并不是沿着seq_L和channel维度一起做,参考:BERT用的LayerNorm可能不是你认为的那个Layer Norm
LayerNormlization计算流程:
- init里初始化C_in大小的scale和shift向量
- 沿Channel维度计算均值和方差
- 归一化
代码
LayerNorm(InstanceNorm)实现如下:
class LayerNormalization(nn.Module):
def __init__(self,hidden_dim,eps=1e-6):
super(LayerNormalization, self).__init__()
self.eps=eps
self.gamma=nn.Parameter(torch.ones(hidden_dim))
self.beta=nn.Parameter(torch.zeros(hidden_dim))
def forward(self,x):
B,seq_L,C=x.shape
mean=x.mean(dim=-1,keepdim=True)
std=x.std(dim=-1,keepdim=True)
out=(x-mean)/(std+self.eps)
out=out*self.gamma+self.beta
return out
if __name__=="__main__":
tensor_input=torch.rand(5,10,8)
model=LayerNormalization(8)
res=model(tensor_input)
print(res)