当前位置: 首页 > article >正文

使用TaPas进行表格问答

https://rocm.blogs.amd.com/artificial-intelligence/TaPas/README.html

2024年4月26日, 由 Phillip Dang 于 Phillip Dang 撰写。

传统上,问答任务被视为语义解析任务,即将问题翻译成可执行的完整逻辑形式,以便在表格上检索正确答案。然而,这需要大量标注数据,而这些数据的获取成本很高。

为了应对这一挑战,TaPas 选择通过选择表格单元格的一部分并可能应用聚合操作来预测简化的程序。因此,TaPas可以直接从自然语言中理解操作,而不需要明确的形式化规范。

TaPas 模型(Table Parser,表解析器)是一个基于 BERT 的弱监督问答模型,专为解答关于表格数据的问题而设计和预训练。该模型通过位置嵌入增强,以理解表格结构。输入表格被转换为一系列的单词格式,将单词分成标记,然后在表格标记之前整合问题标记。此外,还有两个分类层,以促进对表格单元格和聚合操作选择,这些操作作用于选定的单元格上。

TaPas 的预训练数据来自包含大量表格的维基百科,使模型能够理解文本与表格之间以及单个单元格与对应表头之间的多种关联。预训练过程涉及从维基百科中提取文本-表格对,生成620万个表格,重点关注最多包含500个单元格的表格,这与其最终任务数据集的结构一致,这些数据集专门包含具有表头行(包含列名)的水平表格。

要深入了解 TaPas 的内部工作原理及其性能,请参阅 Google Research 的 TaPas: Weakly Supervised Table Parsing via Pre-training。

在这篇博客中,我们使用 AMD GPU 和 ROCm 运行了一些推理,并展示了 TaPas 开箱即用的效果。

先决条件

  • 软件:

    • ROCm

    • PyTorch

    • Linux 操作系统

想了解支持的 GPU 和操作系统的列表,请参考 ROCm 的安装指南。为了方便和稳定性,我们建议您直接拉取并运行 rocm/pytorch Docker 映像到您的 Linux 系统,使用以下命令:

docker run -it --ipc=host --network=host --device=/dev/kfd --device=/dev/dri \
           --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \
           --name=olmo rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1 /bin/bash

  • 硬件:

确保系统能识别您的 GPU:

! rocm-smi --showproductname

================= ROCm 系统管理界面 ================
========================= 产品信息 ============================
GPU[0] : 卡系列: Instinct MI210
GPU[0] : 卡型号: 0x0c34
GPU[0] : 卡厂商: Advanced Micro Devices, Inc. [AMD/ATI]
GPU[0] : 卡 SKU: D67301
===================================================================
===================== ROCm SMI 日志结束=========================

让我们检查是否安装了正确版本的 ROCm。

!apt show rocm-libs -a

Package: rocm-libs
Version: 5.7.0.50700-63~22.04
Priority: optional
Section: devel
Maintainer: ROCm Libs Support <rocm-libs.support@amd.com>
Installed-Size: 13.3 kBA
Depends: hipblas (= 1.1.0.50700-63~22.04), hipblaslt (= 0.3.0.50700-63~22.04), hipfft (= 1.0.12.50700-63~22.04), hipsolver (= 1.8.1.50700-63~22.04), hipsparse (= 2.3.8.50700-63~22.04), miopen-hip (= 2.20.0.50700-63~22.04), rccl (= 2.17.1.50700-63~22.04), rocalution (= 2.1.11.50700-63~22.04), rocblas (= 3.1.0.50700-63~22.04), rocfft (= 1.0.23.50700-63~22.04), rocrand (= 2.10.17.50700-63~22.04), rocsolver (= 3.23.0.50700-63~22.04), rocsparse (= 2.5.4.50700-63~22.04), rocm-core (= 5.7.0.50700-63~22.04), hipblas-dev (= 1.1.0.50700-63~22.04), hipblaslt-dev (= 0.3.0.50700-63~22.04), hipcub-dev (= 2.13.1.50700-63~22.04), hipfft-dev (= 1.0.12.50700-63~22.04), hipsolver-dev (= 1.8.1.50700-63~22.04), hipsparse-dev (= 2.3.8.50700-63~22.04), miopen-hip-dev (= 2.20.0.50700-63~22.04), rccl-dev (= 2.17.1.50700-63~22.04), rocalution-dev (= 2.1.11.50700-63~22.04), rocblas-dev (= 3.1.0.50700-63~22.04), rocfft-dev (= 1.0.23.50700-63~22.04), rocprim-dev (= 2.13.1.50700-63~22.04), rocrand-dev (= 2.10.17.50700-63~22.04), rocsolver-dev (= 3.23.0.50700-63~22.04), rocsparse-dev (= 2.5.4.50700-63~22.04), rocthrust-dev (= 2.18.0.50700-63~22.04), rocwmma-dev (= 1.2.0.50700-63~22.04)
Homepage: https://github.com/RadeonOpenCompute/ROCm
Download-Size: 1012 B
APT-Manual-Installed: yes
APT-Sources: http://repo.radeon.com/rocm/apt/5.7 jammy/main amd64 Packages
Description: Radeon Open Compute (ROCm) Runtime software stack

