算法面经手撕系列(2)--手撕BatchNormlization
BatchNormlization
BatchNormlization的编码流程:
- init阶段初始化 C i n C_in Cin大小的scale向量和shift向量,同时初始化相同大小的滑动均值向量和滑动标准差向量;
- forward时沿着非channel维度计算均值、有偏方差
- 依据得到均值和有偏方差进行归一化
- 对归一化的结果进行缩放和平移
代码
代码如下:
class BN(nn.Module):
def __init__(self,C_in):
super(BN,self).__init__()
self.scale=nn.Parameter(torch.ones(C_in).view(1,-1,1,1))
self.shift=nn.Parameter(torch.zeros(C_in).view(1,-1,1,1))
self.momentum=0.9
self.register_buffer('running_mean',torch.zeros(C_in).view(1,-1,1,1))
self.register_buffer('running_var',torch.zeros(C_in).view(1,-1,1,1))
self.eps=1e-9
def forward(self,x):
if self.training:
N,C,H,W=x.shape
mean=x.mean(dim=[0,2,3],keepdim=True)
var=x.var(dim=[0,2,3],keepdim=True,unbiased=False)
x=(x-mean)/torch.sqrt(var+self.eps)
self.running_mean=self.momentum*self.running_mean+(1-self.momentum)*mean
self.running_var=self.momentum*self.running_var+(1-self.momentum)*var
else:
x=(x-self.running_mean)/torch.sqrt(self.running_var+self.eps)
return x
if __name__=="__main__":
input=torch.rand(10,3,5,5)
model=BN(3)
res=model(input)
print('cool')