统计模型的Flops和Params
1、方法一 thop
from thop import profile, clever_format
model = Model() ## 实例化模型
input = torch.randn(1, 3, 128, 128) ## 模拟输入
flops, params = profile(model, inputs=(input,))
flops, params = clever_format([flops, params], "%.3f")
print('flops: {}, params: {}'.format(flops, params))
-
thop
是一个用于计算模型 FLOPs 和参数量的库。 -
profile
函数用于计算模型的 FLOPs 和参数量。 -
clever_format
函数用于将计算得到的 FLOPs 和参数量格式化为更易读的形式(例如,将1000000
转换为1.000M
)。
2、方法二 torchinfo
from torchinfo import summary
input = torch.randn(1, 3, 128, 128) ## 模拟输入
# 使用 torchinfo 计算参数量
summary(model, input_data=input)
3、方法三 手工计算
total_params = sum(p.numel() for p in model.parameters())
print(f"总参数量: {total_params / 1e6:.3f}M")