确保 PyTorch 也能识别 GPU:

import torch
print(f"number of GPUs: {torch.cuda.device_count()}")
print([torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())])

number of GPUs: 1
['AMD Radeon Graphics']

让我们开始测试 TaPas。TaPas 有 3 种不同的变体,分别对应于不同的数据集,在这些数据集上,TaPas 进行了微调。在本博客中,我们将重点关注在 WTQ 数据集上进行微调的模型,它是一个弱监督的聚合任务。这个任务不是在对话环境下提问,而是针对给定表格的特定问题,可能涉及聚合操作。这称为弱监督,因为“模型必须在仅将答案作为监督的情况下,学习合适的聚合操作符(SUM/COUNT/AVERAGE/NONE)”。

关于其他变体的更多信息,请查看 该文档。

在开始之前,请确保已安装所有必要的库:

! pip install -q transformers pandas datasets tabulate

接下来导入您将在本博文中使用的模块:

from transformers import TapasTokenizer, TapasForQuestionAnswering
import pandas as pd

加载数据

我们将使用一个关于世界经济的简单数据集。

from datasets import load_dataset
data = load_dataset("ashraq/ott-qa-20k", split='train')

for doc in data:
    if doc['title'] == 'World economy':
        table = pd.DataFrame(doc["data"], columns=doc['header'])
        break 

print(table.to_markdown())
|    |   Rank | Country              | Value ( USD $ )   |   Peak year |
|---:|-------:|:---------------------|:------------------|------------:|
|  0 |      1 | Qatar                | 146,982           |        2012 |
|  1 |      2 | Macau                | 133,021           |        2013 |
|  2 |      3 | Luxembourg           | 108,951           |        2019 |
|  3 |      4 | Singapore            | 103,181           |        2019 |
|  4 |      5 | United Arab Emirates | 92,037            |        2004 |
|  5 |      6 | Brunei               | 83,785            |        2012 |
|  6 |      7 | Ireland              | 83,399            |        2019 |
|  7 |      8 | Norway               | 76,684            |        2019 |
|  8 |      9 | San Marino           | 74,664            |        2008 |
|  9 |     10 | Kuwait               | 71,036            |        2013 |
| 10 |     11 | Switzerland          | 66,196            |        2019 |
| 11 |     12 | United States        | 65,112            |        2019 |
| 12 |     13 | Hong Kong            | 64,928            |        2019 |
| 13 |     14 | Netherlands          | 58,341            |        2019 |
| 14 |     15 | Iceland              | 56,066            |        2019 |
| 15 |     16 | Saudi Arabia         | 55,730            |        2018 |
| 16 |     17 | Taiwan               | 55,078            |        2019 |
| 17 |     18 | Sweden               | 54,628            |        2019 |
| 18 |     19 | Denmark              | 53,882            |        2019 |
| 19 |     20 | Germany              | 53,567            |        2019 |

加载模型

让我们加载模型、其分词器和配置。

from transformers import TapasTokenizer, TapasForQuestionAnswering, TapasConfig
model_name = "google/tapas-base-finetuned-wtq"
model = TapasForQuestionAnswering.from_pretrained(model_name)
tokenizer = TapasTokenizer.from_pretrained(model_name)
config = TapasConfig.from_pretrained('google/tapas-base-finetuned-wtq')

print(model)

print("Aggregations: ", config.aggregation_labels)
TapasForQuestionAnswering(
  (tapas): TapasModel(
    (embeddings): TapasEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(1024, 768)
      (token_type_embeddings_0): Embedding(3, 768)
      (token_type_embeddings_1): Embedding(256, 768)
      (token_type_embeddings_2): Embedding(256, 768)
      (token_type_embeddings_3): Embedding(2, 768)
      (token_type_embeddings_4): Embedding(256, 768)
      (token_type_embeddings_5): Embedding(256, 768)
      (token_type_embeddings_6): Embedding(10, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): TapasEncoder(
      (layer): ModuleList(
        (0-11): 12 x TapasLayer(
          (attention): TapasAttention(
            (self): TapasSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): TapasSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): TapasIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
            (intermediate_act_fn): GELUActivation()
          )
          (output): TapasOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
    )
    (pooler): TapasPooler(
      (dense): Linear(in_features=768, out_features=768, bias=True)
      (activation): Tanh()
    )
  )
  (dropout): Dropout(p=0.1, inplace=False)
  (aggregation_classifier): Linear(in_features=768, out_features=4, bias=True)
)

