YK人工智能(五)——万字长文学会torch模型微调
1. 模型微调 - timm
除了使用torchvision.models
进行预训练以外,还有一个常见的预训练模型库,叫做timm
,这个库是由Ross Wightman创建的。里面提供了许多计算机视觉的SOTA模型,可以当作是torchvision的扩充版本,并且里面的模型在准确度上也较高。在本章内容中,我们主要是针对这个库的预训练模型的使用做叙述,其他部分内容(数据扩增,优化器等)如果大家感兴趣,可以参考以下两个链接。
- Github链接:https://github.com/rwightman/pytorch-image-models
- 官网链接:https://fastai.github.io/timmdocs/
https://rwightman.github.io/pytorch-image-models/
1.1 timm的安装
关于timm的安装,我们可以选择以下两种方式进行:
- 通过pip安装
pip install timm
- 通过源码编译安装
git clone https://github.com/rwightman/pytorch-image-models
cd pytorch-image-models && pip install -e .
1.2 如何查看预训练模型种类
- 查看timm提供的预训练模型
import timm
avail_pretrained_models = timm.list_models(pretrained=True)
len(avail_pretrained_models)
/Users/yaliu/miniconda3/envs/llm/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
1493
- 查看特定模型的所有种类
每一种系列可能对应着不同方案的模型,比如Resnet系列就包括了ResNet18,50,101等模型,我们可以在timm.list_models()
传入想查询的模型名称(模糊查询),比如我们想查询densenet系列的所有模型。
all_densnet_models = timm.list_models("*densenet*")
all_densnet_models
['densenet121',
'densenet161',
'densenet169',
'densenet201',
'densenet264d',
'densenetblur121d']
我们发现以列表的形式返回了所有densenet系列的所有模型。
[‘densenet121’,
‘densenet121d’,
‘densenet161’,
‘densenet169’,
‘densenet201’,
‘densenet264’,
‘densenet264d_iabn’,
‘densenetblur121d’,
‘tv_densenet121’]
- 查看模型的具体参数
当我们想查看下模型的具体参数的时候,我们可以通过访问模型的default_cfg
属性来进行查看,具体操作如下
model.default_cfg
{'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth',
'num_classes': 1000,
'input_size': (3, 224, 224),
'pool_size': (7, 7),
'crop_pct': 0.875,
'interpolation': 'bilinear',
'mean': (0.485, 0.456, 0.406),
'std': (0.229, 0.224, 0.225),
'first_conv': 'conv1',
'classifier': 'fc',
'architecture': 'resnet34'}
除此之外,我们可以通过访问这个链接 查看提供的预训练模型的准确度等信息。
1.3 使用和修改预训练模型
在得到我们想要使用的预训练模型后,我们可以通过timm.create_model()
的方法来进行模型的创建,我们可以通过传入参数pretrained=True
,来使用预训练模型。同样的,我们也可以使用跟torchvision里面的模型一样的方法查看模型的参数,类型/
import timm
import torch
model = timm.create_model('resnet34',pretrained=True)
x = torch.randn(1,3,224,224)
output = model(x)
output.shape
torch.Size([1, 1000])
- 查看某一层模型参数(以第一层卷积为例)
model = timm.create_model('resnet34',pretrained=True)
list(dict(model.named_children())['conv1'].parameters())
[Parameter containing:
tensor([[[[ 1.3241e-02, 6.6559e-03, 5.7815e-02, ..., -3.8790e-02,
1.9017e-02, 9.5395e-03],
[-7.0519e-03, -7.8363e-03, 3.4842e-02, ..., -1.0224e-01,
-1.2634e-02, 1.5481e-02],
[ 1.1300e-03, -5.1502e-02, 9.0429e-02, ..., -1.8404e-01,
5.5529e-02, 2.2725e-02],
...,
[ 9.9761e-03, -8.1292e-03, 6.7813e-02, ..., -1.6383e-02,
8.5947e-02, -5.2195e-03],
[-3.1304e-05, -5.6708e-03, 3.1476e-02, ..., 2.3568e-02,
3.3940e-02, 1.1443e-02],
[-1.6616e-02, -1.9411e-02, 1.8363e-02, ..., -3.4171e-02,
4.5334e-02, 6.7183e-03]],
[[-1.1590e-02, -4.6420e-03, 1.9632e-01, ..., -1.5805e-01,
-1.5498e-01, 5.5492e-02],
[-4.1337e-02, 8.6030e-03, 2.5071e-01, ..., -4.4611e-01,
-1.4408e-01, 1.6274e-01],
[-2.3485e-02, -1.4609e-03, 3.6883e-01, ..., -7.4886e-01,
1.2964e-01, 2.4389e-01],
...,
[ 1.8104e-02, 5.9922e-02, 1.6394e-01, ..., -1.2101e-01,
3.0192e-01, 8.4386e-02],
[ 2.1925e-02, 2.6516e-02, 1.3733e-02, ..., 5.2955e-02,
1.5561e-01, 3.1387e-02],
[ 4.6675e-03, -3.7291e-02, -2.3643e-03, ..., 2.3555e-02,
1.2043e-01, 3.4407e-02]],
[[-1.4739e-02, -5.8642e-02, 1.7216e-01, ..., -1.1750e-01,
-2.1472e-01, 2.3416e-02],
[-2.5352e-02, -7.0249e-02, 2.8296e-01, ..., -4.2906e-01,
-1.8982e-01, 1.8421e-01],
[-4.2610e-02, -3.1400e-02, 4.1770e-01, ..., -7.3294e-01,
1.0219e-01, 2.8515e-01],
...,
[ 1.8093e-02, 1.0949e-01, 1.3435e-01, ..., -8.6010e-02,
3.8485e-01, 1.9760e-02],
[ 4.5487e-02, 1.1258e-02, -4.4662e-02, ..., 1.1688e-01,
1.7451e-01, -4.5457e-02],
[ 4.5555e-02, -4.9004e-02, -4.4223e-02, ..., 1.2554e-01,
1.1863e-01, -5.2878e-02]]],
[[[-1.4286e-02, -5.7435e-02, -5.0481e-02, ..., 2.3353e-01,
9.3413e-03, -1.3056e-01],
[ 6.9023e-02, 1.2293e-01, 3.4628e-01, ..., -3.4137e-01,
-4.3753e-01, -7.2841e-03],
[ 6.9188e-02, 7.4427e-02, -1.5753e-01, ..., -5.0551e-01,
2.5123e-01, 3.0250e-01],
...,
[-8.8818e-02, -3.5606e-02, 1.8457e-01, ..., 4.9473e-04,
-1.9765e-01, -1.4096e-01],
[ 2.5310e-02, 7.2172e-02, 9.2005e-02, ..., -8.8405e-02,
-3.7373e-02, 3.0886e-02],
[ 2.8831e-03, 2.0433e-02, -2.4169e-02, ..., -2.9244e-03,
-5.9756e-04, 1.3455e-02]],
[[ 6.1489e-04, -4.9930e-02, -7.9242e-02, ..., 2.5175e-01,
9.3875e-03, -1.1223e-01],
[ 2.5510e-02, 8.6683e-02, 3.7108e-01, ..., -3.9404e-01,
-4.7669e-01, 1.6728e-02],
[ 2.1567e-02, 1.0391e-01, -1.3022e-01, ..., -5.6372e-01,
3.1695e-01, 2.8162e-01],
...,
[-2.7600e-02, -2.8208e-03, 1.7733e-01, ..., -4.2006e-02,
-2.2525e-01, -1.1917e-01],
[ 3.6358e-02, 3.5118e-02, 6.7940e-02, ..., -6.3574e-02,
-3.4339e-02, 3.4484e-02],
[ 3.3697e-03, 2.2288e-03, -5.9547e-03, ..., -5.2432e-03,
-8.9137e-03, -7.8769e-03]],
[[ 3.6170e-02, -2.1226e-02, -9.5854e-02, ..., 6.7977e-02,
1.9102e-02, 5.3759e-04],
[-1.2038e-02, -6.3962e-02, 1.3238e-01, ..., -8.4713e-02,
-1.3318e-01, 2.9261e-02],
[-4.7583e-02, 5.0337e-02, 3.5472e-02, ..., -1.7995e-01,
1.3148e-01, 4.6766e-03],
...,
[ 3.6997e-02, 1.7687e-02, 2.1564e-02, ..., -1.0761e-01,
-8.0010e-02, 9.8908e-03],
[-1.3485e-03, -5.2196e-02, -1.7048e-02, ..., 1.9133e-02,
5.0341e-02, 3.8563e-02],
[-6.2974e-03, 1.6892e-02, 1.2362e-02, ..., 8.7610e-03,
-9.0904e-03, -2.9497e-02]]],
[[[ 1.7855e-02, -3.5589e-02, -3.0748e-02, ..., -4.0728e-02,
2.0649e-02, 1.2108e-02],
[ 2.5726e-02, -6.1673e-02, -2.8369e-02, ..., -2.9064e-02,
6.1145e-05, 1.3813e-04],
[ 2.1786e-02, -1.8098e-01, 3.4700e-02, ..., 3.4765e-03,
-9.7975e-03, 3.7564e-03],
...,
[-1.6468e-02, -7.2676e-01, 3.1810e-02, ..., 9.5118e-02,
-4.0116e-02, -2.6844e-02],
[-2.3499e-02, -6.8231e-01, -5.7596e-02, ..., 1.4643e-01,
-1.2075e-02, -3.3952e-02],
[ 5.4030e-03, -4.1767e-01, -8.1749e-02, ..., 7.2265e-02,
1.1518e-02, 3.0781e-02]],
[[ 1.1709e-02, -4.7288e-02, -5.7795e-03, ..., -3.1544e-02,
1.3032e-02, 5.4662e-03],
[ 3.2407e-02, -3.5795e-02, 3.0350e-03, ..., -2.0654e-02,
3.6605e-03, -1.1423e-02],
[ 1.7763e-02, -1.3680e-01, 2.7630e-02, ..., 1.5636e-02,
-3.1981e-02, 8.4630e-03],
...,
[-7.3526e-03, -6.7698e-01, 3.1936e-02, ..., 1.3407e-01,
-1.7916e-03, -1.2497e-02],
[-1.7930e-02, -6.4497e-01, -6.4267e-02, ..., 1.7841e-01,
4.2890e-02, -2.4835e-03],
[ 1.0293e-02, -4.1721e-01, -1.1511e-01, ..., 4.9444e-02,
3.3947e-02, 2.4359e-02]],
[[-2.6378e-02, 5.2438e-03, 1.9837e-02, ..., -1.0643e-02,
6.0600e-03, 1.3987e-02],
[-1.2431e-02, 9.2507e-03, 3.8218e-03, ..., -5.1784e-03,
-1.2076e-03, -1.0555e-02],
[-3.0464e-02, -2.8299e-02, 4.2893e-02, ..., -1.7603e-02,
-2.0343e-02, -7.8467e-03],
...,
[ 1.6213e-02, -2.7667e-01, 7.0711e-02, ..., 1.3919e-02,
2.2801e-02, -3.1549e-03],
[-1.6207e-03, -2.5565e-01, -1.9946e-02, ..., 5.3616e-02,
1.4953e-02, -9.0670e-03],
[ 3.0156e-02, -1.7439e-01, -4.1103e-02, ..., -3.1542e-02,
2.4011e-03, -9.4365e-03]]],
...,
[[[ 5.8947e-03, -2.7493e-02, -3.0449e-02, ..., -6.9770e-02,
1.6368e-01, 1.2380e-01],
[ 1.1019e-01, 5.5498e-02, -4.0677e-02, ..., -4.1906e-01,
1.5541e-02, 2.0763e-01],
[ 1.3086e-01, 2.7921e-01, 3.2870e-01, ..., -6.6961e-01,
-4.7537e-01, -2.3455e-02],
...,
[-1.5914e-01, -1.6557e-01, -1.6625e-02, ..., 6.2807e-01,
3.0720e-02, -2.0916e-01],
[-9.0365e-02, -2.2828e-01, -3.5199e-01, ..., 4.1845e-01,
3.1820e-01, 3.5371e-02],
[ 2.3924e-02, -2.2370e-02, -3.1909e-01, ..., 3.9743e-02,
1.9867e-01, 1.2677e-01]],
[[-7.1503e-03, -1.9327e-02, 8.2112e-03, ..., -4.8980e-02,
1.0190e-01, 5.2047e-02],
[ 4.7220e-02, -1.3147e-04, -3.8134e-02, ..., -2.6850e-01,
5.8958e-02, 1.6542e-01],
[ 5.1625e-02, 1.6315e-01, 2.6927e-01, ..., -5.8208e-01,
-2.9708e-01, 2.1708e-02],
...,
[-1.2651e-01, -2.1064e-01, -8.8513e-02, ..., 5.1764e-01,
-1.2078e-02, -1.1731e-01],
[-1.2137e-04, -1.6347e-01, -3.2100e-01, ..., 2.9400e-01,
1.8603e-01, 1.2233e-02],
[ 7.0618e-02, 6.1514e-02, -1.7610e-01, ..., 5.3494e-02,
1.3066e-01, 7.3699e-02]],
[[ 2.5782e-02, 1.1023e-02, 1.6451e-02, ..., -1.1678e-02,
1.3607e-02, -2.6681e-02],
[ 1.6801e-02, -1.5870e-02, -1.1689e-02, ..., -3.2389e-02,
5.0521e-02, 2.7317e-02],
[ 1.0790e-02, 3.8570e-02, 5.3604e-02, ..., -1.9055e-01,
-4.1125e-02, 2.7854e-02],
...,
[-4.8699e-02, -8.5574e-02, -7.3246e-02, ..., 1.4008e-01,
-5.4957e-03, 2.0663e-02],
[-3.1447e-04, -4.9698e-02, -1.1207e-01, ..., 6.3284e-02,
5.5323e-03, -2.0749e-02],
[ 1.6497e-02, 4.1875e-02, -2.9682e-02, ..., 3.9186e-02,
5.3620e-02, 5.0269e-03]]],
[[[-2.8548e-02, -2.6543e-02, 6.7921e-02, ..., 1.2728e-01,
-6.9825e-03, -5.9453e-02],
[-7.1546e-03, 1.4437e-01, -1.8212e-01, ..., -2.0689e-01,
-5.2888e-03, 3.5883e-02],
[ 7.3313e-02, -8.3201e-02, -9.5630e-02, ..., 1.5879e-01,
-2.5985e-01, 1.9782e-01],
...,
[-1.0720e-01, 8.2556e-02, -1.8204e-01, ..., -1.1637e-01,
-1.5587e-02, -1.2281e-01],
[ 3.6179e-03, 1.5731e-01, -2.8542e-01, ..., 6.1562e-02,
-2.4722e-01, 2.6920e-01],
[-1.8427e-02, 1.4430e-01, -2.7984e-01, ..., -1.4372e-01,
-7.3867e-02, 4.0152e-02]],
[[ 3.8070e-02, 1.9433e-02, -1.1266e-01, ..., -3.9471e-02,
-1.1299e-02, 1.0684e-01],
[-3.8311e-02, -2.5227e-02, 8.7646e-03, ..., 5.6639e-02,
-4.9940e-02, -1.6739e-01],
[ 3.0385e-02, -1.0573e-01, 3.0423e-01, ..., -2.9630e-01,
5.7277e-01, -1.5735e-01],
...,
[-8.4368e-02, 7.0955e-02, -7.6949e-02, ..., -4.1375e-02,
-2.9747e-01, 1.3582e-01],
[-3.2340e-02, 2.8470e-01, -1.3323e-01, ..., 2.0059e-01,
2.6076e-01, -2.5807e-02],
[-1.9552e-02, -2.1651e-01, 3.1641e-01, ..., 4.2781e-02,
-4.7551e-02, -1.1276e-01]],
[[ 1.2550e-02, -2.7163e-02, 6.5782e-02, ..., -1.1360e-01,
3.6653e-02, -5.4121e-02],
[-4.6508e-03, -4.1934e-02, 1.3767e-01, ..., 2.0854e-01,
-1.0368e-03, 1.5212e-01],
[-5.6300e-02, 1.3897e-01, -2.2924e-01, ..., -1.0420e-02,
-1.8810e-01, -8.8822e-02],
...,
[ 2.1298e-01, -1.6483e-01, 1.8837e-01, ..., -6.5574e-02,
5.0852e-01, -1.1179e-01],
[-1.5652e-02, -3.7864e-01, 4.0806e-01, ..., -1.3842e-01,
-1.3881e-01, -1.7664e-01],
[ 6.5026e-02, 3.9605e-02, -2.9637e-02, ..., 4.8323e-02,
1.6893e-01, 5.0818e-02]]],
[[[ 1.4048e-01, -3.6534e-02, -1.0400e-01, ..., -1.3452e-01,
-9.3455e-02, -1.3049e-02],
[-1.4503e-01, -3.0569e-01, -3.8185e-01, ..., -3.4403e-01,
-2.3987e-01, -1.7264e-01],
[-1.6567e-01, -2.8784e-01, -3.0000e-01, ..., -2.1768e-01,
-1.4775e-01, -3.9372e-02],
...,
[ 1.3176e-02, 5.8095e-02, 1.8447e-01, ..., 3.8632e-01,
3.5278e-01, 2.7609e-01],
[ 3.6373e-02, 1.0737e-01, 2.1756e-01, ..., 3.3167e-01,
2.3077e-01, 1.4995e-01],
[ 8.9822e-02, 1.6576e-01, 2.1574e-01, ..., 1.8515e-01,
6.5659e-02, 1.2438e-02]],
[[ 1.6540e-01, -8.2899e-03, -1.5084e-01, ..., -1.8677e-01,
-1.2802e-01, -1.6653e-02],
[-1.9052e-01, -4.0643e-01, -5.2828e-01, ..., -4.6475e-01,
-3.2534e-01, -2.0600e-01],
[-2.8335e-01, -4.4007e-01, -4.8985e-01, ..., -3.1629e-01,
-2.2295e-01, -9.3976e-02],
...,
[ 2.5611e-02, 6.3411e-02, 2.0961e-01, ..., 4.6038e-01,
4.0952e-01, 3.4953e-01],
[ 1.0977e-01, 1.8378e-01, 3.2035e-01, ..., 4.7985e-01,
3.8312e-01, 3.1068e-01],
[ 1.7699e-01, 2.6005e-01, 3.0907e-01, ..., 3.2934e-01,
1.8473e-01, 1.0355e-01]],
[[ 1.5434e-01, 5.3328e-02, 6.5378e-03, ..., -5.3431e-02,
-3.3524e-02, 7.3244e-03],
[-4.7930e-02, -1.4941e-01, -2.4998e-01, ..., -2.2902e-01,
-1.7905e-01, -1.3108e-01],
[-1.0564e-01, -2.1076e-01, -2.6875e-01, ..., -1.5942e-01,
-1.0332e-01, -5.7918e-02],
...,
[ 1.6620e-03, -7.3453e-03, 6.5151e-02, ..., 2.4352e-01,
1.9470e-01, 1.5235e-01],
[ 4.2911e-02, 5.8226e-02, 1.0825e-01, ..., 2.3578e-01,
1.7880e-01, 1.3307e-01],
[ 1.1795e-01, 1.4593e-01, 1.3136e-01, ..., 1.7697e-01,
1.2030e-01, 8.8541e-02]]]], requires_grad=True)]
- 修改模型(将1000类改为10类输出)
model = timm.create_model('resnet34',num_classes=10,pretrained=True)
x = torch.randn(1,3,224,224)
output = model(x)
output.shape
torch.Size([1, 10])
- 改变输入通道数(比如我们传入的图片是单通道的,但是模型需要的是三通道图片)
我们可以通过添加in_chans=1
来改变
model = timm.create_model('resnet34',num_classes=10,pretrained=True,in_chans=1)
x = torch.randn(1,1,224,224)
output = model(x)
print(output)
tensor([[ 0.1117, -0.0241, 0.1028, -0.0363, -0.0025, 0.1028, -0.2013, 0.1746,
-0.1014, -0.1916]], grad_fn=<AddmmBackward0>)
1.4 模型的保存
timm库所创建的模型是torch.model
的子类,我们可以直接使用torch库中内置的模型参数保存和加载的方法,具体操作如下方代码所示
torch.save(model.state_dict(),'./checkpoint/timm_model.pth')
model.load_state_dict(torch.load('./checkpoint/timm_model.pth'))
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[8], line 1
----> 1 torch.save(model.state_dict(),'./checkpoint/timm_model.pth')
2 model.load_state_dict(torch.load('./checkpoint/timm_model.pth'))
File ~/miniconda3/envs/llm/lib/python3.10/site-packages/torch/serialization.py:651, in save(obj, f, pickle_module, pickle_protocol, _use_new_zipfile_serialization, _disable_byteorder_record)
648 _check_save_filelike(f)
650 if _use_new_zipfile_serialization:
--> 651 with _open_zipfile_writer(f) as opened_zipfile:
652 _save(obj, opened_zipfile, pickle_module, pickle_protocol, _disable_byteorder_record)
653 return
File ~/miniconda3/envs/llm/lib/python3.10/site-packages/torch/serialization.py:525, in _open_zipfile_writer(name_or_buffer)
523 else:
524 container = _open_zipfile_writer_buffer
--> 525 return container(name_or_buffer)
File ~/miniconda3/envs/llm/lib/python3.10/site-packages/torch/serialization.py:496, in _open_zipfile_writer_file.__init__(self, name)
494 super().__init__(torch._C.PyTorchFileWriter(self.file_stream))
495 else:
--> 496 super().__init__(torch._C.PyTorchFileWriter(self.name))
RuntimeError: Parent directory ./checkpoint does not exist.
2. 模型微调-torchvision
随着深度学习的发展,模型的参数越来越大,许多开源模型都是在较大数据集上进行训练的,比如Imagenet-1k,Imagenet-11k,甚至是ImageNet-21k等。但在实际应用中,我们的数据集可能只有几千张,这时从头开始训练具有几千万参数的大型神经网络是不现实的,因为越大的模型对数据量的要求越大,过拟合无法避免。
假设我们想从图像中识别出不同种类的椅⼦,然后将购买链接推荐给用户。一种可能的方法是先找出100种常见的椅子,为每种椅子拍摄1000张不同⻆度的图像,然后在收集到的图像数据集上训练一个分类模型。这个椅子数据集虽然可能比Fashion-MNIST数据集要庞⼤,但样本数仍然不及ImageNet数据集中样本数的十分之⼀。这可能会导致适用于ImageNet数据集的复杂模型在这个椅⼦数据集上过拟合。同时,因为数据量有限,最终训练得到的模型的精度也可能达不到实用的要求。
为了应对上述问题,一个显⽽易⻅的解决办法是收集更多的数据。然而,收集和标注数据会花费大量的时间和资⾦。例如,为了收集ImageNet数据集,研究人员花费了数百万美元的研究经费。虽然目前的数据采集成本已降低了不少,但其成本仍然不可忽略。
另外一种解决办法是应用迁移学习(transfer learning),将从源数据集学到的知识迁移到目标数据集上。例如,虽然ImageNet数据集的图像大多跟椅子无关,但在该数据集上训练的模型可以抽取较通用的图像特征,从而能够帮助识别边缘、纹理、形状和物体组成等。这些类似的特征对于识别椅子也可能同样有效。
迁移学习的一大应用场景是模型微调(finetune)。简单来说,就是我们先找到一个同类的别人训练好的模型,把别人现成的训练好了的模型拿过来,换成自己的数据,通过训练调整一下参数。 在PyTorch中提供了许多预训练好的网络模型(VGG,ResNet系列,mobilenet系列…),这些模型都是PyTorch官方在相应的大型数据集训练好的。学习如何进行模型微调,可以方便我们快速使用预训练模型完成自己的任务。
经过本节的学习,你将收获:
- 掌握模型微调的流程
- 了解PyTorch提供的常用model
- 掌握如何指定训练模型的部分层
2.1 模型微调的流程
- 在源数据集(如ImageNet数据集)上预训练一个神经网络模型,即源模型。
- 创建一个新的神经网络模型,即目标模型。它复制了源模型上除了输出层外的所有模型设计及其参数。我们假设这些模型参数包含了源数据集上学习到的知识,且这些知识同样适用于目标数据集。我们还假设源模型的输出层跟源数据集的标签紧密相关,因此在目标模型中不予采用。
- 为目标模型添加一个输出⼤小为⽬标数据集类别个数的输出层,并随机初始化该层的模型参数。
- 在目标数据集上训练目标模型。我们将从头训练输出层,而其余层的参数都是基于源模型的参数微调得到的。
2.2 使用已有模型结构
这里我们以torchvision中的常见模型为例,列出了如何在图像分类任务中使用PyTorch提供的常见模型结构和参数。对于其他任务和网络结构,使用方式是类似的:
-
实例化网络
注意事项:
-
通常PyTorch模型的扩展为
.pt
或.pth
,程序运行时会首先检查默认路径中是否有已经下载的模型权重,一旦权重被下载,下次加载就不需要下载了。 -
一般情况下预训练模型的下载会比较慢,我们可以直接通过迅雷或者其他方式去 这里 查看自己的模型里面
model_urls
,然后手动下载,预训练模型的权重在Linux
和Mac
的默认下载路径是用户根目录下的.cache
文件夹。在Windows
下就是C:\Users\<username>\.cache\torch\hub\checkpoint
。我们可以通过使用torch.utils.model_zoo.load_url()
设置权重的下载地址。 -
如果觉得麻烦,还可以将自己的权重下载下来放到同文件夹下,然后再将参数加载网络。
self.model = models.resnet50(pretrained=False) self.model.load_state_dict(torch.load('./model/resnet50-19c8e357.pth'))
-
如果中途强行停止下载的话,一定要去对应路径下将权重文件删除干净,要不然可能会报错。
2.3 训练特定层
在默认情况下,参数的属性.requires_grad = True
,如果我们从头开始训练或微调不需要注意这里。但如果我们正在提取特征并且只想为新初始化的层计算梯度,其他参数不进行改变。那我们就需要通过设置requires_grad = False
来冻结部分层。在PyTorch官方中提供了这样一个例程。
import torchvision.models as models
resnet18 = models.resnet18()
# resnet18 = models.resnet18(pretrained=False) 等价于与上面的表达式
alexnet = models.alexnet()
vgg16 = models.vgg16()
squeezenet = models.squeezenet1_0()
densenet = models.densenet161()
inception = models.inception_v3()
googlenet = models.googlenet()
shufflenet = models.shufflenet_v2_x1_0()
mobilenet_v2 = models.mobilenet_v2()
mobilenet_v3_large = models.mobilenet_v3_large()
mobilenet_v3_small = models.mobilenet_v3_small()
resnext50_32x4d = models.resnext50_32x4d()
wide_resnet50_2 = models.wide_resnet50_2()
mnasnet = models.mnasnet1_0()
# - 传递`pretrained`参数
# 通过`True`或者`False`来决定是否使用预训练好的权重,在默认状态下`pretrained = False`,意味着我们不使用预训练得到的权重,当`pretrained = True`,意味着我们将使用在一些数据集上预训练得到的权重。
import torchvision.models as models
resnet18 = models.resnet18(pretrained=True)
alexnet = models.alexnet(pretrained=True)
squeezenet = models.squeezenet1_0(pretrained=True)
vgg16 = models.vgg16(pretrained=True)
densenet = models.densenet161(pretrained=True)
inception = models.inception_v3(pretrained=True)
googlenet = models.googlenet(pretrained=True)
shufflenet = models.shufflenet_v2_x1_0(pretrained=True)
mobilenet_v2 = models.mobilenet_v2(pretrained=True)
mobilenet_v3_large = models.mobilenet_v3_large(pretrained=True)
mobilenet_v3_small = models.mobilenet_v3_small(pretrained=True)
resnext50_32x4d = models.resnext50_32x4d(pretrained=True)
wide_resnet50_2 = models.wide_resnet50_2(pretrained=True)
/Users/yaliu/miniconda3/envs/llm/lib/python3.10/site-packages/torchvision/models/inception.py:43: FutureWarning: The default weight initialization of inception_v3 will be changed in future releases of torchvision. If you wish to keep the old behavior (which leads to long initialization times due to scipy/scipy#11299), please set init_weights=True.
warnings.warn(
/Users/yaliu/miniconda3/envs/llm/lib/python3.10/site-packages/torchvision/models/googlenet.py:47: FutureWarning: The default weight initialization of GoogleNet will be changed in future releases of torchvision. If you wish to keep the old behavior (which leads to long initialization times due to scipy/scipy#11299), please set init_weights=True.
warnings.warn(
/Users/yaliu/miniconda3/envs/llm/lib/python3.10/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
warnings.warn(
/Users/yaliu/miniconda3/envs/llm/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.
warnings.warn(msg)
/Users/yaliu/miniconda3/envs/llm/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=AlexNet_Weights.IMAGENET1K_V1`. You can also use `weights=AlexNet_Weights.DEFAULT` to get the most up-to-date weights.
warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth" to /Users/yaliu/.cache/torch/hub/checkpoints/alexnet-owt-7be5be79.pth
100%|██████████| 233M/233M [00:29<00:00, 8.16MB/s]
/Users/yaliu/miniconda3/envs/llm/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=SqueezeNet1_0_Weights.IMAGENET1K_V1`. You can also use `weights=SqueezeNet1_0_Weights.DEFAULT` to get the most up-to-date weights.
warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth" to /Users/yaliu/.cache/torch/hub/checkpoints/squeezenet1_0-b66bff10.pth
100%|██████████| 4.78M/4.78M [00:01<00:00, 2.70MB/s]
/Users/yaliu/miniconda3/envs/llm/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=VGG16_Weights.IMAGENET1K_V1`. You can also use `weights=VGG16_Weights.DEFAULT` to get the most up-to-date weights.
warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /Users/yaliu/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [01:03<00:00, 8.74MB/s]
/Users/yaliu/miniconda3/envs/llm/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=DenseNet161_Weights.IMAGENET1K_V1`. You can also use `weights=DenseNet161_Weights.DEFAULT` to get the most up-to-date weights.
warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/densenet161-8d451a50.pth" to /Users/yaliu/.cache/torch/hub/checkpoints/densenet161-8d451a50.pth
100%|██████████| 110M/110M [00:15<00:00, 7.41MB/s]
/Users/yaliu/miniconda3/envs/llm/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=Inception_V3_Weights.IMAGENET1K_V1`. You can also use `weights=Inception_V3_Weights.DEFAULT` to get the most up-to-date weights.
warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth" to /Users/yaliu/.cache/torch/hub/checkpoints/inception_v3_google-0cc3c7bd.pth
100%|██████████| 104M/104M [00:13<00:00, 7.81MB/s]
/Users/yaliu/miniconda3/envs/llm/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=GoogLeNet_Weights.IMAGENET1K_V1`. You can also use `weights=GoogLeNet_Weights.DEFAULT` to get the most up-to-date weights.
warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/googlenet-1378be20.pth" to /Users/yaliu/.cache/torch/hub/checkpoints/googlenet-1378be20.pth
100%|██████████| 49.7M/49.7M [00:07<00:00, 7.20MB/s]
/Users/yaliu/miniconda3/envs/llm/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1`. You can also use `weights=ShuffleNet_V2_X1_0_Weights.DEFAULT` to get the most up-to-date weights.
warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth" to /Users/yaliu/.cache/torch/hub/checkpoints/shufflenetv2_x1-5666bf0f80.pth
100%|██████████| 8.79M/8.79M [00:01<00:00, 5.07MB/s]
/Users/yaliu/miniconda3/envs/llm/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=MobileNet_V2_Weights.IMAGENET1K_V1`. You can also use `weights=MobileNet_V2_Weights.DEFAULT` to get the most up-to-date weights.
warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth" to /Users/yaliu/.cache/torch/hub/checkpoints/mobilenet_v2-b0353104.pth
100%|██████████| 13.6M/13.6M [00:02<00:00, 5.75MB/s]
/Users/yaliu/miniconda3/envs/llm/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=MobileNet_V3_Large_Weights.IMAGENET1K_V1`. You can also use `weights=MobileNet_V3_Large_Weights.DEFAULT` to get the most up-to-date weights.
warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth" to /Users/yaliu/.cache/torch/hub/checkpoints/mobilenet_v3_large-8738ca79.pth
100%|██████████| 21.1M/21.1M [00:03<00:00, 6.31MB/s]
/Users/yaliu/miniconda3/envs/llm/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=MobileNet_V3_Small_Weights.IMAGENET1K_V1`. You can also use `weights=MobileNet_V3_Small_Weights.DEFAULT` to get the most up-to-date weights.
warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth" to /Users/yaliu/.cache/torch/hub/checkpoints/mobilenet_v3_small-047dcff4.pth
100%|██████████| 9.83M/9.83M [00:01<00:00, 5.46MB/s]
/Users/yaliu/miniconda3/envs/llm/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNeXt50_32X4D_Weights.IMAGENET1K_V1`. You can also use `weights=ResNeXt50_32X4D_Weights.DEFAULT` to get the most up-to-date weights.
warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth" to /Users/yaliu/.cache/torch/hub/checkpoints/resnext50_32x4d-7cdf4587.pth
100%|██████████| 95.8M/95.8M [00:13<00:00, 7.59MB/s]
/Users/yaliu/miniconda3/envs/llm/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=Wide_ResNet50_2_Weights.IMAGENET1K_V1`. You can also use `weights=Wide_ResNet50_2_Weights.DEFAULT` to get the most up-to-date weights.
warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth" to /Users/yaliu/.cache/torch/hub/checkpoints/wide_resnet50_2-95faca4d.pth
100%|██████████| 132M/132M [00:17<00:00, 7.71MB/s]
def set_parameter_requires_grad(model, feature_extracting):
if feature_extracting:
for param in model.parameters():
param.requires_grad = False
在下面我们仍旧使用resnet18
为例的将1000类改为4类,但是仅改变最后一层的模型参数,不改变特征提取的模型参数;注意我们先冻结模型参数的梯度,再对模型输出部分的全连接层进行修改,这样修改后的全连接层的参数就是可计算梯度的。
import torchvision.models as models
# 冻结参数的梯度
feature_extract = True
model = models.resnet18(pretrained=True)
set_parameter_requires_grad(model, feature_extract)
# 修改模型
num_ftrs = model.fc.in_features
model.fc = nn.Linear(in_features=num_ftrs, out_features=4, bias=True)
之后在训练过程中,model仍会进行梯度回传,但是参数更新则只会发生在fc层。通过设定参数的requires_grad属性,我们完成了指定训练模型的特定层的目标,这对实现模型微调非常重要。
本节参考
- 参数更新
- 给不同层分配不同的学习率
- https://www.aiuai.cn/aifarm1967.html
- https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055
- https://chowdera.com/2022/03/202203170834122729.html