当前位置: 首页 > article >正文

多任务学习MTL+多任务损失

多任务学习(MTL)是一种机器学习范式,它的目标是 同时优化多个相关任务,通过共享信息提高模型的泛化能力。MTL 在计算机视觉、自然语言处理(NLP)、医疗诊断等领域有广泛应用。


为什么要使用多任务学习?

相比于单任务学习,多任务学习有以下 优势

  1. 信息共享:不同任务之间共享特征,避免过拟合,提高泛化能力。
  2. 数据利用率提升:某些任务可能数据较少,通过共享训练提升效果。
  3. 模型参数减少:多个任务共享部分参数,减少计算资源占用。
  4. 优化收敛加快:不同任务的梯度信息可以互补,使优化过程更稳定。
  5. 鲁棒性增强:主任务容易受到噪声影响,辅助任务可以起到正则化作用。

示例:

  • 医疗领域,一个神经网络可以同时学习 ECG 波形恢复疾病分类,利用波形恢复信息来提升分类效果。
  • 计算机视觉中,一个网络可以同时学习 目标检测语义分割,因为目标的边界信息可以增强分类效果。

多任务学习的基本架构

多任务学习的网络通常有以下几种结构:

1 硬共享(Hard Parameter Sharing)

架构特点

  • 共享 底层特征提取网络,任务分支在后端分开。
  • 适用于任务间高度相关的情况。

示意图

输入数据  →  共享底层特征提取  →  任务1头部
                           ↘  任务2头部

示例
ECG 任务 中:

  • 共享 CNN/Transformer 提取 ECG 特征。
  • 任务 1:ECG 波形恢复(自编码器)。
  • 任务 2:疾病分类(全连接层)。

2 软共享(Soft Parameter Sharing)

架构特点

  • 每个任务有自己的 独立网络,但参数之间可以有一定的约束,比如 L2 正则化 让参数保持相似。
  • 适用于任务较为独立,但仍有部分共享信息的情况。

3 多头(Multi-Head)与多层(Multi-Layer)

架构特点

  • 前面共享一部分层,后面各个任务有自己的 多头输出
  • 适用于任务关联性较强的情况。

多任务学习的损失函数设计

在多任务学习中,需要 加权求和不同任务的损失
L = λ 1 L 1 + λ 2 L 2 + ⋯ + λ n L n L = \lambda_1 L_1 + \lambda_2 L_2 + \dots + \lambda_n L_n L=λ1L1+λ2L2++λnLn
其中:

  • L 1 , L 2 , … , L n L_1, L_2, \dots, L_n L1,L2,,Ln 是各任务的损失。
  • λ 1 , λ 2 , … , λ n \lambda_1, \lambda_2, \dots, \lambda_n λ1,λ2,,λn 是任务的权重。

1 固定权重

最简单的方法是 人为设定权重,如:
L = 0.1 L reconstruction + 1.0 L classification L = 0.1 L_{\text{reconstruction}} + 1.0 L_{\text{classification}} L=0.1Lreconstruction+1.0Lclassification
适用于知道哪个任务更重要的情况。

2 动态调整权重

在某些情况下,固定权重不够灵活,因此可以 自动调整任务权重

(1) 不确定性加权(Uncertainty Weighting)

利用任务的不确定性 自动调整权重
L = 1 σ 1 2 L 1 + 1 σ 2 2 L 2 + log ⁡ ( σ 1 σ 2 ) L = \frac{1}{\sigma_1^2} L_1 + \frac{1}{\sigma_2^2} L_2 + \log (\sigma_1 \sigma_2) L=σ121L1+σ221L2+log(σ1σ2)
其中:

  • σ 1 , σ 2 \sigma_1, \sigma_2 σ1,σ2 是可训练参数。
  • 训练过程中模型会动态调整任务权重。

(2) GradNorm

GradNorm 通过控制 梯度的范数 来动态调整不同任务的学习速率,使得所有任务的梯度变化更加均衡。

(3) 任务不平衡策略

  • 任务 A 的 Loss 过大:降低其权重
  • 任务 B 的 Loss 下降很慢:提高其权重

训练策略

在 MTL 训练时,可以采用不同的策略:

1 交替训练

  • 适用于难度不同的任务
  • 例子:先训练 ECG 波形恢复(易学),再训练分类任务(主任务)。

2 先预训练,再联合训练

  • 适用于一个任务需要先学会基础特征
  • 例子:先用自监督学习让 ECG 波形恢复学好特征,再用于分类任务。

3 任务逐步衰减

  • 逐渐降低某个任务的权重,使其影响减小。

代码示例

1 传统多任务 Loss

import torch

# 任务 1(波形恢复) 和 任务 2(分类)
lambda_1 = 0.1
lambda_2 = 1.0

loss = lambda_1 * loss_reconstruction + lambda_2 * loss_classification
loss.backward()

2 不确定性加权

import torch.nn as nn

class MultiTaskLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.sigma1 = nn.Parameter(torch.tensor(1.0))
        self.sigma2 = nn.Parameter(torch.tensor(1.0))

    def forward(self, loss1, loss2):
        loss = (loss1 / (2 * self.sigma1**2)) + (loss2 / (2 * self.sigma2**2)) + torch.log(self.sigma1) + torch.log(self.sigma2)
        return loss

criterion = MultiTaskLoss()
loss = criterion(loss_reconstruction, loss_classification)


http://www.kler.cn/a/570069.html

相关文章:

  • 机械视觉组成模块-相机选型
  • 从零开始用react + tailwindcss + express + mongodb实现一个聊天程序(九) 消息接口
  • 【含文档+PPT+源码】基于SpringBoot电脑DIY装机教程网站的设计与实现
  • 面试常问的压力测试问题
  • 实时金融信息搜索的新突破:基于大型语言模型的智能代理框架
  • 前端项目打包生成 JS 文件的核心步骤
  • OpenCV计算摄影学(12)色调映射(Tone Mapping)的一个类cv::TonemapMantiuk
  • CELLO : Causal Evaluation of Large Vision-Language Models
  • 蓝桥与力扣刷题(蓝桥 k倍区间)
  • 互联网时代如何保证数字足迹的安全,以防个人信息泄露?
  • 图像采集卡的技术概述
  • PPT 小黑课堂第38套
  • 使用vue3+element plus 的table自制的穿梭框(支持多列数据)
  • Kotlin函数式编程与Lambda表达式
  • 【Go原生】项目入口配置及路由配置写法
  • QEMU源码全解析 —— 内存虚拟化(24)
  • 华为飞腾D2000芯片(基于ARM架构)的欧拉操作系统(openEuler)上部署MySQL
  • springboot gradle 多项目创建
  • Python----线性代数(线性代数基础:标量,向量,矩阵,张量)
  • 【每日学点HarmonyOS Next知识】web内存异常问题、Web高度问题、图片按压效果、加载沙盒中html、感知路由返回传参