当前位置: 首页 > article >正文

李宏毅2023机器学习HW15-Few-shot Classification

文章目录

  • Link
  • Task: Few-shot Classification
  • Baseline
    • Simple—transfer learning
    • Medium — FO-MAML
    • Strong — MAML

Link

Kaggle

Task: Few-shot Classification

The Omniglot dataset

  • background set: 30 alphabets
  • evaluation set: 20 alphabets
  • Problem setup: 5-way 1-shot classification

Omniglot数据集

  • 背景集:30个字母
  • 评估集:20个字母
  • 问题设置:5-way 1-shot分类
    Definition of support set and query set

Baseline

Simple—transfer learning

直接把sample code运行即可

  • traing:
    对随机选择的5个任务进行正常分类训练验证/测试
  • validation / testing:
    对五个 Support Images 进行微调,并对Query Images进行推理

Slover首先从训练集中选择5个任务,然后对选择的5个任务进行正常分类训练。在推理中,模型在支持集support set图像上微调inner_train_step步骤,然后在查询集Query Set图像上进行推理。
为了与元学习Slover保持一致,基本Slover具有与元学习Slover完全相同的输入输出格式

def BaseSolver(
    model,
    optimizer,
    x,
    n_way,
    k_shot,
    q_query,
    loss_fn,
    inner_train_step=1,
    inner_lr=0.4,
    train=True,
    return_labels=False,
):
    criterion, task_loss, task_acc = loss_fn, [], []
    labels = []

    for meta_batch in x:
        # Get data
        support_set = meta_batch[: n_way * k_shot]
        query_set = meta_batch[n_way * k_shot :]

        if train:
            """ training loop """
            # Use the support set to calculate loss
            labels = create_label(n_way, k_shot).to(device)
            logits = model.forward(support_set)
            loss = criterion(logits, labels)

            task_loss.append(loss)
            task_acc.append(calculate_accuracy(logits, labels))
        else:
            """ validation / testing loop """
            # First update model with support set images for `inner_train_step` steps
            fast_weights = OrderedDict(model.named_parameters())


            for inner_step in range(inner_train_step):
                # Simply training
                train_label = create_label(n_way, k_shot).to(device)
                logits = model.functional_forward(support_set, fast_weights)
                loss = criterion(logits, train_label)

                grads = torch.autograd.grad(loss, fast_weights.values(), create_graph=True)
                # Perform SGD
                fast_weights = OrderedDict(
                    (name, param - inner_lr * grad)
                    for ((name, param), grad) in zip(fast_weights.items(), grads)
                )

            if not return_labels:
                """ validation """
                val_label = create_label(n_way, q_query).to(device)

                logits = model.functional_forward(query_set, fast_weights)
                loss = criterion(logits, val_label)
                task_loss.append(loss)
                task_acc.append(calculate_accuracy(logits, val_label))
            else:
                """ testing """
                logits = model.functional_forward(query_set, fast_weights)
                labels.extend(torch.argmax(logits, -1).cpu().numpy())

    if return_labels:
        return labels

    batch_loss = torch.stack(task_loss).mean()
    task_acc = np.mean(task_acc)

    if train:
        # Update model
        model.train()
        optimizer.zero_grad()
        batch_loss.backward()
        optimizer.step()

    return batch_loss, task_acc

Medium — FO-MAML

FOMAML(First-Order MAML)是MAML(Model-Agnostic Meta-Learning)的一种简化版本。MAML是一种元学习算法,旨在通过训练模型使其能够在少量新数据上快速适应新任务。FOMAML通过忽略二阶导数来简化MAML的计算过程,从而提高计算效率。它在许多情况下表现良好,尤其是在计算资源有限的情况下。然而,它也可能在某些任务上表现不如完整的MAML。

MAML的核心思想是通过在多个任务上进行训练,使得模型能够在面对新任务时,只需少量数据就能快速收敛到一个好的参数配置。具体来说,MAML的训练过程包括两个层次的优化:

  • 内层优化(Inner Loop):在每个任务上进行少量的梯度更新,以适应该任务。

  • 外层优化(Outer Loop):在所有任务上进行梯度更新,以优化模型的初始参数,使得模型在面对新任务时能够快速适应。

""" Inner Loop Update """
grads = torch.autograd.grad(loss, fast_weights.values(), create_graph=False) # create_graph=False:这个参数表示在计算梯度时不创建计算图。在FOMAML中,我们只关心一阶导数,因此不需要创建计算图
            fast_weights = OrderedDict(
                (name, param - inner_lr * grad)
                for ((name, param), grad) in zip(fast_weights.items(), grads)
            )


""" Outer Loop Update """
        # TODO: Finish the outer loop update
        # raise NotimplementedError
        meta_batch_loss.backward()
        optimizer.step()

Strong — MAML

""" Inner Loop Update """
grads = torch.autograd.grad(loss, fast_weights.values(), create_graph=True)
            fast_weights = OrderedDict(
                (name, param - inner_lr * grad)
                for ((name, param), grad) in zip(fast_weights.items(), grads)
            )


""" Outer Loop Update """
        # TODO: Finish the outer loop update
        # raise NotimplementedError
        meta_batch_loss.backward()
        optimizer.step()

http://www.kler.cn/news/308880.html

相关文章:

  • 部分动态铜皮的孤岛无法删除。报错
  • Linux下的CAN通讯
  • 深度学习中实验、观察与思考的方法与技巧
  • JavaScript:驱动现代Web应用的关键引擎及其与HTML/CSS的集成
  • 数模原理精解【11】
  • el-table 如何实现行列转置?
  • C#读取应用配置的简单类
  • 软件测试工程师面试整理-常见面试问题
  • 后端Controller获取成功,但是前端报错404
  • etcd入门指南:分布式事务、分布式锁及核心API详解
  • 企业开发时,会使用sqlalchedmy来构建数据库 结构吗? 还是说直接写SQL 语句比较多?
  • 断电重启之后服务器都有哪些服务需要重启
  • 828华为云征文|docker部署kafka及ui搭建
  • VRRP 笔记
  • 认知小文3《打破桎梏,编程与人生的基本法则》
  • 抓机遇,创发展︱2025 第十二届广州国际汽车零部件加工技术及汽车模具展览会,零部件国产浪潮不可阻挡
  • Pillow:Python图像处理库详解
  • 计算机网络(网络层)
  • 系统架构设计师:系统质量属性与架构评估
  • 固态硬盘:量产、开卡、ROM短接是指什么?
  • 34.贪心算法1
  • 2024最新股票系统源码 附教程
  • Track 08:AIML
  • CTFHub技能树-信息泄露-HG泄漏
  • 医学数据分析实训 项目二 数据预处理作业
  • 在 React 中掌握 useImperativeHandle(使用 TypeScript)
  • visual prompt tuning和visual instruction tuning
  • 白话:大型语言模型中的幻觉(Hallucinations)
  • react hooks--useState
  • Spring Boot基础