当前位置: 首页 > article >正文

如何解决小尺寸图像分割中的样本不均衡问题


1. 生成对抗数据增强(Copy-Paste Augmentation)

原理:将稀有目标的像素块复制粘贴到其他图像中,低成本生成平衡数据。
适用场景:小目标(如车辆、船只)或极端稀疏类别(如灾害损毁区域)。
PyTorch 实现

import random

def copy_paste_augment(image, mask, paste_image, paste_mask):
    # 从粘贴数据中随机选择一个目标实例
    obj_ids = torch.unique(paste_mask)
    obj_id = random.choice(obj_ids[1:])  # 跳过背景0
    obj_mask = (paste_mask == obj_id)
    
    # 随机选择粘贴位置
    h, w = image.shape[-2:]
    paste_h, paste_w = obj_mask.sum(dim=1).max(), obj_mask.sum(dim=2).max()
    x = random.randint(0, w - paste_w)
    y = random.randint(0, h - paste_h)
    
    # 将目标粘贴到主图像
    image[:, y:y+paste_h, x:x+paste_w] = paste_image * obj_mask + image[:, y:y+paste_h, x:x+paste_w] * (~obj_mask)
    mask[y:y+paste_h, x:x+paste_w] = paste_mask * obj_mask + mask[y:y+paste_h, x:x+paste_w] * (~obj_mask)
    return image, mask

# 使用示例
image, mask = next(iter(dataloader))  # 主图像
paste_image, paste_mask = next(iter(paste_loader))  # 粘贴源
aug_image, aug_mask = copy_paste_augment(image, mask, paste_image, paste_mask)

2. 自监督预训练(Self-Supervised Pretraining)

原理:利用无标签数据预训练模型,增强特征提取能力,缓解小样本学习压力。
适用场景:标注成本高、有大量未标注遥感数据的场景。
工具推荐:使用 lightly 库实现自监督对比学习:

# 安装:pip install lightly
from lightly.models.modules import SimCLRProjectionHead
from lightly.loss import NTXentLoss

class SelfSupervisedModel(nn.Module):
    def __init__(self, backbone):
        super().__init__()
        self.backbone = backbone  # 例如 ResNet-18
        self.projection = SimCLRProjectionHead(512, 512, 128)
    
    def forward(self, x):
        features = self.backbone(x).flatten(1)
        return self.projection(features)

# 对比学习训练
model = SelfSupervisedModel(backbone)
criterion = NTXentLoss(temperature=0.5)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

for images, _ in dataloader:  # 无需标签
    images = torch.cat(images, dim=0)  # 假设已应用两种增强
    features = model(images)
    loss = criterion(features)
    loss.backward()
    optimizer.step()

3. 动态类别权重(Class-Balanced Loss)

原理:根据每个 batch 的实时类别分布动态调整损失权重,适应局部不均衡。
代码实现

class DynamicWeightedCE(nn.Module):
    def __init__(self, num_classes, beta=0.9):
        super().__init__()
        self.num_classes = num_classes
        self.beta = beta  # 平滑系数
        self.running_counts = torch.zeros(num_classes)
    
    def forward(self, inputs, targets):
        # 统计当前batch的类别频率
        batch_counts = torch.bincount(targets.flatten(), minlength=self.num_classes).float()
        self.running_counts = self.beta * self.running_counts + (1 - self.beta) * batch_counts
        
        # 计算动态权重(频率倒数)
        weights = 1.0 / (self.running_counts + 1e-6)
        weights = weights / weights.sum() * self.num_classes  # 归一化
        
        # 加权交叉熵
        loss = F.cross_entropy(inputs, targets, weight=weights.to(inputs.device))
        return loss

# 使用示例
criterion = DynamicWeightedCE(num_classes=5)

4. Transformer + 像素对比学习(Pixel-wise Contrastive Learning)

原理:利用 Transformer 的长距离建模能力,结合像素级对比学习增强边界区分。
代码框架(使用 timm 和自定义模块):

import timm
from einops import rearrange

class PixelContrastiveHead(nn.Module):
    def __init__(self, in_channels, proj_dim=128):
        super().__init__()
        self.proj = nn.Conv2d(in_channels, proj_dim, kernel_size=1)
    
    def forward(self, x):
        return F.normalize(self.proj(x), p=2, dim=1)

# 模型定义
encoder = timm.create_model("vit_small_patch16_224", pretrained=True, num_classes=0)
contrast_head = PixelContrastiveHead(encoder.embed_dim)
seg_head = nn.Conv2d(encoder.embed_dim, num_classes, kernel_size=1)

