当前位置: 首页 > article >正文

李宏毅机器学习2023-HW13-Network Compression

文章目录

  • Task
  • Link
  • Baseline
    • Simple Baseline
    • Medium Baseline
    • Strong Baseline
    • Boss Baseline
      • FitNet Knowledge Distillation
      • Relational Knowledge Distillation (RKD)
      • Distance Metric (DM) Knowledge Distillation

Task

通过network compression完成图片分类,数据集跟hw3中11种食品分类一致。需要设计小模型student model,参数数目少于60k,训练该模型接近teacher model的精度test-Acc ≅ 0.902

Link

kaggle

Baseline

Simple Baseline

Just run the sample code

Medium Baseline

loss function定义为 KL divergence,公式如下:
L o s s = α T 2 × K L ( p ∣ ∣ q ) + ( 1 − α ) ( O r i g i n a l C r o s s E n t r o p y L o s s ) , w h e r e   p = s o f t m a x ( student’s logits T ) , a n d   q = s o f t m a x ( teacher’s logits T ) Loss=αT^2×KL(p||q)+(1−α)(Original Cross Entropy Loss),where \ p=softmax(\frac{\text{student's logits}}{T}),and\ q=softmax(\frac{\text{teacher's logits}}{T}) Loss=αT2×KL(p∣∣q)+(1α)(OriginalCrossEntropyLoss),where p=softmax(Tstudent’s logits),and q=softmax(Tteacher’s logits)
同时epoch可以增加到50,其他地方标有#medium的也要修改

# medium
def loss_fn_kd(student_logits, labels, teacher_logits, alpha=0.5, temperature=5.0):
    # ------------TODO-------------
    # Refer to the above formula and finish the loss function for knowkedge distillation using KL divergence loss and CE loss.
    # If you have no idea, please take a look at the provided useful link above.
    student_prob = F.softmax(student_logits/temperature, dim=-1)
    teacher_prob = F.softmax(teacher_logits/temperature, dim=-1)
    KL_loss = (teacher_prob * (teacher_prob.log() - student_prob)).mean()
    CE_loss = nn.CrossEntropyLoss()(student_logits, labels)
    loss = alpha * temperature**2* KL_loss + (1 - alpha) * CE_loss
    return loss

Strong Baseline

用depth-wise and point-wise convolutions修改model architecture+增加epoch

  • other useful techniques
    • group convolution (Actually, depthwise convolution is a specific type of group convolution)
    • SqueezeNet
    • MobileNet
    • ShuffleNet
    • Xception
    • GhostNet

还可以应用中间层特征学习,这里我没有做

# strong
def dwpw_conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
    return nn.Sequential(
        nn.Conv2d(in_channels, in_channels, kernel_size, stride=stride, padding=padding, groups=in_channels), #depthwise convolution
        nn.BatchNorm2d(in_channels),
        nn.ReLU(),
        nn.Conv2d(in_channels, out_channels, 1), # pointwise convolution
        nn.BatchNorm2d(out_channels),
        nn.ReLU(),
    )

Boss Baseline

Other advanced Knowledge Distillation(FitNet/RKD/DM) + 增加epoch + Depthwise & Pointwise Conv layer(深度可分离卷积)

当然也可以应用中间层特征学习

FitNet Knowledge Distillation

FitNet focuses on transferring knowledge from intermediate feature representations (hidden layers) instead of just using the output logits. The student model is trained to mimic the feature maps from certain layers of the teacher model.

#boss
def loss_fn_fitnet(teacher_feature, student_feature, labels, alpha=0.5):
    """
    FitNet Knowledge Distillation Loss Function.
    
    Args:
    - teacher_feature: The feature maps from a hidden layer of the teacher model.
    - student_feature: The feature maps from the corresponding hidden layer of the student model.
    - labels: Ground truth labels for the task.
    - alpha: Weighting factor for the feature distillation loss.
    
    Returns:
    - loss: Combined loss with cross-entropy and feature map alignment.
    """
    # Mean squared error loss to align feature maps of teacher and student
    feature_loss = F.mse_loss(student_feature, teacher_feature)

    # Hard label cross-entropy loss for the student output (classification)
    hard_loss = F.cross_entropy(student_feature, labels)
    
    # Combine both losses
    loss = alpha * hard_loss + (1 - alpha) * feature_loss
    return loss

Relational Knowledge Distillation (RKD)

Relational Knowledge Distillation focuses on transferring the relationships (distances and angles) between data samples as learned by the teacher. The student model is trained to match these relationships instead of just focusing on output probabilities.