Aggregation:  {0: 'NONE', 1: 'SUM', 2: 'AVERAGE', 3: 'COUNT'}

运行推断

我们已经准备好通过在数据上运行一些查询来测试模型了。首先,让我们创建一个函数,该函数将查询列表、DataFrame 作为输入,并输出答案。我们改编了来自 HuggingFace 教程 中的代码。关键部分是 convert_logits_to_predictions 方法,它将模型的输出(或 logits)转换为预测坐标,即关注表中的哪个单元格,以及聚合索引,即根据问题应执行的聚合操作。

def qa(queries, table):    
    inputs = tokenizer(table=table, queries=queries, padding=True, truncation=True, return_tensors="pt") 
    outputs = model(**inputs)
    predicted_answer_coordinates, predicted_aggregation_indices = tokenizer.convert_logits_to_predictions(
        inputs, outputs.logits.detach(), outputs.logits_aggregation.detach()
    )

    # 打印出结果:
    id2aggregation = config.aggregation_labels
    aggregation_predictions_string = [id2aggregation[x] for x in predicted_aggregation_indices]

    answers = []
    for coordinates in predicted_answer_coordinates:
        if len(coordinates) == 1:
            # 只有一个单元格:
            answers.append(table.iat[coordinates[0]])
        else:
            # 多个单元格 
            cell_values = []
            for coordinate in coordinates:
                cell_values.append(table.iat[coordinate])
            answers.append(", ".join(cell_values))

    print("")
    for query, answer, predicted_agg in zip(queries, answers, aggregation_predictions_string):
        print(query)
        if predicted_agg == "NONE":
            print("Predicted answer: " + answer)
        else:
            print("Predicted answer: " + predicted_agg + " > " + answer)
        print()

让我们试试这个模型:

queries = ["What is the value of Norway?",
           "What is the total value of all countries in 2013?",
           "What is the average value of all countries in 2019?",
           "How many countries are in the data in 2012?",
           "What is the combined value of Sweden and Denmark?"
          ]
qa(queries, table)

输出应如下所示:

What is the value of Norway?
Predicted answer: AVERAGE > 76,684

What is the total value of all countries in 2013?
Predicted answer: SUM > 133,021, 71,036

What is the average value of all countries in 2019?
Predicted answer: AVERAGE > 108,951, 83,399, 76,684, 66,196, 65,112, 64,928, 58,341, 56,066, 55,078, 54,628, 53,882, 53,567

How many countries are in the data in 2012?
Predicted answer: COUNT > Qatar, Brunei

What is the combined value of Sweden and Denmark?
Predicted answer: SUM > 54,628, 53,882

该模型能够根据问题准确选择数据中的相关单元格和聚合函数。我们鼓励读者探索 TaPas 的其他变种,并在自己的数据集上进行微调。


http://www.kler.cn/a/392758.html

相关文章:

  • 动态规划<八> 完全背包问题及其余背包问题
  • CMake配置区分Debug和Release模式
  • 目标检测入门指南:从原理到实践
  • gitlab 还原合并请求
  • 【Java回顾】Day2 正则表达式----异常处理
  • flink cdc oceanbase
  • 一文理解:结构化数据、非结构化数据、半结构化数据和元数据
  • 使用 start-local 脚本在本地运行 Elasticsearch
  • Pycharm打开终端时报错:Cannot open Local,Failed to start[powershell.exe]
  • 【论文复现】STM32设计的物联网智能鱼缸
  • 快速排序法
  • Macos mysql实现命令自动补全的方法
  • 7天用Go从零实现分布式缓存GeeCache(总结)
  • 目录树文件名映射深度1分组计数,tree(映射(目录A))
  • Mysql用户权限与账号管理
  • Conda环境、Ubuntu环境移植
  • Scala 的List
  • 【蓝桥等考C++真题】蓝桥杯等级考试C++组第13级L13真题原题(含答案)-成绩排序ABCDE
  • 3DTiles之使用customShader调整风格
  • 图像处理实验一(Matlab Exercises and Image Fundamentals)
  • Unity使用PS合并贴图
  • 「IDE」PyCharm 之 安装与卸载
  • Python 数据库操作教程
  • python购物计算 2024年6月青少年电子学会等级考试 中小学生python编程等级考试一级真题答案解析
  • 51c自动驾驶~合集21
  • python,dataclasses模块介绍及示例