# 对比损失(基于像素相似度)
def pixel_contrast_loss(feats, mask, temperature=0.1):
    feats = rearrange(feats, "b c h w -> (b h w) c")
    mask = rearrange(mask, "b h w -> (b h w)")
    
    # 同类别像素为正样本对
    same_class = (mask.unsqueeze(0) == mask.unsqueeze(1))  # [N, N]
    similarity = torch.matmul(feats, feats.t()) / temperature  # [N, N]
    
    # 排除自身对比
    same_class = same_class.fill_diagonal_(False)
    pos_pairs = similarity[same_class]
    neg_pairs = similarity[~same_class]
    
    loss = -torch.log(torch.exp(pos_pairs).sum() / (torch.exp(neg_pairs).sum() + 1e-6))
    return loss

# 训练时联合优化分割和对比损失
feats = encoder(images)  # [B, C, H, W]
seg_loss = F.cross_entropy(seg_head(feats), masks)
contrast_loss = pixel_contrast_loss(contrast_head(feats), masks)
total_loss = seg_loss + 0.5 * contrast_loss

5. 不确定性加权多任务学习(Uncertainty Weighting)

原理:自动学习各损失函数的权重,平衡不同任务(如分割、边界检测)的贡献。
实现代码

class UncertaintyWeightedLoss(nn.Module):
    def __init__(self, num_tasks=2):
        super().__init__()
        self.log_vars = nn.Parameter(torch.zeros(num_tasks))
    
    def forward(self, *losses):
        loss_sum = 0
        for i, loss in enumerate(losses):
            precision = torch.exp(-self.log_vars[i])
            loss_sum += precision * loss + self.log_vars[i]
        return loss_sum

# 示例:联合分割损失和边界损失
seg_criterion = DiceLoss()
edge_criterion = nn.BCEWithLogitsLoss()
weight_module = UncertaintyWeightedLoss(num_tasks=2)

# 计算各任务损失
seg_loss = seg_criterion(outputs, masks)
edges = canny_edge_detector(masks)  # 假设已提取边界
edge_loss = edge_criterion(edge_pred, edges)

# 自动加权总损失
total_loss = weight_module(seg_loss, edge_loss)

6. 在线困难样本挖掘(OHEM, Online Hard Example Mining)

原理:在训练过程中动态筛选对模型当前最难分的样本,针对性加强学习。
PyTorch 自定义实现

class OHEMLoss(nn.Module):
    def __init__(self, criterion, hard_ratio=0.3):
        super().__init__()
        self.criterion = criterion  # 基础损失函数
        self.hard_ratio = hard_ratio
    
    def forward(self, inputs, targets):
        batch_loss = self.criterion(inputs, targets, reduction="none")  # [B, H, W]
        
        # 按像素损失排序,选择前 K% 的困难样本
        num_pixels = batch_loss.numel()
        k = int(num_pixels * self.hard_ratio)
        hard_loss, _ = torch.topk(batch_loss.view(-1), k)
        return hard_loss.mean()

# 使用示例(结合交叉熵)
base_loss = nn.CrossEntropyLoss(reduction="none")
ohem_loss = OHEMLoss(base_loss, hard_ratio=0.3)

工具库推荐总结

方法推荐工具库关键优势
生成对抗增强自定义实现(无需额外库)低成本生成逼真样本
自监督预训练lightly支持多种对比学习算法
Transformer 模型timm提供预训练 Vision Transformer 模型
多任务不确定性加权自定义实现端到端自动平衡多任务
在线困难样本挖掘自定义实现动态关注难分样本,无需额外标注


http://www.kler.cn/a/522739.html

相关文章:

  • wow-agent---task4 MetaGPT初体验
  • ResNeSt: Split-Attention Networks 参考论文
  • 基于dlib/face recognition人脸识别推拉流实现
  • 机器学习(三)
  • Vue 3 + TypeScript 实现父子组件协同工作案例解析
  • Android Studio安装配置
  • 指针的介绍2前
  • 【JavaEE进阶】应用分层
  • 使用Ollama 在Ubuntu运行deepseek大模型:以DeepSeek-coder为例
  • 包管理工具随记
  • 构建1688自动代采系统:PHP开发实战指南
  • 深度学习|表示学习|卷积神经网络|输出维度公式如何理解?|16
  • 宝塔中运行java项目 报权限不足
  • 14-6-2C++STL的list
  • mysql统计每个表行数、大小以及数据库总行数、大小
  • 洛谷题目 P5994 [PA 2014] Kuglarz 题解 (本题较难)
  • 深入浅出 Rust 的强大 match 表达式
  • 怎么样把pdf转成图片模式(不能复制文字)
  • PyCharm介绍
  • 宝塔面板SSL加密访问设置教程
  • 自助设备系统设置——对接POS支付
  • 《程序人生》工作2年感悟
  • 蓝桥杯python语言基础(1)——编程基础
  • (2025 年最新)MacOS Redis Desktop Manager中文版下载,附详细图文
  • 【BQ3568HM开发板】如何在OpenHarmony上通过校园网的上网认证
  • USB鼠标的数据格式