# boss
def pairwise_distance(x):
    """Calculate pairwise distance between batch samples."""
    return torch.cdist(x, x, p=2)

def angle_between_pairs(x):
    """Calculate angles between all pairs of points in batch."""
    diff = x.unsqueeze(1) - x.unsqueeze(0)
    norm = diff.norm(dim=-1, p=2, keepdim=True)
    normalized_diff = diff / (norm + 1e-8)
    angles = torch.bmm(normalized_diff, normalized_diff.transpose(1, 2))
    return angles

def loss_fn_rkd(teacher_feature, student_feature, labels, alpha=0.5):
    """
    Relational Knowledge Distillation Loss Function.
    
    Args:
    - teacher_feature: Teacher model feature embeddings.
    - student_feature: Student model feature embeddings.
    - labels: Ground truth labels.
    - alpha: Weighting factor for relational distillation loss.
    
    Returns:
    - loss: Combined relational knowledge and hard label loss.
    """
    
    # Pairwise distances between features in the teacher and student model
    teacher_dist = pairwise_distance(teacher_feature)
    student_dist = pairwise_distance(student_feature)
    
    # Distillation loss using the L2 norm between relational distances
    distance_loss = F.mse_loss(student_dist, teacher_dist)

    # Angle-based loss between teacher and student feature vectors
    teacher_angle = angle_between_pairs(teacher_feature)
    student_angle = angle_between_pairs(student_feature)
    angle_loss = F.mse_loss(student_angle, teacher_angle)
    
    # Hard label cross-entropy loss for the student output
    hard_loss = F.cross_entropy(student_feature, labels)
    
    # Combine the losses
    loss = alpha * hard_loss + (1 - alpha) * (distance_loss + angle_loss)
    return loss

Distance Metric (DM) Knowledge Distillation

Distance Metric distillation focuses on transferring the distance metric (such as Euclidean distance or cosine similarity) between instances in the teacher’s feature space to the student model.

def loss_fn_dm(teacher_feature, student_feature, labels, alpha=0.5):
    """
    Distance Metric (DM) Knowledge Distillation Loss Function.
    
    Args:
    - teacher_feature: The feature representations from the teacher model.
    - student_feature: The feature representations from the student model.
    - labels: Ground truth labels for the task.
    - alpha: Weighting factor for distance metric loss.
    
    Returns:
    - loss: Combined distance metric loss and cross-entropy loss.
    """
    # Calculate pairwise distance between teacher and student embeddings
    teacher_dist = pairwise_distance(teacher_feature)
    student_dist = pairwise_distance(student_feature)
    
    # Distance metric loss using Mean Squared Error (MSE) loss
    dist_loss = F.mse_loss(student_dist, teacher_dist)
    
    # Hard label cross-entropy loss for the student's output
    hard_loss = F.cross_entropy(student_feature, labels)
    
    # Combine the losses
    loss = alpha * hard_loss + (1 - alpha) * dist_loss
    return loss

http://www.kler.cn/a/314929.html

相关文章:

  • java mail 535 Login Fail. Please enter your authorization code to login
  • ArrayList和HashMap区别
  • c++ 17 constexpr
  • Eclipse配置Tomcat服务器(最全图文详解)
  • 成为LabVIEW自由开发者
  • 回顾2024年重磅AI发布汇总
  • 4.《DevOps》系列K8S部署CICD流水线之Helm部署Harbor私人镜像仓库
  • 机器学习中求解模型参数的方法
  • LabVIEW编程能力如何能突飞猛进
  • 【数据优化】基于GEE填补遥感缺失数据
  • 测试面试题:接口自动化测试流程?
  • Vue3入门 - ElementPlus中左侧菜单和Tabs菜单组合联动效果
  • linux下的日志编写
  • Linux 文件 IO 管理(第二讲)(重定向和缓冲区)
  • 鸿蒙开发的基本技术栈及学习路线
  • 【JVM安装MinIO】
  • IO流中的异常捕获
  • 搜索引擎onesearch3实现解释和升级到Elasticsearch v8系列(三)-文档
  • C++模版初阶
  • claude,gpt,通义千问
  • Java面试篇基础部分-ReentrantLock详解(二)
  • 2024最新!!!iOS高级面试题,全!(二)
  • 深度学习对抗海洋赤潮危机!浙大GIS实验室提出ChloroFormer模型,可提前预警海洋藻类爆发
  • Vue3 中组件传递 + css 变量的组合
  • 深度学习03-神经网络01-什么是神经网络?
  • QT快速安装使用指南