如何解决小尺寸图像分割中的样本不均衡问题
1. 生成对抗数据增强(Copy-Paste Augmentation)
原理:将稀有目标的像素块复制粘贴到其他图像中,低成本生成平衡数据。
适用场景:小目标(如车辆、船只)或极端稀疏类别(如灾害损毁区域)。
PyTorch 实现:
import random
def copy_paste_augment(image, mask, paste_image, paste_mask):
# 从粘贴数据中随机选择一个目标实例
obj_ids = torch.unique(paste_mask)
obj_id = random.choice(obj_ids[1:]) # 跳过背景0
obj_mask = (paste_mask == obj_id)
# 随机选择粘贴位置
h, w = image.shape[-2:]
paste_h, paste_w = obj_mask.sum(dim=1).max(), obj_mask.sum(dim=2).max()
x = random.randint(0, w - paste_w)
y = random.randint(0, h - paste_h)
# 将目标粘贴到主图像
image[:, y:y+paste_h, x:x+paste_w] = paste_image * obj_mask + image[:, y:y+paste_h, x:x+paste_w] * (~obj_mask)
mask[y:y+paste_h, x:x+paste_w] = paste_mask * obj_mask + mask[y:y+paste_h, x:x+paste_w] * (~obj_mask)
return image, mask
# 使用示例
image, mask = next(iter(dataloader)) # 主图像
paste_image, paste_mask = next(iter(paste_loader)) # 粘贴源
aug_image, aug_mask = copy_paste_augment(image, mask, paste_image, paste_mask)
2. 自监督预训练(Self-Supervised Pretraining)
原理:利用无标签数据预训练模型,增强特征提取能力,缓解小样本学习压力。
适用场景:标注成本高、有大量未标注遥感数据的场景。
工具推荐:使用 lightly
库实现自监督对比学习:
# 安装:pip install lightly
from lightly.models.modules import SimCLRProjectionHead
from lightly.loss import NTXentLoss
class SelfSupervisedModel(nn.Module):
def __init__(self, backbone):
super().__init__()
self.backbone = backbone # 例如 ResNet-18
self.projection = SimCLRProjectionHead(512, 512, 128)
def forward(self, x):
features = self.backbone(x).flatten(1)
return self.projection(features)
# 对比学习训练
model = SelfSupervisedModel(backbone)
criterion = NTXentLoss(temperature=0.5)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
for images, _ in dataloader: # 无需标签
images = torch.cat(images, dim=0) # 假设已应用两种增强
features = model(images)
loss = criterion(features)
loss.backward()
optimizer.step()
3. 动态类别权重(Class-Balanced Loss)
原理:根据每个 batch 的实时类别分布动态调整损失权重,适应局部不均衡。
代码实现:
class DynamicWeightedCE(nn.Module):
def __init__(self, num_classes, beta=0.9):
super().__init__()
self.num_classes = num_classes
self.beta = beta # 平滑系数
self.running_counts = torch.zeros(num_classes)
def forward(self, inputs, targets):
# 统计当前batch的类别频率
batch_counts = torch.bincount(targets.flatten(), minlength=self.num_classes).float()
self.running_counts = self.beta * self.running_counts + (1 - self.beta) * batch_counts
# 计算动态权重(频率倒数)
weights = 1.0 / (self.running_counts + 1e-6)
weights = weights / weights.sum() * self.num_classes # 归一化
# 加权交叉熵
loss = F.cross_entropy(inputs, targets, weight=weights.to(inputs.device))
return loss
# 使用示例
criterion = DynamicWeightedCE(num_classes=5)
4. Transformer + 像素对比学习(Pixel-wise Contrastive Learning)
原理:利用 Transformer 的长距离建模能力,结合像素级对比学习增强边界区分。
代码框架(使用 timm
和自定义模块):
import timm
from einops import rearrange
class PixelContrastiveHead(nn.Module):
def __init__(self, in_channels, proj_dim=128):
super().__init__()
self.proj = nn.Conv2d(in_channels, proj_dim, kernel_size=1)
def forward(self, x):
return F.normalize(self.proj(x), p=2, dim=1)
# 模型定义
encoder = timm.create_model("vit_small_patch16_224", pretrained=True, num_classes=0)
contrast_head = PixelContrastiveHead(encoder.embed_dim)
seg_head = nn.Conv2d(encoder.embed_dim, num_classes, kernel_size=1)
# 对比损失(基于像素相似度)
def pixel_contrast_loss(feats, mask, temperature=0.1):
feats = rearrange(feats, "b c h w -> (b h w) c")
mask = rearrange(mask, "b h w -> (b h w)")
# 同类别像素为正样本对
same_class = (mask.unsqueeze(0) == mask.unsqueeze(1)) # [N, N]
similarity = torch.matmul(feats, feats.t()) / temperature # [N, N]
# 排除自身对比
same_class = same_class.fill_diagonal_(False)
pos_pairs = similarity[same_class]
neg_pairs = similarity[~same_class]
loss = -torch.log(torch.exp(pos_pairs).sum() / (torch.exp(neg_pairs).sum() + 1e-6))
return loss
# 训练时联合优化分割和对比损失
feats = encoder(images) # [B, C, H, W]
seg_loss = F.cross_entropy(seg_head(feats), masks)
contrast_loss = pixel_contrast_loss(contrast_head(feats), masks)
total_loss = seg_loss + 0.5 * contrast_loss
5. 不确定性加权多任务学习(Uncertainty Weighting)
原理:自动学习各损失函数的权重,平衡不同任务(如分割、边界检测)的贡献。
实现代码:
class UncertaintyWeightedLoss(nn.Module):
def __init__(self, num_tasks=2):
super().__init__()
self.log_vars = nn.Parameter(torch.zeros(num_tasks))
def forward(self, *losses):
loss_sum = 0
for i, loss in enumerate(losses):
precision = torch.exp(-self.log_vars[i])
loss_sum += precision * loss + self.log_vars[i]
return loss_sum
# 示例:联合分割损失和边界损失
seg_criterion = DiceLoss()
edge_criterion = nn.BCEWithLogitsLoss()
weight_module = UncertaintyWeightedLoss(num_tasks=2)
# 计算各任务损失
seg_loss = seg_criterion(outputs, masks)
edges = canny_edge_detector(masks) # 假设已提取边界
edge_loss = edge_criterion(edge_pred, edges)
# 自动加权总损失
total_loss = weight_module(seg_loss, edge_loss)
6. 在线困难样本挖掘(OHEM, Online Hard Example Mining)
原理:在训练过程中动态筛选对模型当前最难分的样本,针对性加强学习。
PyTorch 自定义实现:
class OHEMLoss(nn.Module):
def __init__(self, criterion, hard_ratio=0.3):
super().__init__()
self.criterion = criterion # 基础损失函数
self.hard_ratio = hard_ratio
def forward(self, inputs, targets):
batch_loss = self.criterion(inputs, targets, reduction="none") # [B, H, W]
# 按像素损失排序,选择前 K% 的困难样本
num_pixels = batch_loss.numel()
k = int(num_pixels * self.hard_ratio)
hard_loss, _ = torch.topk(batch_loss.view(-1), k)
return hard_loss.mean()
# 使用示例(结合交叉熵)
base_loss = nn.CrossEntropyLoss(reduction="none")
ohem_loss = OHEMLoss(base_loss, hard_ratio=0.3)
工具库推荐总结
方法 | 推荐工具库 | 关键优势 |
---|---|---|
生成对抗增强 | 自定义实现(无需额外库) | 低成本生成逼真样本 |
自监督预训练 | lightly | 支持多种对比学习算法 |
Transformer 模型 | timm | 提供预训练 Vision Transformer 模型 |
多任务不确定性加权 | 自定义实现 | 端到端自动平衡多任务 |
在线困难样本挖掘 | 自定义实现 | 动态关注难分样本,无需额外标注 |