当前位置: 首页 > article >正文

OCR经典神经网络(三)LayoutLM v2算法原理及其在发票数据集上的应用(NER及RE)

OCR经典神经网络(三)LayoutLM v2算法原理及其在发票数据集上的应用(NER及RE)

  • LayoutLM系列模型是微软发布的、文档理解多模态基础模型领域最重要和有代表性的工作:
    • LayoutLM v2:在一个单一的多模态框架中对文本(text)、布局(layout)和图像(image)之间的交互进行建模。
    • LayoutXLM:LayoutXLM是 LayoutLMv2的多语言扩展版本。
    • LayoutLM v3:借鉴了ViLT和BEIT,不需要经过预训练的视觉backbone,通过MLM、MIM和WPA进行预训练的多模态Transformer。在以视觉为中心的任务上(如文档图像分类和文档布局分析)和以文本为中心的任务上(表单理解、收据理解、文档问答)都表现很好。
  • 今天,我们来了解下LayoutLM v2模型。
    • 论文链接:https://arxiv.org/pdf/2012.14740
    • 同样,百度开源的paddleocr中,在关键信息抽取中集成了此算法。
    • paddleocr中集成的算法列表:https://github.com/PaddlePaddle/PaddleOCR/blob/main/docs/algorithm/overview.md

1 LayoutLM v2算法原理

  • LayoutLM v2是一种多模态Transformer模型,该模型在预训练阶段整合了文档文本、版式及视觉信息,实现了在一个框架内端到端地学习跨模态交互。同时,将一种空间感知的自注意力机制融入到了Transformer架构中。
  • 除了掩码视觉语言模型(MVLM)预训练策略外,LayoutLM v2还新增了文本-图像对齐(TIA)和文本-图像匹配(TIM) 作为预训练策略,以强化不同模态间的对齐。
  • LayoutLMv2不仅在传统的富视觉文档理解(VrDU)任务上取得了显著的性能提升并达到当时新的最优水平,还在文档图像的视觉问题回答(VQA)任务上实现了新突破,这证明了多模态预训练在富视觉文档理解领域的巨大潜力。

