upsample nearest 临近上采样实现方式
nearest 最近邻居像素值来
实现 原理
最近邻上采样的原理非常简单,它直接取输入图像中最近的像素值作为输出图像中对应位置的像素值。这种方法在放大图像时保持了像素的原始值,不涉及像素值之间的平滑过渡,因此可能会产生明显的锯齿状边缘。
优点
计算量小:最近邻插值法通过直接选取离目标点最近的点的值作为新的插入点的值,计算过程相对简单,计算量小。
速度快:
算法简单:易于理解和实现。
缺点
图像质量差:插值后的图像在边缘处容易出现明显的锯齿现象,图像质量差。
灰度值不连续:由于插值点直接取最近点的值,因此插值后的图像在灰度值上能会出现不连续的情况,导致图像不够自然。
放大效果有限:当需要大幅度放大图像时,简单的复制最近点的值无法有效地恢复图像中的高频信息。
如下面11010的张量二倍nearest 上采样后生成12020的结果
代码示例
import torch
import matplotlib.pyplot as plt
import seaborn as sns
def custom_nearest_interpolate(input_tensor, scale_factor):
"""
对输入张量进行最近邻上采样。
参数:
- input_tensor: 输入张量,形状为 [batch_size, channels, height, width]
- scale_factor: 缩放因子,一个整数或整数元组,表示在每个维度上放大多少倍
返回:
- output_tensor: 上采样后的张量
"""
batch_size, channels, height, width = input_tensor.size()
# 确保scale_factor是整数元组,且长度为2(针对height和width)
if isinstance(scale_factor, int):
scale_factor = (scale_factor, scale_factor)
new_height, new_width = int(height * scale_factor[0]), int(width * scale_factor[1])
# 创建一个空的输出张量,并用0填充(这里也可以用其他值填充,但最近邻插值通常不需要)
# 注意:我们实际上不需要用0填充,因为我们会直接从上采样中复制值
# 但为了与PyTorch的行为一致(虽然它内部不这样做),我们还是创建一个形状正确的张量
output_tensor = torch.zeros(batch_size, channels, new_height, new_width, dtype=input_tensor.dtype, device=input_tensor.device)
# 执行最近邻上采样
for b in range(batch_size):
for c in range(channels):
for y in range(new_height):
for x in range(new_width):
# 计算原始图像中的对应位置(向下取整)
orig_y = y // scale_factor[0]
orig_x = x // scale_factor[1]
# 注意边界情况,这里简单处理为边缘复制(也可以选择其他方式,如填充)
orig_y = min(orig_y, height - 1)
orig_x = min(orig_x, width - 1)
# 复制值
output_tensor[b, c, y, x] = input_tensor[b, c, orig_y, orig_x]
return output_tensor
# 示例使用
input_tensor = torch.randn(1, 1, 10, 10)
# 提取张量的最后两个维度
tensor_2d = input_tensor[0, 0].numpy() # 转换为 NumPy 数组以便绘图
# 使用 seaborn 的 heatmap 函数绘制热力图
sns.set() # 设置 seaborn 的默认样式
plt.figure(figsize=(10, 8)) # 设置图形大小
sns.heatmap(tensor_2d, annot=True, fmt=".2f", cmap='coolwarm', cbar=True,
xticklabels=False, yticklabels=False, square=True)
plt.title('Tensor Shape Visualization (20x20)')
plt.xlabel('Width')
plt.ylabel('Height')
plt.show()
output_tensor = custom_nearest_interpolate(input_tensor, scale_factor=2)
print(output_tensor.shape) # 应该输出 torch.Size([1, 1, 20, 20])
# 获取张量的形状
shape = output_tensor.shape
# 创建一个条形图来表示张量的每个维度大小
plt.bar(range(len(shape)), shape)
plt.xlabel('Dimension')
plt.ylabel('Size')
plt.title('Tensor Shape')
plt.show()
# 提取张量的最后两个维度
tensor_2d = output_tensor[0, 0]
# 绘制点阵
plt.imshow(tensor_2d, cmap='gray', interpolation='nearest')
plt.colorbar() # 显示颜色条
plt.title('Tensor Shape Visualization (20x20)')
plt.xlabel('Width')
plt.ylabel('Height')
plt.xticks([]) # 隐藏坐标轴刻度
plt.yticks([])
plt.show()
# 提取张量的最后两个维度
tensor_2d = output_tensor[0, 0].numpy() # 转换为 NumPy 数组以便绘图
# 使用 seaborn 的 heatmap 函数绘制热力图
sns.set() # 设置 seaborn 的默认样式
plt.figure(figsize=(15, 12)) # 设置图形大小
sns.heatmap(tensor_2d, annot=True, fmt=".2f", cmap='coolwarm', cbar=True,
xticklabels=False, yticklabels=False, square=True)
plt.title('Tensor Shape Visualization (20x20)')
plt.xlabel('Width')
plt.ylabel('Height')
plt.show()