nn.Upsample
nn.Upsample
是 PyTorch 中的一个模块,用于对张量进行上采样(增加空间分辨率)。它常用于图像生成、分割等需要调整张量尺寸的任务中。模块支持多种插值模式来改变分辨率。
基本用法
torch.nn.Upsample(size=None, scale_factor=None, mode='nearest', align_corners=None)
参数说明
-
size
(tuple[int] 或 int, 可选):
目标输出尺寸,指定输出张量的高度和宽度。例如,size=(10, 10)
将输出张量调整为 10x10 的大小。如果指定了size
,则会忽略scale_factor
。 -
scale_factor
(float 或 tuple[float], 可选):
空间维度的倍率。例如,scale_factor=2
表示将输入张量的高度和宽度扩大两倍。 -
mode
(str, 默认值='nearest'
):
插值模式,可选值包括:'nearest'
: 最近邻插值。'linear'
: 一维线性插值(用于 1D 输入)。'bilinear'
: 双线性插值(用于 2D 输入)。'bicubic'
: 双三次插值(用于 2D 输入)。'trilinear'
: 三线性插值(用于 3D 输入)。'area'
: 面积插值(仅用于下采样)。
-
align_corners
(bool, 可选):- 在
'linear'
、'bilinear'
、'bicubic'
和'trilinear'
模式下有效。 - 如果为
True
,输入和输出张量的角像素会严格对齐。 - 如果为
False
(默认值),插值以像素的中心为基准,通常能带来更准确的结果。
- 在
示例用法
使用 scale_factor
import torch
import torch.nn as nn
# 输入张量:[batch, channels, height, width]
input_tensor = torch.randn(1, 3, 4, 4)
# 上采样倍率为 2
upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
output_tensor = upsample(input_tensor)
print(f"输入尺寸: {input_tensor.shape}") # [1, 3, 4, 4]
print(f"输出尺寸: {output_tensor.shape}") # [1, 3, 8, 8]
使用 size
# 调整到固定大小
upsample = nn.Upsample(size=(10, 10), mode='nearest')
output_tensor = upsample(input_tensor)
print(f"输出尺寸: {output_tensor.shape}") # [1, 3, 10, 10]
注意事项
-
Upsample
是 PyTorch 中F.interpolate
函数的简单封装。 -
虽然
Upsample
使用方便,但推荐直接使用F.interpolate
,因为它提供更多的灵活性。例如:
import torch.nn.functional as F
output_tensor = F.interpolate(input_tensor, scale_factor=2, mode='bilinear', align_corners=True)
- 上采样可能会引入伪影。选择合适的
mode
,必要时可进行额外的后处理来减轻伪影影响。