1.1 模型结构

  • 模型结构如下图所示,可以看到LayoutLM v2接收文本、视觉及版式信息作为输入,以建立深度的跨模态交互。另外,将spatial-aware的自注意力机制整合到了transformer中。

  • 这里,我们主要看下Embedding层:

    • 文本嵌入

      • 文本嵌入包含三种嵌入:词嵌入代表词本身,一维位置嵌入表示词的位置索引,而段落嵌入用于区分不同的文本段落。

      t i = T o k E m b ( w i ) + P o s E m b 1 D ( i ) + S e g E m b ( s i ) t_i= TokEmb(w_i)+PosEmb1D(i)+SegEmb(s_i) ti=TokEmb(wi)+PosEmb1D(i)+SegEmb(si)

      • 使用WordPiece对OCR文本序列进行分词,并将每个分词(token)分配给特定的段落。接着,在序列的开始添加[CLS]标记,在每个文本段落的末尾添加[SEP]标记。为了使最终序列的长度恰好等于最大序列长度L,在序列末尾额外添加[PAD]填充符。
    • 视觉嵌入

      • 给定一个文档页面图像I,将其调整大小至224×224像素后输入到视觉主干网络中。之后,输出的特征图通过平均池化到固定尺寸,宽度为W,高度为H。接下来,它被展平为长度为W×H(例如:7×7)的视觉嵌入序列,此序列被称为VisTokEmb(I)。然后对每个视觉token嵌入应用线性投影层,以统一其维度与文本嵌入的维度
      • 由于基于CNN的视觉主干无法捕获位置信息,因此添加一维位置嵌入。
      • 对于段落嵌入,将所有视觉令牌附属于视觉段[C]。

      v i = P r o j ( V i s T o k E m b ( I ) i + P o s E m b 1 D ( i ) + S e g E m b ( [ C ] ) v_i= Proj(VisTokEmb(I)_i+PosEmb1D(i)+SegEmb([C]) vi=Proj(VisTokEmb(I)i+PosEmb1D(i)+SegEmb([C])

    • 布局嵌入(2D Position Embeddings)

      • 将所有的坐标标准化并离散化为[0, 1000]范围内的整数,并使用两个嵌入层分别嵌入x轴特征和y轴特征
      • 给定第i个( 0 ≤ i < W × H + L 0 ≤ i < W×H + L 0i<W×H+L)文本/视觉token的标准化边界框 b o x i = ( x m i n , x m a x , y m i n , y m a x , w i d t h , h e i g h t ) box_i = (x_{min}, x_{max}, y_{min}, y_{max}, width, height) boxi=(xmin,xmax,ymin,ymax,width,height)布局嵌入层将这六个边界框特征连接起来构建一个token级的2D位置嵌入,即布局嵌入

在这里插入图片描述
在这里插入图片描述

  • 由于卷积神经网络(CNNs)执行局部变换,因此视觉token嵌入可以一一映射回图像区域,既没有重叠也没有遗漏。
    • 在计算边界框时,视觉token可以被视为均匀划分的网格。
    • 对于特殊token [CLS]、[SEP]和[PAD],会附加一个空边界框boxPAD = (0, 0, 0, 0, 0, 0)。这意味着这些特殊符号在空间布局上不占用实际区域,但通过这样的空边界框嵌入,模型能够将它们整合到序列中的相应位置上,同时保持空间信息的一致性

1.2 预训练目标及数据

1.2.1 MVLM

  • 采用了掩码视觉-语言建模(Masked Visual-Language Modeling, MVLM)方法,以便模型在跨模态线索的帮助下更好地学习语言方面。
    • 随机掩蔽一些文本token,并要求模型恢复这些被掩蔽的token。
    • 与此同时,布局信息保持不变,这意味着模型了解每个被掩蔽token在页面上的位置。
    • 为了避免视觉线索泄露,在将原始页面图像输入到视觉编码器之前,会先对应掩蔽掉与被掩蔽文本token相对应的图像区域。

1.2.2 TIA

  • Text-Image Alignment(TIA):随机遮盖图像,然后识别文本对应图像是否被遮盖了。
    • 为了帮助模型学习图像与边界框坐标的空間位置对应关系,提出了细粒度的跨模态对齐任务——文本-图像对齐(Text-Image Alignment, TIA)。
    • 在TIA任务中,随机选择一些文本行,并在其文档图像上的对应图像区域进行遮盖, 称此操作为“遮盖”,以避免与MVLM中的“掩码”操作混淆。
    • 预训练期间,在编码器输出之上构建了一个分类层。该层根据文本令牌是否被遮盖(即,[Covered]或[Not Covered])预测每个文本令牌的标签,并计算二元交叉熵损失
    • 考虑到输入图像的分辨率有限,且某些文档元素(如图表中的符号和线条)可能看起来像被遮盖的文本区域,寻找单词大小的遮盖图像区域的任务可能会存在噪声。因此,遮盖操作是在行级别进行的
    • 当MVLM和TIA同时执行时,MVLM中被掩蔽的令牌的TIA损失不予考虑。这防止了模型学习从[MASK]到[Covered]这种无用但直观的对应关系。

1.2.3 TIM

  • Text-Image Matching(TIM):使用[CLS]来判断给出的图片特征与文本特征是否属于同一个页面。
  • 为了帮助模型学习文档图像与文本内容之间的对应关系,采用了较为粗粒度的跨模态对齐任务,即文本-图像匹配(Text-Image Matching, TIM)。
  • 将[CLS]位置的输出表示送入一个分类器,以预测图像和文本是否来自同一文档页面。正常的配对输入被视为正样本
  • 为了构建负样本,图像要么被另一文档的页面图像替换,要么被移除。
  • 为防止模型通过寻找任务特定特征来作弊,对负面样本中的图像也执行相同的掩码和遮盖操作。在负面样本中,TIA的目标标签全部设置为[Covered]

1.2.4 预训练数据

  • 为了预训练和评估LayoutLMv2模型,作者从富含视觉元素的文档理解领域中选择了广泛的数据集。

  • 使用IIT-CDIP作为预训练数据集。

1.3 模型微调

  • 文档级别分类任务RVL-CDIP中,使用[CLS]输出以及池化的视觉令牌表示作为全局特征
  • 对于提取式问答任务DocVQA及其他四个实体提取任务,在LayoutLMv2输出的文本部分上构建特定任务的头部层。在DocVQA论文中,实验结果显示,在SQuAD数据集上微调过的BERT模型比原始BERT模型表现更优。受此启发,增加了一个额外的设置:首先在问题生成(Question Generation, QG)数据集上微调LayoutLMv2,随后再在DocVQA数据集上微调。这个QG数据集包含近百万对由训练于SQuAD数据集的生成模型产生的问题-答案对。

1.4 LayoutXLM模型结构

  • LayoutXLM是 LayoutLMv2的多语言扩展版本。为了准确评估LayoutXLM,论文中还引入了一个多语言表单理解基准数据集,名为XFUND,该数据集包含了7种语言(中文、日语、西班牙语、法语、意大利语、德语、葡萄牙语)的表单理解样本,并为每种语言的手工标注了键值对。
  • 论文链接:https://arxiv.org/pdf/2104.08836
  • LayoutXLM预训练策略,同LayoutLMv2
  • 该框架如下图所示:
    • 模型接收来自三种不同模态的信息,即文本、布局和图像,分别使用文本嵌入、布局嵌入和视觉嵌入层进行编码。文本和图像嵌入被连接在一起,然后加上布局嵌入以获得输入嵌入。
    • 输入嵌入通过带有空间感知自注意力机制的多模态Transformer进行编码。
    • 最后,输出的上下文表示可以用于后续的任务特定层。

在这里插入图片描述

1.5 VI-LayoutXLM

  • 百度在PP-StructureV2中,针对 LayoutXLM 进行改进,得到了VI-LayoutXLM。

  • 论文链接:https://arxiv.org/pdf/2210.05391

  • 模型部分改进如下:

    在这里插入图片描述

    • LayoutLMv2 以及 LayoutXLM 中引入视觉骨干网络,用于提取视觉特征,并与后续的 text embedding 进行联合,作为多模态的输入 embedding。但是该模块为基于 ResNet_x101_64x4d 的特征提取网络,特征抽取阶段耗时严重。
    • 因此,移除视觉特征提取模块,同时仍然保留文本、位置以及布局等信息,最终发现针对 LayoutXLM 进行改进,下游 SER 任务精度无损,针对 LayoutLMv2 进行改进,下游 SER 任务精度仅降低2.1%,而模型大小减小了约340M。

在这里插入图片描述

2 VI-LayoutXLM在发票数据集上的应用

  • 关键信息抽取 (Key Information Extraction, KIE)指的是是从文本或者图像中,抽取出关键的信息。

    • 针对文档图像的关键信息抽取任务作为OCR的下游任务,存在非常多的实际应用场景,如表单识别、车票信息抽取、身份证信息抽取等。
    • 文档图像中的KIE一般包含2个子任务,示意图如下图所示。
      • SER: 语义实体识别 (Semantic Entity Recognition),对每一个检测到的文本进行分类,如将其分为姓名,身份证。如下图中的黑色框和红色框。
      • RE: 关系抽取 (Relation Extraction),对每一个检测到的文本进行分类,如将其分为问题 (key) 和答案 (value) 。然后对每一个问题找到对应的答案,相当于完成key-value的匹配过程。如下图中的红色框和黑色框分别代表问题和答案,黄色线代表问题和答案之间的对应关系。

    在这里插入图片描述

  • 除了视觉特征无关的多模态预训练模型结构,paddleocr中在KIE任务上,还有两个主要的优化策略:

    • TB-YX:考虑阅读顺序的文本行排序逻辑
      • 文本阅读顺序对于信息抽取与文本理解等任务至关重要,传统多模态模型中,没有考虑不同 OCR 工具可能产生的不正确阅读顺序,而模型输入中包含位置编码,阅读顺序会直接影响预测结果
      • 在预处理中,对文本行按照从上到下,从左到右(YX)的顺序进行排序,为防止文本行位置轻微干扰带来的排序结果不稳定问题,在排序的过程中,引入位置偏移阈值 Th,对于 Y 方向距离小于 Th 的2个文本内容,使用 X 方向的位置从左到右进行排序。
    • UDML:联合互学习知识蒸馏策略
      • UDML(Unified-Deep Mutual Learning)联合互学习是 PP-OCRv2 与 PP-OCRv3 中采用的对于文本识别非常有效的提升模型效果的策略。
      • 在训练时,引入2个完全相同的模型进行互学习,计算2个模型之间的互蒸馏损失函数(DML loss),同时对 transformer 中间层的输出结果计算距离损失函数(L2 loss)。
      • 使用该策略,最终 XFUND 数据集上,SER 任务 F1 指标提升0.6%,RE 任务 F1 指标提升5.01%。

    在这里插入图片描述

  • KIE常用思路有如下两种:

    • 一种是SER:

      • 直接使用SER,获取关键信息的类别;常用于关键信息类别固定的场景。
      • 以身份证场景为例, 关键信息一般包含姓名性别民族等,我们直接将对应的字段标注为特定的类别即可,如下图所示:

      在这里插入图片描述

      • 注意:

        • 标注过程中,对于无关于KIE关键信息的文本内容,均需要将其标注为other类别,相当于背景信息。如在身份证场景中,如果我们不关注性别信息,那么可以将“性别”与“男”这2个字段的类别均标注为other
        • 标注过程中,需要以文本行为单位进行标注,无需标注单个字符的位置信息。

        数据量方面,一般来说,对于比较固定的场景,50张左右的训练图片即可达到可以接受的效果,可以使用PPOCRLabel完成KIE的标注过程。

    • 一种是SER+RE:

      • 联合使用SER+RE,先利用SER找到key和value,然后再利用RE进行匹配;常用于关系类别不固定的场景。
      • 以身份证场景为例, 关键信息一般包含姓名性别民族等关键信息。在SER阶段,我们需要识别所有的question (key) 与answer (value) 。每个字段的类别信息(label字段)可以是question、answer或者other(与待抽取的关键信息无关的字段)

      在这里插入图片描述

      • 在RE阶段,需要标注每个字段的的id与连接信息,如下图所示:
        • 标注过程中,如果value是多个字符,那么linking中可以新增一个key-value对,如[[0, 1], [0, 2]]
        • 数据量方面,一般来说,对于比较固定的场景,50张左右的训练图片即可达到可以接受的效果,可以使用PPOCRLabel完成KIE的标注过程。

      在这里插入图片描述

    • 我们参考案例:https://aistudio.baidu.com/projectdetail/4823162(项目里提供了发票数据集),来对VI-LayoutXLM模型有更深的认识。

2.1 语义实体识别 (SER)

2.1.1 模型构建

  • 我这里不用命令行执行,在paddleocr\tests目录下创建一个py文件执行训练过程

  • 我们复制一份paddleocr\configs\kie\vi_layoutxlm\ser_vi_layoutxlm_xfund_zh_udml.yml文件到paddleocr\tests\configs进行修改(参考上面项目链接进行修改),发票数据集在上面项目中已提供,模型部分的配置文件如下:

    Architecture:
      model_type: &model_type "kie"
      name: DistillationModel
      algorithm: Distillation
      Models:
        Teacher:
          pretrained:
          freeze_params: false
          return_all_feats: true
          model_type: *model_type
          algorithm: &algorithm "LayoutXLM"
          Transform:
          Backbone:
            name: LayoutXLMForSer
            pretrained: True             # 会利用paddle-nlp加载预训练模型
            # one of base or vi
            mode: vi
            checkpoints:
            num_classes: &num_classes 5  # 采用BIO的标注,训练需要修改
        Student:
          pretrained:
          freeze_params: false
          return_all_feats: true
          model_type: *model_type
          algorithm: *algorithm
          Transform:
          Backbone:
            name: LayoutXLMForSer
            pretrained: True
            # one of base or vi
            mode: vi
            checkpoints:
            num_classes: *num_classes
    
  • 通过下面的py文件,我们就可以愉快的查看源码了。

def train_kie_token_ser_demo():
    from tools.train import program, set_seed, main
    # 配置文件的源地址地址: paddleocr\configs\kie\vi_layoutxlm\ser_vi_layoutxlm_xfund_zh_udml.yml
    config, device, logger, vdl_writer = program.preprocess(is_train=True)

    ###############修改配置(也可在yml文件中修改)##################
    # 评估频率
    config["Global"]["eval_batch_step"] = [0, 200]
    # log的打印频率
    config["Global"]["print_batch_step"] = 50
    # 训练的epochs
    config["Global"]["epoch_num"] = 1
    # 随机种子
    seed = config["Global"]["seed"] if "seed" in config["Global"] else 1024
    set_seed(seed)

    ###############模型训练##################
    main(config, device, logger, vdl_writer, seed)


def train_kie_token_re_demo():
    from tools.train import program, set_seed, main
    # 配置文件的源地址地址: paddleocr\configs\kie\vi_layoutxlm\re_vi_layoutxlm_xfund_zh_udml.yml
    config, device, logger, vdl_writer = program.preprocess(is_train=True)

    ###############修改配置(也可在yml文件中修改)##################
    # 评估频率
    config["Global"]["eval_batch_step"] = [0, 200]
    # log的打印频率
    config["Global"]["print_batch_step"] = 50
    # 训练的epochs
    config["Global"]["epoch_num"] = 1
    # 随机种子
    seed = config["Global"]["seed"] if "seed" in config["Global"] else 1024
    set_seed(seed)

    ###############模型训练##################
    main(config, device, logger, vdl_writer, seed)

if __name__ == '__main__':
    train_kie_token_ser_demo()
    # train_kie_token_re_demo()

LayoutXLMForTokenClassification

  • 首先,利用LayoutXLMModel提取特征(文本、布局信息)
  • 然后,利用文本部分的特征进行BIO多分类
# paddleocr.ppocr.modeling.backbones.vqa_layoutlm.py
class LayoutXLMForTokenClassification(LayoutXLMPretrainedModel):
    def __init__(self, config: LayoutXLMConfig):
        super(LayoutXLMForTokenClassification, self).__init__(config)
        self.num_classes = config.num_labels
        self.layoutxlm = LayoutXLMModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, self.num_classes)

    ......

    def forward(
        self,
        input_ids=None,
        bbox=None,
        image=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        labels=None,
    ):
        # 1、经过12层的Transformer Block Encoder
        outputs = self.layoutxlm(
            input_ids=input_ids,
            bbox=bbox,
            image=image,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
        )
        seq_length = input_ids.shape[1]
        
        # sequence out and image out
        # 2、进行BIO多分类
        # sequence_output: (bs, 561, 768) -> (bs, 512, 768) -> (bs, 512, 5)
        sequence_output = outputs[0][:, :seq_length]
        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)

        hidden_states = {
            f"hidden_states_{idx}": outputs[2][f"{idx}_data"] for idx in range(self.layoutxlm.config.num_hidden_layers)
        }
        if self.training:
            outputs = (logits, hidden_states)
        else:
            outputs = (logits,)

        ......

        return outputs

LayoutXLMModel

这里我们主要看下LayoutXLMModel模型中,文本的embedding和视觉部分的embedding。

  • 文本的embedding:

    • word_embeddings:对tokenizer后的input_ids进行word_embeddings,shape变化:(bs, 512) -> (bs, 512, 768)

    • position_embeddings(1D position embedding):对文本部分的position_ids进行embeding,shape变化:(bs, 512) -> (bs, 512, 768)。这里,文本和视觉的position_embeddings是共享的。

    • spatial_position_embeddings:这里shape变化为(bs, 512, 4) -> (bs, 512, 768),是将每一个bbox信息的(x_min, y_min, x_max, y_max, h, w)编码,然后concat得到,代码如下所示。注意:如果一个bbox内的文字,被切分为多个token,那么这些token的bbox信息是一致的。

          # paddlenlp.transformers.layoutxlm.modeling.py
          def _cal_spatial_position_embeddings(self, bbox):
              try:
                  # (bs, embdedding_dim) -> (bs, embdedding_dim, 128)
                  left_position_embeddings = self.x_position_embeddings(bbox[:, :, 0])
                  # (bs, embdedding_dim) -> (bs, embdedding_dim, 128)
                  upper_position_embeddings = self.y_position_embeddings(bbox[:, :, 1])
                  # (bs, embdedding_dim) -> (bs, embdedding_dim, 128)
                  right_position_embeddings = self.x_position_embeddings(bbox[:, :, 2])
                  # (bs, embdedding_dim) -> (bs, embdedding_dim, 128)
                  lower_position_embeddings = self.y_position_embeddings(bbox[:, :, 3])
              except IndexError as e:
                  raise IndexError("The :obj:`bbox`coordinate values should be within 0-1000 range.") from e
              # (bs, embdedding_dim) -> (bs, embdedding_dim, 128)
              h_position_embeddings = self.h_position_embeddings(bbox[:, :, 3] - bbox[:, :, 1])
              # (bs, embdedding_dim) -> (bs, embdedding_dim, 128)
              w_position_embeddings = self.w_position_embeddings(bbox[:, :, 2] - bbox[:, :, 0])
      
              #  [x_min, y_min, x_max, y_max, h, w] concat -> (bs, embdedding_dim, 128*6)
              spatial_position_embeddings = paddle.concat(
                  [
                      left_position_embeddings,
                      upper_position_embeddings,
                      right_position_embeddings,
                      lower_position_embeddings,
                      h_position_embeddings,
                      w_position_embeddings,
                  ],
                  axis=-1,
              )
              return spatial_position_embeddings
      
    • token_type_embeddings:这里的token_type_ids全为0,shape变化为(bs, 512) -> (bs, 512, 768)

  • 视觉部分的embedding:

    • position_embeddings(1D position embedding):shape变化为(bs, 49) -> (bs, 49, 768)。视觉部分的position ids为:[0, 1, 2, …, 48] -> (bs, 49)。这里虽然去除了视觉提取,但是position ids按照图像224×224经过降采样32倍后的feature map:7×7进行生成。这里,文本和视觉的position_embeddings是共享的;
    • spatial_position_embeddings:视觉部分布局信息,即bbox的生成的核心逻辑是:7×7网格中,每一个小的正方形的坐标(x_min, y_min, x_max, y_max)即为一个视觉token。shape变化为(bs, 49, 4) -> (bs, 49, 768)
    • visual_segment_embedding
  • 最终,将文本的embedding和视觉部分的embedding送入到12层的Transformer Encoder Block提取特征。

# paddlenlp.transformers.layoutxlm.modeling.py
@register_base_model
class LayoutXLMModel(LayoutXLMPretrainedModel):

    def __init__(self, config: LayoutXLMConfig):
        super(LayoutXLMModel, self).__init__(config)
        self.config = config
        self.use_visual_backbone = config.use_visual_backbone
        self.has_visual_segment_embedding = config.has_visual_segment_embedding
        self.embeddings = LayoutXLMEmbeddings(config)

        if self.use_visual_backbone is True:
            self.visual = VisualBackbone(config)
            self.visual.stop_gradient = True
            self.visual_proj = nn.Linear(config.image_feature_pool_shape[-1], config.hidden_size)

        if self.has_visual_segment_embedding:
            self.visual_segment_embedding = self.create_parameter(
                shape=[
                    config.hidden_size,
                ],
                dtype=paddle.float32,
            )
        self.visual_LayerNorm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps)
        self.visual_dropout = nn.Dropout(config.hidden_dropout_prob)

        self.encoder = LayoutXLMEncoder(config)
        self.pooler = LayoutXLMPooler(config)



    def _calc_visual_bbox(self, image_feature_pool_shape, bbox, visual_shape):
        """
           视觉部分布局信息,即bbox的生成:
                 - image_feature_pool_shape:(7, 7, 256)
                 - 文字token的bbox信息:(bs, 512, 4)
                 - visual_shape:[bs, 49]
        """
        # 首先,生成一个序列[0, 1000, 2000, 3000, 4000, 5000, 6000, 7000]
        # 然后,离散化为[0, 1000],即[0, 142, 285, 428, 571, 714, 857, 1000]
        visual_bbox_x = (
            paddle.arange(
                0,
                1000 * (image_feature_pool_shape[1] + 1),
                1000,
                dtype=bbox.dtype,
            )
            // image_feature_pool_shape[1]
        )
        visual_bbox_y = (
            paddle.arange(
                0,
                1000 * (image_feature_pool_shape[0] + 1),
                1000,
                dtype=bbox.dtype,
            )
            // image_feature_pool_shape[0]
        )

        expand_shape = image_feature_pool_shape[0:2] # (7, 7)
        # 7×7网格中,每一个小的正方形的坐标(x_min, y_min, x_max, y_max)即为一个视觉token
        # visual_bbox shape = (7×7, 4)
        visual_bbox = paddle.stack(
            [
                visual_bbox_x[:-1].expand(expand_shape),
                visual_bbox_y[:-1].expand(expand_shape[::-1]).transpose([1, 0]),
                visual_bbox_x[1:].expand(expand_shape),
                visual_bbox_y[1:].expand(expand_shape[::-1]).transpose([1, 0]),
            ],
            axis=-1,
        ).reshape([expand_shape[0] * expand_shape[1], paddle.shape(bbox)[-1]])
        # 扩展到bs个样本, (7×7, 4) -> (bs, 7×7, 4)
        visual_bbox = visual_bbox.expand([visual_shape[0], visual_bbox.shape[0], visual_bbox.shape[1]])
        return visual_bbox

    def _calc_text_embeddings(self, input_ids, bbox, position_ids, token_type_ids):
        """
          文本部分进行embeddings:
                  word_embeddings
                + position_embeddings(文本和视觉的position_embeddings是共享的)
                + spatial_position_embeddings
                + token_type_embeddings
        """
        # (bs, 512) -> (bs, 512, 768)
        words_embeddings = self.embeddings.word_embeddings(input_ids)
        # (bs, 512) -> (bs, 512, 768)
        position_embeddings = self.embeddings.position_embeddings(position_ids)
        # (bs, 512, 4) -> (bs, 512, 768)
        spatial_position_embeddings = self.embeddings._cal_spatial_position_embeddings(bbox)
        # (bs, 512) -> (bs, 512, 768)
        token_type_embeddings = self.embeddings.token_type_embeddings(token_type_ids)
        # 4种embedding相加
        embeddings = words_embeddings + position_embeddings + spatial_position_embeddings + token_type_embeddings
        # LayerNorm + dropout
        embeddings = self.embeddings.LayerNorm(embeddings)
        embeddings = self.embeddings.dropout(embeddings)
        return embeddings


    def _calc_img_embeddings(self, image, bbox, position_ids):
        """
            视觉部分进行embedding:
                    position_embeddings(文本和视觉的position_embeddings是共享的)
                +   spatial_position_embeddings
                +   visual_segment_embedding
        """
        use_image_info = self.use_visual_backbone and image is not None
        # (bs, 49) -> (bs, 49, 768)
        position_embeddings = self.embeddings.position_embeddings(position_ids)
        # (bs, 49, 4) -> (bs, 49, 768)
        spatial_position_embeddings = self.embeddings._cal_spatial_position_embeddings(bbox)
        if use_image_info is True:
            visual_embeddings = self.visual_proj(self.visual(image.astype(paddle.float32)))
            embeddings = visual_embeddings + position_embeddings + spatial_position_embeddings
        else:
            # embedding相加
            embeddings = position_embeddings + spatial_position_embeddings

        if self.has_visual_segment_embedding:
            # self.visual_segment_embedding shape = (768)
            embeddings += self.visual_segment_embedding

        #  visual_LayerNorm + visual_dropout
        embeddings = self.visual_LayerNorm(embeddings)
        embeddings = self.visual_dropout(embeddings)
        return embeddings

    
    def forward(
        self,
        input_ids=None,
        bbox=None,
        image=None,
        token_type_ids=None,
        position_ids=None,
        attention_mask=None,
        head_mask=None,
        output_hidden_states=False,
        output_attentions=False,
    ):
        input_shape = paddle.shape(input_ids)
        visual_shape = list(input_shape)
        visual_shape[1] = self.config.image_feature_pool_shape[0] * self.config.image_feature_pool_shape[1]

        # 视觉部分的bbox的生成
        # 视觉token被视为均匀划分的网格
        # 生成的bbox信息:feature_map(7×7)网格中,每一个小的正方形的坐标(x_min, y_min, x_max, y_max)即为一个视觉token
        visual_bbox = self._calc_visual_bbox(self.config.image_feature_pool_shape, bbox, visual_shape)

        # 1、2D position embedding(文本部分bbox+视觉部分bbox)
        # (bs, 512, 4) + (bs, 49, 4) -> (bs, 561, 4)
        final_bbox = paddle.concat([bbox, visual_bbox], axis=1)
        if attention_mask is None:
            attention_mask = paddle.ones(input_shape)

        if self.use_visual_backbone is True:
            # 使用视觉部分的backbone
            visual_attention_mask = paddle.ones(visual_shape)
        else:
            # 移除视觉特征提取模块,mask全设置为0
            visual_attention_mask = paddle.zeros(visual_shape)

        attention_mask = attention_mask.astype(visual_attention_mask.dtype)
        # concat后attention_mask:(bs, 512) + (bs, 49) -> (bs, 561)
        final_attention_mask = paddle.concat([attention_mask, visual_attention_mask], axis=1)

        if token_type_ids is None:
            token_type_ids = paddle.zeros(input_shape, dtype=paddle.int64)


        # 2、1D position embedding(文本部分+视觉部分) (bs, 512) + (bs, 49) -> (bs, 561)
        if position_ids is None:
            # 文本部分的position embedding
            seq_length = input_shape[1]
            position_ids = self.embeddings.position_ids[:, :seq_length]
            position_ids = position_ids.expand(input_shape)

        # 视觉部分的position embedding
        # [0, 1, 2, ..., 48] -> (bs, 49)
        visual_position_ids = paddle.arange(0, visual_shape[1]).expand([input_shape[0], visual_shape[1]])
        final_position_ids = paddle.concat([position_ids, visual_position_ids], axis=1)

        if bbox is None:
            bbox = paddle.zeros(input_shape + [4])

        # 3、 text embedding & visual  (bs, 512, 768) + (bs, 49, 768) -> (bs, 561, 768)
        # 文本部分进行embdedding (bs, 512) -> (bs, 512, 768)
        text_layout_emb = self._calc_text_embeddings(
            input_ids=input_ids,
            bbox=bbox,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
        )
        # 视觉部分进行embedding(注意此时没有image,仅有视觉的bbox以及position_ids)
        visual_emb = self._calc_img_embeddings(
            image=image,
            bbox=visual_bbox,
            position_ids=visual_position_ids,
        )
        final_emb = paddle.concat([text_layout_emb, visual_emb], axis=1)
        # (bs, 561) -> (bs, 1, 1, 561)
        extended_attention_mask = final_attention_mask.unsqueeze(1).unsqueeze(2)

        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        if head_mask is not None:
            if head_mask.dim() == 1:
                head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
                head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
            elif head_mask.dim() == 2:
                head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
        else:
            head_mask = [None] * self.config.num_hidden_layers

        # 经过Transformer Encoder Block(12层)
        encoder_outputs = self.encoder(
            final_emb,                        # 文本&视觉部分的embedding , shape=(bs, 561, 768)
            extended_attention_mask,          # attention_mask        , shape=(bs, 1, 1, 561)
            bbox=final_bbox,                  # 2D position embedding【如果需要相对位置位置编码,加在attention_score上,这里为False】, shape=(bs, 561, 4)
            position_ids=final_position_ids,  # 1D position embedding【如果需要相对位置位置编码,加在attention_score上,这里为False】, shape=(bs, 561)
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
        )
        # sequence_output shape = (bs, 561, 768)
        sequence_output = encoder_outputs[0]
        # pooled_output shape = (bs, 768)
        pooled_output = self.pooler(sequence_output)
        return sequence_output, pooled_output, encoder_outputs[1]

