多任务学习MTL+多任务损失
多任务学习(MTL)是一种机器学习范式,它的目标是 同时优化多个相关任务,通过共享信息提高模型的泛化能力。MTL 在计算机视觉、自然语言处理(NLP)、医疗诊断等领域有广泛应用。
为什么要使用多任务学习?
相比于单任务学习,多任务学习有以下 优势:
- 信息共享:不同任务之间共享特征,避免过拟合,提高泛化能力。
- 数据利用率提升:某些任务可能数据较少,通过共享训练提升效果。
- 模型参数减少:多个任务共享部分参数,减少计算资源占用。
- 优化收敛加快:不同任务的梯度信息可以互补,使优化过程更稳定。
- 鲁棒性增强:主任务容易受到噪声影响,辅助任务可以起到正则化作用。
示例:
- 在医疗领域,一个神经网络可以同时学习 ECG 波形恢复 和 疾病分类,利用波形恢复信息来提升分类效果。
- 在计算机视觉中,一个网络可以同时学习 目标检测 和 语义分割,因为目标的边界信息可以增强分类效果。
多任务学习的基本架构
多任务学习的网络通常有以下几种结构:
1 硬共享(Hard Parameter Sharing)
架构特点:
- 共享 底层特征提取网络,任务分支在后端分开。
- 适用于任务间高度相关的情况。
示意图:
输入数据 → 共享底层特征提取 → 任务1头部
↘ 任务2头部
示例:
在 ECG 任务 中:
- 共享 CNN/Transformer 提取 ECG 特征。
- 任务 1:ECG 波形恢复(自编码器)。
- 任务 2:疾病分类(全连接层)。
2 软共享(Soft Parameter Sharing)
架构特点:
- 每个任务有自己的 独立网络,但参数之间可以有一定的约束,比如 L2 正则化 让参数保持相似。
- 适用于任务较为独立,但仍有部分共享信息的情况。
3 多头(Multi-Head)与多层(Multi-Layer)
架构特点:
- 前面共享一部分层,后面各个任务有自己的 多头输出。
- 适用于任务关联性较强的情况。
多任务学习的损失函数设计
在多任务学习中,需要 加权求和不同任务的损失:
L
=
λ
1
L
1
+
λ
2
L
2
+
⋯
+
λ
n
L
n
L = \lambda_1 L_1 + \lambda_2 L_2 + \dots + \lambda_n L_n
L=λ1L1+λ2L2+⋯+λnLn
其中:
- L 1 , L 2 , … , L n L_1, L_2, \dots, L_n L1,L2,…,Ln 是各任务的损失。
- λ 1 , λ 2 , … , λ n \lambda_1, \lambda_2, \dots, \lambda_n λ1,λ2,…,λn 是任务的权重。
1 固定权重
最简单的方法是 人为设定权重,如:
L
=
0.1
L
reconstruction
+
1.0
L
classification
L = 0.1 L_{\text{reconstruction}} + 1.0 L_{\text{classification}}
L=0.1Lreconstruction+1.0Lclassification
适用于知道哪个任务更重要的情况。
2 动态调整权重
在某些情况下,固定权重不够灵活,因此可以 自动调整任务权重:
(1) 不确定性加权(Uncertainty Weighting)
利用任务的不确定性 自动调整权重:
L
=
1
σ
1
2
L
1
+
1
σ
2
2
L
2
+
log
(
σ
1
σ
2
)
L = \frac{1}{\sigma_1^2} L_1 + \frac{1}{\sigma_2^2} L_2 + \log (\sigma_1 \sigma_2)
L=σ121L1+σ221L2+log(σ1σ2)
其中:
- σ 1 , σ 2 \sigma_1, \sigma_2 σ1,σ2 是可训练参数。
- 训练过程中模型会动态调整任务权重。
(2) GradNorm
GradNorm 通过控制 梯度的范数 来动态调整不同任务的学习速率,使得所有任务的梯度变化更加均衡。
(3) 任务不平衡策略
- 任务 A 的 Loss 过大:降低其权重。
- 任务 B 的 Loss 下降很慢:提高其权重。
训练策略
在 MTL 训练时,可以采用不同的策略:
1 交替训练
- 适用于难度不同的任务。
- 例子:先训练 ECG 波形恢复(易学),再训练分类任务(主任务)。
2 先预训练,再联合训练
- 适用于一个任务需要先学会基础特征。
- 例子:先用自监督学习让 ECG 波形恢复学好特征,再用于分类任务。
3 任务逐步衰减
- 逐渐降低某个任务的权重,使其影响减小。
代码示例
1 传统多任务 Loss
import torch
# 任务 1(波形恢复) 和 任务 2(分类)
lambda_1 = 0.1
lambda_2 = 1.0
loss = lambda_1 * loss_reconstruction + lambda_2 * loss_classification
loss.backward()
2 不确定性加权
import torch.nn as nn
class MultiTaskLoss(nn.Module):
def __init__(self):
super().__init__()
self.sigma1 = nn.Parameter(torch.tensor(1.0))
self.sigma2 = nn.Parameter(torch.tensor(1.0))
def forward(self, loss1, loss2):
loss = (loss1 / (2 * self.sigma1**2)) + (loss2 / (2 * self.sigma2**2)) + torch.log(self.sigma1) + torch.log(self.sigma2)
return loss
criterion = MultiTaskLoss()
loss = criterion(loss_reconstruction, loss_classification)