当前位置: 首页 > article >正文

多标签损失之Hamming Loss(PyTorch和sklearn)、Focal Loss、交叉熵和ASL损失

多标签损失

  • 多标签评价指标之Hamming Loss
  • 多标签评价指标之Focal Loss
  • 多标签分类的交叉熵
  • Asymmetric Loss (ASL) 损失

各个损失函数的计算公式,网上有很多文章了,此处就不一一介绍了。

多标签评价指标之Hamming Loss

PyTorch实现的Hamming Loss和sklearn实现的Hamming Loss代码实现对比

import torch
from sklearn.metrics import hamming_loss


def multi_label_classification_hamming_loss(preds, targets):
    """
    计算多标签分类Hamming Loss的函数。
    :param preds: 预测的概率值,大小为 [batch_size, num_classes]
    :param targets: 目标标签值,大小为 [batch_size, num_classes]
    :return: 多标签分类Hamming Loss的值,大小为 [1]
    """
    # 将概率值转换为二进制标签(0或1)
    binary_preds = torch.round(torch.sigmoid(preds))
    # 计算Hamming Loss
    hamming_loss = 1 - (binary_preds == targets).float().mean()
    return hamming_loss


# 定义预测值和目标标签值
preds = torch.tensor([[0.1, 0.9, 0.3],
                      [0.8, 0.2, 0.6],
                      [0.4, 0.5, 0.7]])
targets = torch.tensor([[0, 1, 0],
                        [1, 0, 1],
                        [0, 1, 1]])

# 计算多标签分类Hamming Loss
loss = multi_label_classification_hamming_loss(preds, targets)

# 对比sklearn中的Hamming Loss计算结果
sklearn_loss = hamming_loss(targets.numpy(), torch.round(torch.sigmoid(preds)).numpy(), sample_weight=None)

print("PyTorch实现的Hamming Loss:", loss.item())
print("sklearn实现的Hamming Loss:", sklearn_loss)

输出结果:

PyTorch实现的Hamming Loss: 0.4444444179534912
sklearn实现的Hamming Loss: 0.4444444444444444

使用PyTorch中的torch.sigmoid将预测概率值转换为二进制标签,然后通过比较预测标签与目标标签的不一致情况来计算Hamming Loss。最后,输出PyTorch实现的Hamming Loss和sklearn实现的Hamming Loss两个指标的结果。

多标签评价指标之Focal Loss

定义了一个FocalLoss的类,其中gamma是调节因子,alpha是类别权重。在前向传播时,我们先计算出二元交叉熵损失,并根据该损失计算出每个样本的焦点因子(pt)。然后,我们将pt和交叉熵损失的权重调整后,计算最终的Focal Loss。

以下代码实现Focal Loss的多标签分类

import torch
import torch.nn.functional as F