2.1.2 损失计算

  • 由于使用了UDML:联合互学习知识蒸馏策略,损失计算的配置如下:

在这里插入图片描述

Loss:
  name: CombinedLoss                     # ppocr.losses.combined_loss.CombinedLoss
  loss_config_list:
  - DistillationVQASerTokenLayoutLMLoss: # GT loss   ppocr.losses.distillation_loss.DistillationVQASerTokenLayoutLMLoss
      weight: 1.0
      model_name_list: ["Student", "Teacher"]
      key: backbone_out
      num_classes: *num_classes
  - DistillationSERDMLLoss:              # DML loss  ppocr.losses.distillation_loss.DistillationSERDMLLoss
      weight: 1.0
      act: "softmax"
      use_log: true
      model_name_pairs:
      - ["Student", "Teacher"]
      key: backbone_out
  - DistillationVQADistanceLoss:         # S5 loss  ppocr.losses.distillation_loss.DistillationVQADistanceLoss
      weight: 0.5
      mode: "l2"
      model_name_pairs:
        - ["Student", "Teacher"]
      key: hidden_states_5
      name: "loss_5"
  - DistillationVQADistanceLoss:         # S8 loss  ppocr.losses.distillation_loss.DistillationVQADistanceLoss
      weight: 0.5
      mode: "l2"
      model_name_pairs:
        - ["Student", "Teacher"]
      key: hidden_states_8
      name: "loss_8"
  • 如下所示,在DistillationModel中,Teacher和Student模型分别进行前向过程
