thop计算模型复杂度(params,flops)
thop安装
-pip install thop在线安装失败
-离线安装
github网址:
pytorch-OpCounter:Count the MACs / FLOPs of your PyTorch model. - GitCode
python setup.py install
测试:
from options import config as c
import os
os.environ["CUDA_VISIBLE_DEVICES"] = c.os_environ
import torch.nn
from modules.NET import Net
from utils.utils import load
from utils.yml import parse_yml, dict_to_nonedict
import numpy as np
from thop import profile
from modules.DCTGate_fast import DCT_transform
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# --------------------noise set------------------------------
# read noise_config
yml_path = c.noise_opt_yml_path
option_yml = parse_yml(yml_path)
# convert to NoneDict, which returns None for missing keys
noise_opt = dict_to_nonedict(option_yml)
# -------------------MODEL load---------------------------
model_path = c.load_model_path
net = Net(noise_opt, device).to(device)
load(model_path, net)
# -----------------MODEL input-----------------------
cover = torch.randn(1, 3, c.cropsize_val, c.cropsize_val).to(device)
secret = torch.Tensor(np.random.choice([-0.5, 0.5], (cover.shape[0], c.input_message_length))).to(device)
dct_trans = DCT_transform(image_size=c.cropsize_val, block_size=8).to(device)
cover_dct = dct_trans(cover)
# ------------cal: params, FLOPS-----------
flops, params = profile(net, (cover, secret, cover_dct))
print(f'\nflops: {flops}, params: {params}\n')
print('the flops is {}G, the params is {}M\n'.format(round(flops / (10 ** 9), 2), round(params / (10 ** 6), 2)))