当前位置: 首页 > article >正文

使用本地预训练模型可视化卷积每一部分

 

 

整体代码如下:

# 加载图像并转换为数组的形式,调整大小
def figure_trans():
    # 加载图像并转换为数组的形式,调整大小
    img_path = "E:/IRTC/ceshi/img.png"
    image = Image.open(img_path)
    transform = transforms.ToTensor()
    tensor_image = transform(image)

    transform_size = transforms.Resize((512, 512))
    out = transform_size(tensor_image)
    tensor_image_batch = out.unsqueeze(0)

    return tensor_image_batch

# 加载模型并绘制模型
def yml_plot():
    # yaml_file_path = 'F:/code/DEKR-main/experiments/coco/w32/w32_4x_reg03_bs10_512_adam_lr1e-3_coco_x140.yaml'
    # yaml_file = Path(yaml_file_path)
    tensor_image_batch = figure_trans()

    # if yaml_file.exists():
    #     with yaml_file.open('r') as file:
    #         cfg_dict = yaml.safe_load(file)
    #         cfg = OmegaConf.create(cfg_dict)
    out_img3 = pose(pretrained=True)
    # out_img3 = PoseHigherResolutionNet(cfg)

    # input = torch.randn(1, 3, 512, 512)
    out_img3 = out_img3.forward(tensor_image_batch)
    # 选择要可视化的通道索引
    selected_channels = [1, 15, 30, 45, 63]  # 例如,选择第0、15、30、45和63个通道

    # 创建一个包含多个子图的图形
    fig, axes = plt.subplots(1, len(selected_channels), figsize=(15, 5))

    # 遍历选定的通道并绘制它们
    for i, ax in enumerate(axes):
        channel_image = out_img3[selected_channels[i]]  # 提取选定通道的图像
        print(channel_image.shape)
        out_img1 = transforms.ToPILImage()(channel_image)
        ax.imshow(out_img1, cmap='viridis')  # 使用 viridis 颜色图显示图像
        ax.set_title(f'Channel {selected_channels[i]}')  # 设置子图的标题
        ax.axis('off')  # 关闭坐标轴

    # 显示图形
    plt.tight_layout()  # 调整子图参数, 使之填充整个图像区域
    plt.show()
    # print(pose.forward(input))

    # else:
    #     print(f"Error: The YAML file {yaml_file_path} does not exist.")

model_urls = {
    'pose': 'F:/code/DEKR-main/model/pose_coco/pose_dekr_hrnetw32_coco.pth',
}
# 加载本地预训练模型并返回模型
def pose(pretrained=False):
    yaml_file_path = 'F:/code/DEKR-main/experiments/coco/w32/w32_4x_reg03_bs10_512_adam_lr1e-3_coco_x140.yaml'
    yaml_file = Path(yaml_file_path)
    if yaml_file.exists():
        with yaml_file.open('r') as file:
            cfg_dict = yaml.safe_load(file)
            cfg = OmegaConf.create(cfg_dict)
            model = PoseHigherResolutionNet(cfg)
    if pretrained:
        # 加载本地预训练模型
        model_path = model_urls['pose']
        model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))  # 如果你在GPU上运行,可能不需要map_location
    return model


plt_figure = yml_plot()

这里新加的就是加载本地训练好的模型文件pth

model_urls = {
    'pose': 'F:/code/DEKR-main/model/pose_coco/pose_dekr_hrnetw32_coco.pth',
}

并加载进来

model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))  # 如果你在GPU上运行,可能不需要map_location


http://www.kler.cn/a/292503.html

相关文章:

  • 容器技术在持续集成与持续交付中的应用
  • 【论文阅读】Virtual Compiler Is All You Need For Assembly Code Search
  • 【vue3中el-table表格高度自适应】
  • 成都睿明智科技有限公司解锁抖音电商新玩法
  • Java面向对象高级2
  • androidstudio下载gradle慢
  • 基于自适应多变量超扭转的 Lyapunov 重新设计 RLV 姿态控制
  • Facebook与区块链:构建更安全的社交网络生态
  • MFC dll无法显示tooltip问题
  • Java-数据结构-链表-习题(三)(๑´ㅂ`๑)
  • java开发简历详解
  • Dubbo缓存
  • HTML 基础知识详解与代码示例
  • C++笔记16•数据结构:关联式容器map和set•
  • Java健康养老智慧相伴养老护理小程序系统源码代办陪诊陪护更安心
  • 阿里云身份证二要素详细使用
  • 第T2周:彩色图片分类
  • 828华为云征文|基于华为云Flexus云服务器X搭建jumpserver堡垒机软件
  • 电子病历、开药发药、住院检查、会员管理,SaaS模式B/S架构的云医院管理系统源码,云计算技术的医疗信息化解决方案
  • 经验笔记:Feeds流设计与实现
  • SpringMVC 第一次复学笔记
  • 解决 EasyExcel BigDecimal 加%的问题
  • 请解释Java Web中的Filter的作用和使用场景。什么是Java Web中的JSP?请解释其与Servlet的关系及各自优势。
  • OPC DA
  • 2024数学建模国赛高教社杯C题:农作物的种植策略 思路代码文章助攻手把手保姆级
  • 编程秘密武器:提升工作效率的关键工具