# paddleocr.ppocr.modeling.architectures.distillation_model.py
class DistillationModel(nn.Layer):
    def __init__(self, config):
        """
        the module for OCR distillation.
        args:
            config (dict): the super parameters for module.
        """
        super().__init__()
        self.model_list = []
        self.model_name_list = []
        for key in config["Models"]:
            model_config = config["Models"][key]
            freeze_params = False
            pretrained = None
            if "freeze_params" in model_config:
                freeze_params = model_config.pop("freeze_params")
            if "pretrained" in model_config:
                pretrained = model_config.pop("pretrained")
            model = BaseModel(model_config)
            if pretrained is not None:
                load_pretrained_params(model, pretrained)
            if freeze_params:
                for param in model.parameters():
                    param.trainable = False
            self.model_list.append(self.add_sublayer(key, model))
            self.model_name_list.append(key)

    def forward(self, x, data=None):
        result_dict = dict()
        # 执行所有模型的前向过程, 例如:Teacher和Student模型
        for idx, model_name in enumerate(self.model_name_list):
            result_dict[model_name] = self.model_list[idx](x, data)
        return result_dict
  • 在CombinedLoss中遍历配置的损失函数,分别计算损失,最后相加最为总损失
# paddleocr.ppocr.losses.combined_loss.py
class CombinedLoss(nn.Layer):
    """
    CombinedLoss:
        a combionation of loss function
    """

    def __init__(self, loss_config_list=None):
        super().__init__()
        self.loss_func = []
        self.loss_weight = []
        assert isinstance(loss_config_list, list), "operator config should be a list"
        ......

    def forward(self, input, batch, **kargs):
        # input包含Teacher模型以及Student模型的输出结果
        # batch是批次数据,里面包含label
        loss_dict = {}
        loss_all = 0.0
        # 遍历配置的所有的损失函数,计算损失
        for idx, loss_func in enumerate(self.loss_func):
            loss = loss_func(input, batch, **kargs)
            if isinstance(loss, paddle.Tensor):
                loss = {"loss_{}_{}".format(str(loss), idx): loss}

            weight = self.loss_weight[idx]

            loss = {key: loss[key] * weight for key in loss}

            if "loss" in loss:
                loss_all += loss["loss"]
            else:
                loss_all += paddle.add_n(list(loss.values()))
            loss_dict.update(loss)
        loss_dict["loss"] = loss_all
        return loss_dict
  • 我们看下具体配置的损失函数:

    • DistillationVQASerTokenLayoutLMLoss的实质就是每个模型分别计算NER任务的CrossEntropyLoss,即GT loss:

      class DistillationVQASerTokenLayoutLMLoss(VQASerTokenLayoutLMLoss):
          def __init__(self, num_classes, model_name_list=[], key=None, name="loss_ser"):
              super().__init__(num_classes=num_classes)
              self.model_name_list = model_name_list
              self.key = key
              self.name = name
      
          def forward(self, predicts, batch):
              loss_dict = dict()
              # 遍历Teacher模型、Student模型
              for idx, model_name in enumerate(self.model_name_list):
                  # 先从predicts取出相关模型的预测结果
                  out = predicts[model_name]
                  # 然后,从out中取出key(即配置文件中配置的backbone_out)的值
                  if self.key is not None:
                      out = out[self.key]
                  # 调用父类,计算损失
                  loss = super().forward(out, batch)
                  loss_dict["{}_{}".format(self.name, model_name)] = loss["loss"]
              return loss_dict
      
      # DistillationVQASerTokenLayoutLMLoss的父类
      class VQASerTokenLayoutLMLoss(nn.Layer):
          def __init__(self, num_classes, key=None):
              super().__init__()
              self.loss_class = nn.CrossEntropyLoss()
              self.num_classes = num_classes
              self.ignore_index = self.loss_class.ignore_index
              self.key = key
      
          def forward(self, predicts, batch):
              if isinstance(predicts, dict) and self.key is not None:
                  predicts = predicts[self.key]
              labels = batch[5]           # (bs, 512)
              attention_mask = batch[2]   # (bs, 512)
              if attention_mask is not None:
                  active_loss = (
                      attention_mask.reshape(
                          [
                              -1,
                          ]
                      )
                      == 1
                  )
                  # active_output_shape = (bs, 512, 5) -> (bs*512, 5)
                  active_output = predicts.reshape([-1, self.num_classes])[active_loss]
                  # active_label_shape = bs*512
                  active_label = labels.reshape(
                      [
                          -1,
                      ]
                  )[active_loss]
                  # 交叉熵损失函数
                  loss = self.loss_class(active_output, active_label)
              else:
                  loss = self.loss_class(
                      predicts.reshape([-1, self.num_classes]),
                      labels.reshape(
                          [
                              -1,
                          ]
                      ),
                  )
              return {"loss": loss}
      
    • DistillationSERDMLLoss实质是计算Techaer和Student模型之间的互蒸馏损失函数,即KL散度。

      class DistillationSERDMLLoss(DMLLoss):
          """ """
      
          def __init__(
              self,
              act="softmax",
              use_log=True,
              num_classes=7,
              model_name_pairs=[],
              key=None,
              name="loss_dml_ser",
          ):
              super().__init__(act=act, use_log=use_log)
              assert isinstance(model_name_pairs, list)
              self.key = key
              self.name = name
              self.num_classes = num_classes
              self.model_name_pairs = model_name_pairs
      
          def forward(self, predicts, batch):
              loss_dict = dict()
              # 遍历Teacher模型、Student模型
              for idx, pair in enumerate(self.model_name_pairs):
                  # 取出Teacher模型以及Student模型中的结果
                  out1 = predicts[pair[0]]
                  out2 = predicts[pair[1]]
                  if self.key is not None:
                      # 取出backbone_out
                      out1 = out1[self.key]
                      out2 = out2[self.key]
                  out1 = out1.reshape([-1, out1.shape[-1]])
                  out2 = out2.reshape([-1, out2.shape[-1]])
      
                  attention_mask = batch[2]
                  if attention_mask is not None:
                      active_output = (
                          attention_mask.reshape(
                              [
                                  -1,
                              ]
                          )
                          == 1
                      )
                      out1 = out1[active_output]
                      out2 = out2[active_output]
                  # 调用父类的方法
                  loss_dict["{}_{}".format(self.name, idx)] = super().forward(out1, out2)
      
              return loss_dict
      
      # DistillationSERDMLLoss的父类   
      class DMLLoss(nn.Layer):
          """
          DMLLoss
          """
      
          def __init__(self, act=None, use_log=False):
              super().__init__()
              if act is not None:
                  assert act in ["softmax", "sigmoid"]
              if act == "softmax":
                  self.act = nn.Softmax(axis=-1)
              elif act == "sigmoid":
                  self.act = nn.Sigmoid()
              else:
                  self.act = None
      
              self.use_log = use_log
              self.jskl_loss = KLJSLoss(mode="kl")
      
          def _kldiv(self, x, target):
              """
                  计算两个概率分布之间的KL散度:
                      KL散度的公式是 KL(P||Q) = ΣP(x) * log(P(x)/Q(x)),这里将其重写为ΣP(x)*(log(P(x)) - log(Q(x)))
                      即target * (paddle.log(target + eps) - x)
              """
              eps = 1.0e-10
              loss = target * (paddle.log(target + eps) - x)
              # batch mean loss
              loss = paddle.sum(loss) / loss.shape[0]
              return loss
      
          def forward(self, out1, out2):
              if self.act is not None:
                  out1 = self.act(out1) + 1e-10
                  out2 = self.act(out2) + 1e-10
              if self.use_log:
                  # 计算KL散度
                  # for recognition distillation, log is needed for feature map
                  log_out1 = paddle.log(out1)
                  log_out2 = paddle.log(out2)
                  loss = (self._kldiv(log_out1, out2) + self._kldiv(log_out2, out1)) / 2.0
              else:
                  # for detection distillation log is not needed
                  loss = self.jskl_loss(out1, out2)
              return loss    
      
    • DistillationVQADistanceLoss,本质是对 transformer 中间层的输出结果计算距离损失函数(L2 loss)

      # DistillationVQADistanceLoss的父类
      class DistanceLoss(nn.Layer):
          """
          DistanceLoss:
              mode: loss mode
          """
      
          def __init__(self, mode="l2", **kargs):
              super().__init__()
              assert mode in ["l1", "l2", "smooth_l1"]
              if mode == "l1":
                  self.loss_func = nn.L1Loss(**kargs)
              elif mode == "l2":
                  self.loss_func = nn.MSELoss(**kargs)
              elif mode == "smooth_l1":
                  self.loss_func = nn.SmoothL1Loss(**kargs)
      
          def forward(self, x, y):
              return self.loss_func(x, y)
      

      其他部分,诸如数据集的加载、构建优化器、创建评估函数、加载预训练模型、模型训练等,大家可以查看源码,不再赘述。