class FocalLoss(torch.nn.Module):
    def __init__(self, gamma=2, alpha=None):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha

    def forward(self, input, target):
        ce_loss = F.binary_cross_entropy_with_logits(input, target, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = ((1 - pt) ** self.gamma * ce_loss).mean()

        if self.alpha is not None:
            alpha_t = self.alpha[target]
            focal_loss = alpha_t * focal_loss

        return focal_loss


# 定义预测值和目标标签值
preds = torch.tensor([[0.1, 0.9, 0.3],
                      [0.8, 0.2, 0.6],
                      [0.4, 0.5, 0.7]])
targets = torch.tensor([[0, 1, 0],
                        [1, 0, 1],
                        [0, 1, 1]], dtype=torch.float32)

focal_loss = FocalLoss(gamma=2, alpha=None)
loss = focal_loss(preds, targets)

print(loss)

输出结果:

tensor(0.1430)

多标签分类的交叉熵

为了解决这样的数据不平衡主要有两种方法,一种是数据层面上(数据增强,样本的过采样欠采样等)。另一种是从算法层名上对loss操作可以选择加权loss,focalLoss()等,这里面引入到苏剑林大神的文章——将“softmax+交叉熵”推广到多标签分类问题——可以缓解这个数据不平衡问题。

数学原理见详解:将“softmax+交叉熵”推广到多标签分类问题

以下代码实现多标签分类的交叉熵

import torch
import torch.nn as nn


def multilabel_categorical_crossentropy(y_true, y_pred):
    """多标签分类的交叉熵
    说明:y_true和y_pred的shape一致,y_true的元素非0即1,
         1表示对应的类为目标类,0表示对应的类为非目标类。
    警告:请保证y_pred的值域是全体实数,换言之一般情况下y_pred
         不用加激活函数,尤其是不能加sigmoid或者softmax!预测
         阶段则输出y_pred大于0的类。如有疑问,请仔细阅读并理解
         本文。
    """
    y_pred = (1 - 2 * y_true) * y_pred
    y_pred_neg = y_pred - y_true * 1e12
    y_pred_pos = y_pred - (1 - y_true) * 1e12
    zeros = torch.zeros_like(y_pred[..., :1])
    y_pred_neg = torch.cat([y_pred_neg, zeros], axis=-1)
    y_pred_pos = torch.cat([y_pred_pos, zeros], axis=-1)
    neg_loss = torch.logsumexp(y_pred_neg, dim=-1)
    pos_loss = torch.logsumexp(y_pred_pos, dim=-1)
    return neg_loss + pos_loss


# 定义y_true和y_pred
y_pred = torch.tensor([[0.9, 0.1, 0.8, 0.3], [0.3, 0.2, 0.5, 0.7], [0.7, 0.6, 0.1, 0.3]])
y_true = torch.tensor([[1, 0, 1, 0], [0, 0, 1, 1], [1, 0, 0, 0]])

# 计算多标签分类的交叉熵
loss = multilabel_categorical_crossentropy(y_true, y_pred)

print(loss) 

输出结果:

tensor([1.4417, 1.6109, 1.5661])

Asymmetric Loss (ASL) 损失

定义了一个 AsymmetricLoss 类,它包含了 Asymmetric Loss 损失函数的实现。在 init 方法中,我们定义了一些超参数,包括 gamma_pos、gamma_neg 和 eps,以及指定了损失函数的归一化方式 reduction。在 forward 方法中,我们首先根据目标值 targets 来计算正类和负类的权重 pos_weight 和 neg_weight,然后根据公式计算损失值 loss。最后,我们根据 reduction 参数来决定损失值的归一化方式。

PyTorch 实现 Asymmetric Loss 损失函数的多标签分类代码:

import torch
import torch.nn as nn


class AsymmetricLoss(nn.Module):
    def __init__(self, gamma_pos=0.5, gamma_neg=3.0, eps=0.1, reduction='mean'):
        super(AsymmetricLoss, self).__init__()
        self.gamma_pos = gamma_pos
        self.gamma_neg = gamma_neg
        self.eps = eps
        self.reduction = reduction

    def forward(self, inputs, targets):
        pos_weight = targets * (1 - self.gamma_pos) + self.gamma_pos
        neg_weight = (1 - targets) * (1 - self.gamma_neg) + self.gamma_neg
        loss = -pos_weight * targets * torch.log(inputs + self.eps) - neg_weight * (1 - targets) * torch.log(
            1 - inputs + self.eps)

        if self.reduction == 'mean':
            return torch.mean(loss)
        elif self.reduction == 'sum':
            return torch.sum(loss)
        else:
            return loss


# Example usage
num_classes = 4
batch_size = 4

# Generate random inputs and targets
predicted = torch.tensor([[0.9, 0.1, 0.8, 0.3], [0.3, 0.2, 0.5, 0.7], [0.7, 0.6, 0.1, 0.3]], requires_grad=True)
target = torch.tensor([[1, 0, 1, 0], [0, 0, 1, 1], [1, 0, 0, 0]], dtype=torch.float32)

# Define the model and loss function
model = nn.Linear(num_classes, num_classes)  # 两个线性层和ReLU和Sigmoid激活函数的简单多标签分类模型
loss_fn = AsymmetricLoss()

# Compute the loss
outputs = model(predicted)
loss = loss_fn(torch.sigmoid(outputs), target)
print(loss)

输出结果:

tensor(0.4989, grad_fn=<MeanBackward0>)

http://www.kler.cn/a/2305.html

相关文章:

  • 实战指南:理解 ThreadLocal 原理并用于Java 多线程上下文管理
  • Python网络爬虫与数据采集实战——什么是网络爬虫
  • 【Elasticsearch入门到落地】1、初识Elasticsearch
  • 从0开始学PHP面向对象内容之(常用魔术方法续一)
  • Linux——gcc编译过程详解与ACM时间和进度条的制作
  • old-cms(原生PHP开发的企业网站管理系统)
  • iOS中.podspec文件中source_files参数怎么设置
  • Markdown如何使用详细教程
  • Vue2源码-初始化
  • 算法练习-堆
  • 【数据结构与算法】用栈实现队列
  • STM32学习(五)
  • SpringBoot基础教程
  • 数据结构——栈和队列(2)
  • 基于SpringBoot的学生成绩管理系统
  • 第十四届蓝桥杯三月真题刷题训练——第 18 天
  • 记录一次很坑的报错:java.lang.Exception: The class is not public.
  • 【沐风老师】3DMAX交通流插件TrafficFlow使用方法详解
  • albedo开源框架配置多数据源
  • 乐观锁和悲观锁 面试题
  • vue使用split()将字符串分割数组join()将数组转字符串reverse()将数组反转
  • Linux 总结9个最危险的命令,一定要牢记在心!
  • 通过DNS数据包解释DNS协议各个字段含义
  • Java中 ==和equals的区别是什么?
  • 流量分析-Wireshark -操作手册(不能说最全,只能说更全)
  • Golang每日一练(leetDay0012)