警告 torch.nn.utils.weight_norm is deprecate 的参考解决方法
文章目录
- 写在前面
- 一、问题描述
- 二、解决方法
- 参考链接
写在前面
自己的测试环境:
Ubuntu20.04
一、问题描述
运行 pytorch 程序,如下如下警告
/home/wong/ProgramFiles/anaconda3/envs/pytorch_env/lib/python3.8/site-packages/torch/nn/utils/weight_norm.py:30: UserWarning: torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.
warnings.warn("torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.")
二、解决方法
根据提示可以发现,torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.
也就是 torch.nn.utils.weight_norm
已经被弃用,建议使用 torch.nn.utils.parametrizations.weight_norm.
因此需要改为 torch.nn.utils.parametrizations.weight_norm
, 并且,根据pytorch官网中的描述,还需要把程序中的 weight_g
修改为parametrizations.weight.original0
, weight_v
修改为parametrizations.weight.original1
.
比如,原程序为
theta = torch.nn.utils.weight_norm(nn.Linear(d_in, num_projections, bias=False), dim=0)
if num_projections <= d_in:
torch.nn.init.eye_(theta.weight_v)
else:
torch.nn.init.normal_(theta.weight_v)
theta.weight_g.data = torch.ones_like(theta.weight_g.data, requires_grad=False)
theta.weight_g.requires_grad = False
需要修改为:
theta = torch.nn.utils.parametrizations.weight_norm(nn.Linear(d_in, num_projections, bias=False), dim=0)
if num_projections <= d_in:
torch.nn.init.eye_(theta.parametrizations.weight.original1)
else:
torch.nn.init.normal_(theta.parametrizations.weight.original1)
theta.parametrizations.weight.original0.data = torch.ones_like(theta.parametrizations.weight.original0.data, requires_grad=False)
theta.parametrizations.weight.original0.requires_grad = False
参考链接
[1] pytorch. torch.nn.utils.weight_norm [EB/OL]. https://pytorch.org/docs/2.1/generated/torch.nn.utils.weight_norm.html, 2023-xx-xx/2024-12-18.