当前位置: 首页 > article >正文

Numpy实现BatchNorm2d

假设 x[B,C,H,W],

1. 计算不同通道的均值

如何理解 np.mean(x, axis=(0, 2, 3), keepdims=True)

import numpy as np

def cal_mean(x:np.ndarray):
    B,C,H,W = x.shape
    batch_mean = np.zeros((1,C,1,1))

    for c in range(C):
        total_sum = 0.0
        count = 0
        for b in range(B):
            for h in range(H):
                for w in range(W):
                    total_sum += x[b,c,h,w]
                    count += 1
        # 实际上,count=B*H*W
        batch_mean[0,c,0,0] = total_sum / count
    return batch_mean


B, C, H, W = 4, 3, 8, 8
x = np.random.randint(-10,10, (B, C, H, W) )

mean0 = np.mean(x,axis=(0,2,3),keepdims=True)
mean1 = cal_mean(x)
print(mean0==mean1)  # True

2.计算不同通道的方差 

如何理解np.var(x, axis=(0, 2, 3), keepdims=True)?

def cal_variance(x:np.ndarray):
    batch_mean = cal_mean(x)
    B,C,H,W = x.shape
    batch_var = np.zeros((1,C,1,1))

    for c in range(C):
        sum_squared_diff = 0.0
        count = 0
        for b in range(B):
            for h in range(H):
                for w in range(W):
                    diff = x[b,c,h,w] - batch_mean[0,c,0,0]
                    sum_squared_diff += diff ** 2
                    count += 1
        # 实际上,count=B*H*W
        batch_var[0,c,0,0] = sum_squared_diff / count
    return batch_var

B, C, H, W = 4, 3, 8, 8
x = np.random.randint(-10,10, (B, C, H, W) )
var0 = np.var(x, axis=(0, 2, 3), keepdims=True)
var1 = cal_variance(x)

print(var0 == var1) # True

3. NumPy计算BatchNorm2d

class BatchNorm2d:
    def __init__(self, num_channels, epsilon=1e-5, momentum=0.1):
        self.epsilon = epsilon
        self.momentum = momentum
        
        # 初始化 gamma (scale) and beta (shift) as learnable parameters
        # gamma和beta是learnable参数
        self.gamma = np.ones((1, num_channels, 1, 1))  
        self.beta = np.zeros((1, num_channels, 1, 1))  
        
        # 推理阶段使用的running_mean和running_var
        self.running_mean = np.zeros((1, num_channels, 1, 1))
        self.running_var = np.ones((1, num_channels, 1, 1))
    
    def forward(self, x, training=True):
        if training:
            # 计算均值和方差
            batch_mean = np.mean(x, axis=(0, 2, 3), keepdims=True)
            batch_var = np.var(x, axis=(0, 2, 3), keepdims=True)

            # 更新running_mean 和 running_var, 为推理阶段使用
            self.running_mean = self.momentum * self.running_mean + (1 - self.momentum) * batch_mean
            self.running_var = self.momentum * self.running_var + (1 - self.momentum) * batch_var
            
            # x[4,3,8,8],batch_mean[1,3,1,1],batch_var[1,3,1,1]
            x_normalized = (x - batch_mean) / np.sqrt(batch_var + self.epsilon)
        else:
            # 在推理阶段, 使用 running mean and variance
            x_normalized = (x - self.running_mean) / np.sqrt(self.running_var + self.epsilon)

        # Scale and shift,normalized之后,再放缩和+偏置
        # 这就是为什么batchnorm之前的Conv2d(bias=False),因为BN层有偏置
        out = self.gamma * x_normalized + self.beta
        return out


batch, channels, height, width = 4, 3, 8, 8
x = np.random.randint(-10,10, (batch, channels, height, width) )

bn_customize = BatchNorm2d(num_channels=channels)
output0 = bn_customize.forward(x, training=True)



x_tensor = torch.tensor(x).float()
batchnorm = nn.BatchNorm2d(num_features=channels,eps=1e-5,momentum=0.1)
batchnorm.train()
output1 = batchnorm(x_tensor)
output2 = output1.detach().numpy()

# 检测2者误差是否在10的负4次方
comparison = np.allclose(output0,output2,atol=1e-4)
print("Are the outputs close enough?",comparison)

部分结果对比 

output0[0,0,0:3,0:4]
array([[ 1.19182343, -0.03069952, -0.20534565,  0.84253116],
       [ 1.19182343,  1.5411157 , -0.72928406, -1.4278686 ],
       [-1.4278686 , -0.55463792, -1.07857633,  0.14394662]])
output1[0,0,0:3,0:4]
tensor([[ 1.1918, -0.0307, -0.2053,  0.8425],
        [ 1.1918,  1.5411, -0.7293, -1.4279],
        [-1.4279, -0.5546, -1.0786,  0.1439]], grad_fn=<SliceBackward0>)


http://www.kler.cn/a/378244.html

相关文章:

  • 计算机网络三张表(ARP表、MAC表、路由表)总结
  • 6. 马科维茨资产组合模型+政策意图AI金融智能体(DeepSeek-V3)增强方案(理论+Python实战)
  • HTML根元素<html>的语言属性lang:<html lang=“en“>
  • Scrapy之一个item包含多级页面的处理方案
  • 嵌入式知识点总结 操作系统 专题提升(一)-进程和线程
  • 【深度学习】嘿马深度学习笔记第11篇:卷积神经网络,学习目标【附代码文档】
  • springboot Lark扫码登录
  • WPF+MVVM案例实战(十七)- 自定义字体图标按钮的封装与实现(ABC类)
  • React v19 革新功能:2024 年需要了解的所有信息
  • 安装Go和配置镜像
  • Web Broker(Web服务应用程序)入门教程(4)
  • K3S 全面解析
  • 从0开始本地部署大模型
  • MyBatis-Plus条件构造器:构建安全、高效的数据库查询
  • NVR小程序接入平台/设备EasyNVR多个NVR同时管理视频监控新选择
  • C语言中的快速排序
  • DNA、蛋白质、生物语义语言模型的介绍
  • ARM cpu算力KDMIPS测试
  • 用 Ray 扩展 AI 应用
  • Django+Vue全栈开发旅游网项目景点详情
  • Linux系统-僵尸孤儿进程
  • Android平台RTSP转RTMP推送之采集麦克风音频转发
  • 【C++】多态的语法与底层原理
  • MATLAB算法实战应用案例精讲-【数模应用】PageRank(附MATLAB、C++、python和R语言代码实现)
  • 《Java 实现快速排序:原理剖析与代码详解》
  • thinkphp中命令行工具think使用,可用于快速生成控制器,模型,中间件等