当前位置: 首页 > article >正文

医学图像分割任务的测试代码

测试集进行测试

import os
import torch
import numpy as np
from torch.utils.data import DataLoader
from sklearn.metrics import (
    precision_score,
    recall_score,
    f1_score,
    roc_curve,
    auc,
    confusion_matrix,
)
import matplotlib.pyplot as plt
from utils import NiiDataset
from model.UNet import UNet

# 加载最佳模型
best_unet_model = r"D:\PytnonProject\Segment\best_unet_model.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(in_channels=1, out_channels=1).to(device)
model.load_state_dict(torch.load(best_unet_model))
model.eval()  # 设置为评估模式

# 定义测试数据集
test_image_paths = [
    r"D:\Data\DegmentData\OriginalNii\DCE\CAO_ZHAN_GUO.nii",  # 替换为实际的测试图像路径
    r"D:\Data\DegmentData\OriginalNii\DCE\CHAI_GUI_LAN.nii",
    # 添加其他测试数据路径...
]

test_mask_paths = [
    r"D:\Data\DegmentData\ROI\CAO_ZHAN_GUO-label.nii",  # 替换为实际的测试掩码路径
    r"D:\Data\DegmentData\ROI\CHAI_GUI_LAN-label.nii",
    # 添加其他测试掩码路径...
]

# 创建测试数据集和数据加载器
test_dataset = NiiDataset(test_image_paths, test_mask_paths)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)  # 批量大小为 1

# 定义评估指标
def dice_coefficient(y_true, y_pred):
    intersection = np.sum(y_true * y_pred)
    denominator = np.sum(y_true) + np.sum(y_pred)
    if denominator == 0:
        return 1.0  # 如果分母为零,返回 1(表示完全匹配)
    return (2.0 * intersection) / denominator

def iou(y_true, y_pred):
    intersection = np.sum(y_true * y_pred)
    union = np.sum(y_true) + np.sum(y_pred) - intersection
    if union == 0:
        return 1.0  # 如果分母为零,返回 1(表示完全匹配)
    return intersection / union

# 初始化指标
dice_scores = []
iou_scores = []
precisions = []
recalls = []
f1_scores = []
sensitivities = []
specificities = []
auc_scores = []

# 用于 ROC 曲线的数据
all_masks = []
all_predictions = []
all_probabilities = []

# 测试过程
with torch.no_grad():  # 禁用梯度计算
    for images, masks in test_dataloader:
        images = images.to(device)  # 形状: [1, 1, 480, 480]
        masks = masks.to(device)    # 形状: [1, 1, 480, 480]

        # 前向传播
        outputs = model(images)  # 形状: [1, 1, 480, 480]

        # 将输出转换为概率和二进制掩码
        probabilities = outputs.cpu().numpy().flatten()  # 形状: [480 * 480]
        predictions = (outputs > 0.5).float().cpu().numpy().flatten()  # 形状: [480 * 480]
        masks = masks.cpu().numpy().flatten()  # 形状: [480 * 480]

        # 保存用于 ROC 曲线的数据
        all_masks.extend(masks)
        all_predictions.extend(predictions)
        all_probabilities.extend(probabilities)

        # 检查 masks 和 predictions 是否只包含一个类别
        if np.all(masks == 0) and np.all(predictions == 0):
            # 如果 masks 和 predictions 都为全 0,则跳过该样本
            continue

        # 计算指标
        dice = dice_coefficient(masks, predictions)
        iou_score = iou(masks, predictions)
        precision = precision_score(masks, predictions, zero_division=0)
        recall = recall_score(masks, predictions, zero_division=0)
        f1 = f1_score(masks, predictions, zero_division=0)

        # 计算混淆矩阵
        cm = confusion_matrix(masks, predictions)
        if cm.size == 1:
            # 如果混淆矩阵只有一个值(全 0 或全 1)
            if np.all(masks == 0):
                tn, fp, fn, tp = cm[0, 0], 0, 0, 0
            else:
                tn, fp, fn, tp = 0, 0, 0, cm[0, 0]
        else:
            tn, fp, fn, tp = cm.ravel()

        sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0  # 灵敏度
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0  # 特异度

        # 保存指标
        dice_scores.append(dice)
        iou_scores.append(iou_score)
        precisions.append(precision)
        recalls.append(recall)
        f1_scores.append(f1)
        sensitivities.append(sensitivity)
        specificities.append(specificity)

# 计算 AUC
fpr, tpr, thresholds = roc_curve(all_masks, all_probabilities)
roc_auc = auc(fpr, tpr)
auc_scores.append(roc_auc)

# 打印平均指标
print(f"Average Dice Coefficient: {np.mean(dice_scores):.4f}")
print(f"Average IoU: {np.mean(iou_scores):.4f}")
print(f"Average Precision: {np.mean(precisions):.4f}")
print(f"Average Recall: {np.mean(recalls):.4f}")
print(f"Average F1 Score: {np.mean(f1_scores):.4f}")
print(f"Average Sensitivity: {np.mean(sensitivities):.4f}")
print(f"Average Specificity: {np.mean(specificities):.4f}")
print(f"Average AUC: {np.mean(auc_scores):.4f}")

# 创建 checkpoint 文件夹(如果不存在)
checkpoint_dir = "checkpoint"
os.makedirs(checkpoint_dir, exist_ok=True)

# 保存 ROC 曲线的数据到 checkpoint 文件夹下的 .npz 文件
roc_data_path = os.path.join(checkpoint_dir, "roc_data.npz")
np.savez(roc_data_path, fpr=fpr, tpr=tpr, thresholds=thresholds, roc_auc=roc_auc)
print(f"ROC 曲线的数据已保存到文件: {roc_data_path}")

# 绘制 ROC 曲线
plt.figure()
plt.plot(fpr, tpr, color="darkorange", lw=2, label=f"ROC curve (AUC = {roc_auc:.2f})")
plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("Receiver Operating Characteristic (ROC) Curve")
plt.legend(loc="lower right")
plt.show()


http://www.kler.cn/a/529455.html

相关文章:

  • C++ 中的类(class)和对象(object)
  • nginx目录结构和配置文件
  • 创新创业计划书|建筑垃圾资源化回收
  • C++ 堆栈分配的区别
  • 快速提升网站收录:避免常见SEO误区
  • Android Studio 正式版 10 周年回顾,承载 Androider 的峥嵘十年
  • C语言中的线程本地变量
  • 无用知识之:std::initializer_list的秘密
  • 【Java源码】基于SpringBoot+小程序的电影购票选座系统
  • vue入门到实战 二
  • 实战技巧:如何快速提高网站收录的多样性?
  • Baklib在企业知识管理领域的领先地位与三款竞品的深度剖析
  • 函数与递归
  • vue2和vue3路由封装及区别
  • 蛇年说蛇,革旧图新
  • VSCode插件HTML CSS Support
  • MyBatis-Plus笔记-快速入门
  • 于动态规划的启幕之章,借 C++ 笔触绘就算法新篇
  • 深度学习模型在汽车自动驾驶领域的应用
  • 二叉树——429,515,116
  • 031.关于后续更新和指纹浏览器成品
  • HTB:Alert[WriteUP]
  • 实现C语言的原子操作
  • 【机器学习】自定义数据集,使用scikit-learn 中K均值包 进行聚类
  • 第12章:基于TransUnet和SwinUnet网络实现的医学图像语义分割:腹部13器官分割(网页推理)
  • 成绩案例demo