1.模型权重保存 torch.save
model_name = args.model
if model_name == "ResNet18" or model_name == "ResNet34":
from models.ResNet1 import BasicBlock
from models.ResNet1 import ResNet as PATCHMODEL
if model_name == "ResNet18":
net = PATCHMODEL(BasicBlock, [2, 2, 2, 2], num_classes=num_classes).cuda()
torch.save(net.state_dict(), weights_dir + '/' + model_name + '_train_loss_min_numCls{}.pth'.format(num_classes))
2.模型权重上传 load_state_dict
model_name = args.model
if model_name == "ResNet18" or model_name == "ResNet34":
from models.ResNet1 import BasicBlock
from models.ResNet1 import ResNet as PATCHMODEL
if model_name == "ResNet18":
model = PATCHMODEL(BasicBlock, [2, 2, 2, 2], num_classes=num_classes).cuda()
model.load_state_dict(torch.load(model_path), strict=False)