2.2 关系抽取(RE)

  • 我们这里,看下模型的构建部分代码,其他代码,大家可以查看源码,不再赘述。
# paddlenlp.transformers.layoutxlm.modeling.py
class LayoutXLMForRelationExtraction(LayoutXLMPretrainedModel):
    def __init__(self, config: LayoutXLMConfig):
        super(LayoutXLMForRelationExtraction, self).__init__(config)

        self.layoutxlm = LayoutXLMModel(config)

        self.extractor = REDecoder(config.hidden_size, config.hidden_dropout_prob)

        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    
    ......
    def forward(
        self,
        input_ids,
        bbox,
        image=None,
        attention_mask=None,
        entities=None,
        relations=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        labels=None,
    ):
        # 1、经过12层的Transformer Block Encoder
        outputs = self.layoutxlm(
            input_ids=input_ids,            # (bs, 512)
            bbox=bbox,                      # (bs, 512, 4)
            image=image,                    # None
            attention_mask=attention_mask,  # (bs, 512)
            token_type_ids=token_type_ids,  # (bs. 512)
            position_ids=position_ids,      # None
            head_mask=head_mask,            # None
        )
        seq_length = input_ids.shape[1]
        # 最后一层输出
        # sequence_output_shape = (bs, 512, 768)
        sequence_output = outputs[0][:, :seq_length]
        sequence_output = self.dropout(sequence_output)

        # 2、计算loss和预测关系
        loss, pred_relations = self.extractor(sequence_output, entities, relations)

        hidden_states = [outputs[2][f"{idx}_data"] for idx in range(self.layoutxlm.config.num_hidden_layers)]
        hidden_states = paddle.stack(hidden_states, axis=1)
        # 3、返回结果
        res = dict(loss=loss, pred_relations=pred_relations, hidden_states=hidden_states)
        return res
  • 主要代码在REDecoder中

    • 首先,构建构建关系对的正负样本
    • 然后,获取关系头(question)、关系尾(answer)对应的特征信息
      • 获取关系头(即question)在input_ids中开始的索引对应token的hidden_states(shape=(100, 768))和关系头(question)经过Embedding后的结果(shape=(100, 768))进行concat
      • 获取关系尾(即answer)在input_ids中开始的索引对应token的hidden_states(shape=(100, 768))和关系尾(answer)经过Embedding后的结果(shape=(100, 768))进行concat
    • 利用提取到的head_repr、tail_repr特征信息进行关系分类
    • 最后,利用预测结果,计算交叉熵损失等
    • 下面,给出一个relations和entities示例,方便理解。

    在这里插入图片描述

