当前位置: 首页 > article >正文

upsample nearest 临近上采样实现方式

nearest 最近邻居像素值来

实现 原理

最近邻上采样的原理非常简单,它直接取输入图像中最近的像素值作为输出图像中对应位置的像素值。这种方法在放大图像时保持了像素的原始值,不涉及像素值之间的平滑过渡,因此可能会产生明显的锯齿状边缘。

优点

计算量小:最近邻插值法通过直接选取离目标点最近的点的值作为新的插入点的值,计算过程相对简单,计算量小。
速度快:
算法简单:易于理解和实现。

缺点

图像质量差:插值后的图像在边缘处容易出现明显的锯齿现象,图像质量差。
灰度值不连续:由于插值点直接取最近点的值,因此插值后的图像在灰度值上能会出现不连续的情况,导致图像不够自然。
放大效果有限:当需要大幅度放大图像时,简单的复制最近点的值无法有效地恢复图像中的高频信息。
如下面11010的张量二倍nearest 上采样后生成12020的结果

请添加图片描述
代码示例

import torch  
import matplotlib.pyplot as plt
import seaborn as sns

def custom_nearest_interpolate(input_tensor, scale_factor):  
    """  
    对输入张量进行最近邻上采样。  
      
    参数:  
    - input_tensor: 输入张量,形状为 [batch_size, channels, height, width]  
    - scale_factor: 缩放因子,一个整数或整数元组,表示在每个维度上放大多少倍  
      
    返回:  
    - output_tensor: 上采样后的张量  
    """  
    batch_size, channels, height, width = input_tensor.size()  
      
    # 确保scale_factor是整数元组,且长度为2(针对height和width)  
    if isinstance(scale_factor, int):  
        scale_factor = (scale_factor, scale_factor)  
      
    new_height, new_width = int(height * scale_factor[0]), int(width * scale_factor[1])  
      
    # 创建一个空的输出张量,并用0填充(这里也可以用其他值填充,但最近邻插值通常不需要)  
    # 注意:我们实际上不需要用0填充,因为我们会直接从上采样中复制值  
    # 但为了与PyTorch的行为一致(虽然它内部不这样做),我们还是创建一个形状正确的张量  
    output_tensor = torch.zeros(batch_size, channels, new_height, new_width, dtype=input_tensor.dtype, device=input_tensor.device)  
      
    # 执行最近邻上采样  
    for b in range(batch_size):  
        for c in range(channels):  
            for y in range(new_height):  
                for x in range(new_width):  
                    # 计算原始图像中的对应位置(向下取整)  
                    orig_y = y // scale_factor[0]  
                    orig_x = x // scale_factor[1]  
                      
                    # 注意边界情况,这里简单处理为边缘复制(也可以选择其他方式,如填充)  
                    orig_y = min(orig_y, height - 1)  
                    orig_x = min(orig_x, width - 1)  
                      
                    # 复制值  
                    output_tensor[b, c, y, x] = input_tensor[b, c, orig_y, orig_x]  
      
    return output_tensor  
  
# 示例使用  
input_tensor = torch.randn(1, 1, 10, 10)



# 提取张量的最后两个维度
tensor_2d = input_tensor[0, 0].numpy()  # 转换为 NumPy 数组以便绘图

# 使用 seaborn 的 heatmap 函数绘制热力图
sns.set()  # 设置 seaborn 的默认样式
plt.figure(figsize=(10, 8))  # 设置图形大小
sns.heatmap(tensor_2d, annot=True, fmt=".2f", cmap='coolwarm', cbar=True,
            xticklabels=False, yticklabels=False, square=True)

plt.title('Tensor Shape Visualization (20x20)')
plt.xlabel('Width')
plt.ylabel('Height')

plt.show()




output_tensor = custom_nearest_interpolate(input_tensor, scale_factor=2)  
print(output_tensor.shape)  # 应该输出 torch.Size([1, 1, 20, 20])

# 获取张量的形状
shape = output_tensor.shape

# 创建一个条形图来表示张量的每个维度大小
plt.bar(range(len(shape)), shape)
plt.xlabel('Dimension')
plt.ylabel('Size')
plt.title('Tensor Shape')
plt.show()
# 提取张量的最后两个维度
tensor_2d = output_tensor[0, 0]

# 绘制点阵
plt.imshow(tensor_2d, cmap='gray', interpolation='nearest')
plt.colorbar()  # 显示颜色条
plt.title('Tensor Shape Visualization (20x20)')
plt.xlabel('Width')
plt.ylabel('Height')
plt.xticks([])  # 隐藏坐标轴刻度
plt.yticks([])
plt.show()

# 提取张量的最后两个维度
tensor_2d = output_tensor[0, 0].numpy()  # 转换为 NumPy 数组以便绘图

# 使用 seaborn 的 heatmap 函数绘制热力图
sns.set()  # 设置 seaborn 的默认样式
plt.figure(figsize=(15, 12))  # 设置图形大小
sns.heatmap(tensor_2d, annot=True, fmt=".2f", cmap='coolwarm', cbar=True,
            xticklabels=False, yticklabels=False, square=True)

plt.title('Tensor Shape Visualization (20x20)')
plt.xlabel('Width')
plt.ylabel('Height')

plt.show()

http://www.kler.cn/a/325624.html

相关文章:

  • 架构师:使用 Atomix 实现分布式协调服务的技术指南
  • Excel如何把两列数据合并成一列,4种方法
  • Java 核心技术卷 I 学习记录八
  • AJAX笔记 (速通精华版)
  • Jaskson处理复杂的泛型对象
  • day-17 反转字符串中的单词
  • Python: RAII:函数执行完毕,socket对象主动发送fin
  • golang Get: context deadline exceeded (Client.Timeout exceeded while aw
  • 第四届机器人、自动化与智能控制国际会议(ICRAIC 2024)征稿
  • Python 学习之生成图形验证码
  • 谷神后端$vs.proc.invoke.stock.loadMap
  • golang语法基础
  • 【大数据应用开发】2023年全国职业院校技能大赛赛题第01套
  • 基于php的助农生鲜销售系统
  • vmware 操作系统安装
  • 常见框架漏洞复现
  • IT运维挑战与对策:构建高效一体化运维管理体系
  • Chapter 2 - 1. Understanding Congestion in Fibre Channel Fabrics
  • Redis: RDB与AOF的选择和容灾备份以及Redis数据持久化的优化方案
  • X86架构(九)——保护模式的进入
  • Hive数仓操作(三)
  • 使用Vue.extend( ) 模仿 elementui 创建一个类似 message 消息提示框
  • AI大模型之旅-最强开源文生图工具Stable Diffusion WebUI 教程
  • Safari 浏览器中的 <audio> 标签的控件无效 - 解决方法
  • linux信号 | 学习信号三步走 | 全解析信号的产生方式
  • 数据结构双链表和循环链表