使用本地预训练模型可视化卷积每一部分
整体代码如下:
# 加载图像并转换为数组的形式,调整大小
def figure_trans():
# 加载图像并转换为数组的形式,调整大小
img_path = "E:/IRTC/ceshi/img.png"
image = Image.open(img_path)
transform = transforms.ToTensor()
tensor_image = transform(image)
transform_size = transforms.Resize((512, 512))
out = transform_size(tensor_image)
tensor_image_batch = out.unsqueeze(0)
return tensor_image_batch
# 加载模型并绘制模型
def yml_plot():
# yaml_file_path = 'F:/code/DEKR-main/experiments/coco/w32/w32_4x_reg03_bs10_512_adam_lr1e-3_coco_x140.yaml'
# yaml_file = Path(yaml_file_path)
tensor_image_batch = figure_trans()
# if yaml_file.exists():
# with yaml_file.open('r') as file:
# cfg_dict = yaml.safe_load(file)
# cfg = OmegaConf.create(cfg_dict)
out_img3 = pose(pretrained=True)
# out_img3 = PoseHigherResolutionNet(cfg)
# input = torch.randn(1, 3, 512, 512)
out_img3 = out_img3.forward(tensor_image_batch)
# 选择要可视化的通道索引
selected_channels = [1, 15, 30, 45, 63] # 例如,选择第0、15、30、45和63个通道
# 创建一个包含多个子图的图形
fig, axes = plt.subplots(1, len(selected_channels), figsize=(15, 5))
# 遍历选定的通道并绘制它们
for i, ax in enumerate(axes):
channel_image = out_img3[selected_channels[i]] # 提取选定通道的图像
print(channel_image.shape)
out_img1 = transforms.ToPILImage()(channel_image)
ax.imshow(out_img1, cmap='viridis') # 使用 viridis 颜色图显示图像
ax.set_title(f'Channel {selected_channels[i]}') # 设置子图的标题
ax.axis('off') # 关闭坐标轴
# 显示图形
plt.tight_layout() # 调整子图参数, 使之填充整个图像区域
plt.show()
# print(pose.forward(input))
# else:
# print(f"Error: The YAML file {yaml_file_path} does not exist.")
model_urls = {
'pose': 'F:/code/DEKR-main/model/pose_coco/pose_dekr_hrnetw32_coco.pth',
}
# 加载本地预训练模型并返回模型
def pose(pretrained=False):
yaml_file_path = 'F:/code/DEKR-main/experiments/coco/w32/w32_4x_reg03_bs10_512_adam_lr1e-3_coco_x140.yaml'
yaml_file = Path(yaml_file_path)
if yaml_file.exists():
with yaml_file.open('r') as file:
cfg_dict = yaml.safe_load(file)
cfg = OmegaConf.create(cfg_dict)
model = PoseHigherResolutionNet(cfg)
if pretrained:
# 加载本地预训练模型
model_path = model_urls['pose']
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) # 如果你在GPU上运行,可能不需要map_location
return model
plt_figure = yml_plot()
这里新加的就是加载本地训练好的模型文件pth
model_urls = {
'pose': 'F:/code/DEKR-main/model/pose_coco/pose_dekr_hrnetw32_coco.pth',
}
并加载进来
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) # 如果你在GPU上运行,可能不需要map_location