class REDecoder(nn.Layer):
    def __init__(self, hidden_size=768, hidden_dropout_prob=0.1):
        super(REDecoder, self).__init__()
        self.entity_emb = nn.Embedding(3, hidden_size)
        # 100代表:100个关系对
        # (100, 1536) -> (100, 768) -> (100, 384)
        projection = nn.Sequential(
            nn.Linear(hidden_size * 2, hidden_size),
            nn.ReLU(),
            nn.Dropout(hidden_dropout_prob),
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(hidden_dropout_prob),
        )
        self.ffnn_head = copy.deepcopy(projection)
        self.ffnn_tail = copy.deepcopy(projection)
        # (100, 384) -> (100, 2)
        self.rel_classifier = BiaffineAttention(hidden_size // 2, 2)
        self.loss_fct = CrossEntropyLoss()

    def build_relation(self, relations, entities):
        """
            relations_shape = (bs, 262145, 2)
            entities_shape  = (bs, 513, 3)
            注:
                relations第1个数组代表实际长度,例如:[10, 10],代表:关系对(QUESTION-ANSWER)只有10个,其他为填充
                entities第1个数组代表实际长度,例如:[20, 20, 20],代表:实例(QUESTION或ANSWER)只有20个,其他为填充
        """
        batch_size, max_seq_len = paddle.shape(entities)[:2]
        # new_relations_shape = (bs, 513*513, 3), 初始化为-1
        new_relations = paddle.full(
            shape=[batch_size, max_seq_len * max_seq_len, 3], fill_value=-1, dtype=relations.dtype
        )
        for b in range(batch_size):
            if entities[b, 0, 0] <= 2:
                entitie_new = paddle.full(shape=[512, 3], fill_value=-1, dtype=entities.dtype)
                entitie_new[0, :] = 2
                entitie_new[1:3, 0] = 0  # start
                entitie_new[1:3, 1] = 1  # end
                entitie_new[1:3, 2] = 0  # label
                entities[b] = entitie_new
            # 实体label_shape为: [2, 1, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 1, 1, 2, 2, 2, 2, 1]
            # all_possible_relations1为: [1 , 2 , 4 , 6 , 8 , 10, 12, 13, 14, 19]  QUESTION
            # all_possible_relations2为: [0 , 3 , 5 , 7 , 9 , 11, 15, 16, 17, 18]  ANSWER
            entitie_label = entities[b, 1 : entities[b, 0, 2] + 1, 2]
            all_possible_relations1 = paddle.arange(0, entities[b, 0, 2], dtype=entities.dtype)
            all_possible_relations1 = all_possible_relations1[entitie_label == 1]
            all_possible_relations2 = paddle.arange(0, entities[b, 0, 2], dtype=entities.dtype)
            all_possible_relations2 = all_possible_relations2[entitie_label == 2]

            # 所有可能的关系:all_possible_relations_shape:(100, 2)
            # [
            #   [1, 0],  [1, 3], ... , [1, 18],
            #   [2, 0],  [2, 3], ... , [2, 18],
            #           ......
            #   [19, 0], [19, 3], ... , [19, 18]
            # ]
            all_possible_relations = paddle.stack(
                paddle.meshgrid(all_possible_relations1, all_possible_relations2), axis=2
            ).reshape([-1, 2])
            if len(all_possible_relations) == 0:
                all_possible_relations = paddle.full_like(all_possible_relations, fill_value=-1, dtype=entities.dtype)
                all_possible_relations[0, 0] = 0
                all_possible_relations[0, 1] = 1
            # relation_head: [1 , 2 , 4 , 6 , 8 , 10, 12, 13, 14, 19]
            # relation_tail: [0 , 3 , 5 , 7 , 9 , 11, 17, 15, 16, 18]
            relation_head = relations[b, 1 : relations[b, 0, 0] + 1, 0]
            relation_tail = relations[b, 1 : relations[b, 0, 1] + 1, 1]
            # positive_relations_shape: (10, 2)
            positive_relations = paddle.stack([relation_head, relation_tail], axis=1)
            # (100, 2) -> (100, 10, 2)
            all_possible_relations_repeat = all_possible_relations.unsqueeze(axis=1).tile(
                [1, len(positive_relations), 1]
            )
            # (100, 2) -> (100, 10, 2)
            positive_relations_repeat = positive_relations.unsqueeze(axis=0).tile([len(all_possible_relations), 1, 1])
            # mask shape = (100, 10)
            mask = paddle.all(all_possible_relations_repeat == positive_relations_repeat, axis=2)
            # 获取关系对负样本
            # negative_mask = paddle.any(mask, axis=1) is False
            negative_mask = ~paddle.any(mask, axis=1)
            negative_relations = all_possible_relations[negative_mask]

            # 获取关系对正样本
            # positive_mask = paddle.any(mask, axis=0) is True
            positive_mask = paddle.any(mask, axis=0)
            positive_relations = positive_relations[positive_mask]
            if negative_mask.sum() > 0:
                # positive_relations_shape = (10, 2)
                # negative_relations_shape = (90, 2)
                # reordered_relations_shape = (100, 2)
                reordered_relations = paddle.concat([positive_relations, negative_relations])
            else:
                reordered_relations = positive_relations

            relation_per_doc_label = paddle.zeros([len(reordered_relations), 1], dtype=reordered_relations.dtype)
            relation_per_doc_label[: len(positive_relations)] = 1
            # relation_per_doc shape: (100, 3)
            """
            relation_per_doc = 
            [[1 , 0 , 1 ],# 正样本
             [2 , 3 , 1 ],
             [4 , 5 , 1 ],
             ......
             [19, 18, 1 ],
             [1 , 3 , 0 ],# 负样本
             [1 , 5 , 0 ],
             ......
                        ]
            """
            relation_per_doc = paddle.concat([reordered_relations, relation_per_doc_label], axis=1)
            assert len(relation_per_doc[:, 0]) != 0
            # 第1个元素记录正负样本的长度信息,例如:[100, 100, 100]
            new_relations[b, 0] = paddle.shape(relation_per_doc)[0].astype(new_relations.dtype)
            # 将正负样本放到new_relations中
            new_relations[b, 1 : len(relation_per_doc) + 1] = relation_per_doc
            # new_relations.append(relation_per_doc)
        return new_relations, entities

    def get_predicted_relations(self, logits, relations, entities):
        """
            logits: 预测得到的关系概率, 例如:shape = (100, 2)
            relations: shape = (100, 3)
            entities:  shape = (513, 3)
        """
        pred_relations = []
        for i, pred_label in enumerate(logits.argmax(-1)):
            if pred_label != 1:
                continue
            rel = paddle.full(shape=[7, 2], fill_value=-1, dtype=relations.dtype)
            rel[0, 0] = relations[:, 0][i]
            rel[1, 0] = entities[:, 0][relations[:, 0][i] + 1]
            rel[1, 1] = entities[:, 1][relations[:, 0][i] + 1]
            rel[2, 0] = entities[:, 2][relations[:, 0][i] + 1]
            rel[3, 0] = relations[:, 1][i]
            rel[4, 0] = entities[:, 0][relations[:, 1][i] + 1]
            rel[4, 1] = entities[:, 1][relations[:, 1][i] + 1]
            rel[5, 0] = entities[:, 2][relations[:, 1][i] + 1]
            rel[6, 0] = 1
            pred_relations.append(rel)
        return pred_relations

    def forward(self, hidden_states, entities, relations):
        """
            hidden_states_shape:(bs, 512, 768)
            entities_shape: (bs, 513, 3)    , 其中:513 = 512 + 1,第一个元素记录长度信息
            relations_shape: (bs, 262145, 2),其中:262145 = 512*512 + 1,第一个元素记录长度信息
        """
        batch_size, max_length, _ = paddle.shape(entities)
        # 1、构建关系的正负样本
        # relations_shape: (bs, 263169, 3) , 其中: 263169 = 513 * 513
        # entities_shape: (bs, 513, 3)
        relations, entities = self.build_relation(relations, entities)
        loss = 0
        # 所有预测关系结果
        all_pred_relations = paddle.full(
            shape=[batch_size, max_length * max_length, 7, 2], fill_value=-1, dtype=entities.dtype
        )
        for b in range(batch_size):
            # 2、获取关系头(question)、关系尾(answer)对应的特征信息
            # 取出正负样本关系对, relation_shape = (100, 3)
            relation = relations[b, 1 : relations[b, 0, 0] + 1]
            # 获取关系头(question)、关系尾(answer)、以及关系标签(1表示question和answer是一对,即正样本, 0表示负样本)
            head_entities = relation[:, 0]
            tail_entities = relation[:, 1]
            relation_labels = relation[:, 2]
            # 每一个实体(question或answer)在input_ids中开始的索引
            # 例如:  [0  , 3  , 4  , 8  , 14 , 16 , 23 , 29 , 34 , 37 , 60 , 65 , 82 , 84 ,
            #         87 , 90 , 91 , 96 , 102, 106]
            entities_start_index = paddle.to_tensor(entities[b, 1 : entities[b, 0, 0] + 1, 0])
            # 获取每个实体类型编号,1表示question,2表示answer
            # 例如:[2, 1, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 1, 1, 2, 2, 2, 2, 1]
            entities_labels = paddle.to_tensor(entities[b, 1 : entities[b, 0, 2] + 1, 2])
            # 获取关系头(即question)在input_ids中开始的索引,为了后面获取对应token的hidden_states
            head_index = entities_start_index[head_entities]
            # 获取关系头(question)对应的实体类型编号
            head_label = entities_labels[head_entities]
            # 关系头(question)经过Embedding, head_label_repr_shape = (100, 768)
            head_label_repr = self.entity_emb(head_label)

            # 获取关系尾(即answer)在input_ids中开始的索引,为了后面获取对应token的hidden_states
            tail_index = entities_start_index[tail_entities]
            # 获取关系尾(answer)对应的实体类型编号
            tail_label = entities_labels[tail_entities]
            # 关系尾(answer)经过Embedding, tail_label_repr_shape = (100, 768)
            tail_label_repr = self.entity_emb(tail_label)

            # 获取关系头(question)开始token的hidden_states, tmp_hidden_states shape: (100, 768)
            tmp_hidden_states = hidden_states[b][head_index]
            if len(tmp_hidden_states.shape) == 1:
                tmp_hidden_states = paddle.unsqueeze(tmp_hidden_states, axis=0)
            #  concat, head_repr_shape = (100, 1536)
            head_repr = paddle.concat((tmp_hidden_states, head_label_repr), axis=-1)

            # 获取关系尾(answer)开始token的hidden_states, tmp_hidden_states shape: (100, 768)
            tmp_hidden_states = hidden_states[b][tail_index]
            if len(tmp_hidden_states.shape) == 1:
                tmp_hidden_states = paddle.unsqueeze(tmp_hidden_states, axis=0)
            #  concat, tail_repr_shape = (100, 1536)
            tail_repr = paddle.concat((tmp_hidden_states, tail_label_repr), axis=-1)

            # 3、利用提取到的head_repr、tail_repr进行关系分类
            # heads_shape = (100, 1536) -> (100, 384)
            # tails_shape = (100, 1536) -> (100, 384)
            heads = self.ffnn_head(head_repr)
            tails = self.ffnn_tail(tail_repr)
            # 结合双线性层和线性层,实现对两个输入向量的复杂交互建模
            # logits_shape = (100, 2)
            logits = self.rel_classifier(heads, tails)

            # 4、计算交叉熵损失
            loss += self.loss_fct(logits, relation_labels)
            pred_relations = self.get_predicted_relations(logits, relation, entities[b])
            if len(pred_relations) > 0:
                pred_relations = paddle.stack(pred_relations)
                all_pred_relations[b, 0, :, :] = paddle.shape(pred_relations)[0].astype(all_pred_relations.dtype)
                all_pred_relations[b, 1 : len(pred_relations) + 1, :, :] = pred_relations
        return loss, all_pred_relations
  • 关于模型的预测代码(使用OCR结果进行预测等),可以参考https://aistudio.baidu.com/projectdetail/4823162。

http://www.kler.cn/news/367856.html

相关文章:

  • 智能工厂的软件设计 专有名词(juncture/relation/selection):意识形态及认知计算机科学的架构、系统和运用
  • springboot083基于springboot的个人理财系统--论文pf(论文+源码)_kaic
  • 时间序列预测(九)——门控循环单元网络(GRU)
  • 传奇开服教程之新GOM引擎登录器配置教程
  • GPS/北斗时空安全隔离装置(卫星时空防护装置)使用手册
  • Vue学习笔记(四)
  • 【Android】Kotlin教程(6)
  • 算法设计与分析:贪心算法思想的应用
  • Redisson(三)应用场景及demo
  • HTML+CSS实现超酷超炫的3D立方体相册
  • Spring-SpringMVC-SpringBoot理解
  • Java基础-JVM
  • 【宝塔面板】轻松使用docker搭建lobe-chat项目(AI对话)
  • js纯操作dom版购物车(实现购物车功能)
  • Cannot read property ‘prototype’ of undefined 表单
  • 云资源管理与优化:提升效率的技术指南
  • 【数据集】NCEP辐射数据-用于计算漫射天窗比(diffuse skylight ration)
  • ELK之路第二步——可视化界面Kibana
  • Hadoop:yarn的Rust API接口
  • 面向对象思想和面向过程思想分析
  • 【LeetCode】每日一题 2024_10_27 冗余连接(并查集)
  • JavaWeb的小结08
  • 前端聊天室页面开发(赛博朋克科技风,内含源码)
  • Axure随机验证码高级交互
  • 血量更新逻辑的实现
  • Windows AD 域的深度解析 第一篇:AD 域原理与多系统联动