当前位置: 首页 > article >正文

中文NLP地址要素解析【阿里云:天池比赛】

比赛地址:中文NLP地址要素解析
https://tianchi.aliyun.com/notebook/467867?spm=a2c22.12281976.0.0.654b265fTnW3lu

长期赛:
分数:87.7271
排名:长期赛:56(本次)/6990(团体或个人)

方案:BERT-BiLSTM-CRF-NER
预训练模型:bert-base-chinese

训练结果:
F1 : 0.9040681554670564
accuracy : 0.9313805261730405
precision : 0.901296612724897
recall : 0.9068567961165048

运行脚本:
python run_bert_lstm_crf.py

代码解析:

模型:bert_lstm_crf.py:lstm+crf

import torch
import torch.nn as nn
from torchcrf import CRF
from transformers import AutoModel


class NERNetwork(nn.Module):

    def __init__(self, config, n_tags: int, using_lstm: bool = True) -> None:
        """Initialize a NERDA Network
        Args:
            bert_model (nn.Module): huggingface `torch` transformers.
            device (str): Computational device.
            n_tags (int): Number of unique entity tags (incl. outside tag)
            dropout (float, optional): Dropout probability. Defaults to 0.1.
        """
        super(NERNetwork, self).__init__()
        self.bert_encoder = AutoModel.from_pretrained(config.model_name_or_path)
        self.dropout = nn.Dropout(config.dropout)
        self.using_lstm = using_lstm
        out_size = self.bert_encoder.config.hidden_size
        if self.using_lstm:
            self.lstm = nn.LSTM(self.bert_encoder.config.hidden_size, config.lstm_hidden_size, num_layers=1,
                                bidirectional=True, batch_first=True)
            out_size = config.lstm_hidden_size * 2

        self.hidden2tags = nn.Linear(out_size, n_tags)  # BERT+Linear
        self.crf_layer = CRF(num_tags=n_tags, batch_first=True)

    def tag_outputs(self,
                    input_ids: torch.Tensor,
                    attention_mask: torch.Tensor,
                    token_type_ids: torch.Tensor,
                    ) -> torch.Tensor:

        bert_model_inputs = {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'token_type_ids': token_type_ids
        }

        outputs = self.bert_encoder(**bert_model_inputs)
        # apply drop-out
        last_hidden_state = outputs.last_hidden_state
        last_hidden_state = self.dropout(last_hidden_state)

        if self.using_lstm:
            last_hidden_state, _ = self.lstm(last_hidden_state)
        # last_hidden_state for all labels/tags
        emissions = self.hidden2tags(last_hidden_state)

        return emissions

    def forward(self,
                input_ids: torch.Tensor,
                attention_mask: torch.Tensor,
                token_type_ids: torch.Tensor,
                target_tags: torch.Tensor
                ):
        """Model Forward Iteration
        Args:
            input_ids (torch.Tensor): Input IDs.
            attention_mask (torch.Tensor): Attention attention_mask.
            token_type_ids (torch.Tensor): Token Type IDs.
        Returns:
            torch.Tensor: predicted values.
        """
        emissions = self.tag_outputs(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        loss = -1 * self.crf_layer(emissions=emissions, tags=target_tags, mask=attention_mask.byte())
        return loss

    def predict(self,
                input_ids: torch.Tensor,
                attention_mask: torch.Tensor,
                token_type_ids: torch.Tensor,
                ):
        emissions = self.tag_outputs(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        return self.crf_layer.decode(emissions=emissions, mask=attention_mask.byte())

训练脚本:run_bert_lstm_crf.py

import numpy as np
import torch
import argparse
import os, json
import sys
from tqdm import tqdm
import sklearn.preprocessing
from transformers import AutoModel, AutoTokenizer, AutoConfig
from transformers import AdamW, get_linear_schedule_with_warmup
import transformers
import random
from preprocess import create_dataloader, get_semeval_data
from utils import compute_loss, get_ent_tags, batch_to_device, compute_f1, load_test_file
from bert_lstm_crf import NERNetwork
import logging
from config import args
sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))

logger = logging.getLogger('main')
logger.setLevel(logging.INFO)
fh = logging.FileHandler('log/log.txt', mode='w')
fh.setLevel(logging.INFO)
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(funcName)s - %(lineno)d : %(message)s')
fh.setFormatter(formatter)
ch.setFormatter(formatter)
logger.addHandler(fh)
logger.addHandler(ch)

seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# System based
random.seed(seed)
np.random.seed(seed)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
logger.info("Using device {}".format(device))


# 预测
def predict(model, test_dataloader, tag_encoder, device, train=True):
    if train and model.training:
        logger.info("Evaluating the model...")
        model.eval()

    predictions = []
    for batch1 in test_dataloader:
        batch = batch_to_device(inputs=batch1, device=device)
        input_ids, attention_mask, token_type_ids = batch['input_ids'], batch['attention_mask'], batch['token_type_ids']
        with torch.no_grad():
            outputs = model.predict(input_ids=input_ids, attention_mask=attention_mask,
                                    token_type_ids=token_type_ids)  # (batch_size,seq_length,num_classes)

        for i, predict_tag_seq in enumerate(outputs):
            preds = tag_encoder.inverse_transform(predict_tag_seq)  # (with wordpiece)
            preds = [prediction for prediction, offset in zip(preds.tolist(), batch.get('offsets')[i]) if
                     offset]  # offsets = [1] + offsets + [1]
            preds = preds[1:-1]
            predictions.append(preds)

    return predictions


# 训练
def train(args,
          train_dataloader,
          tag_encoder,
          train_conll_tags,
          test_conll_tags,
          test_dataloader):
    n_tags = tag_encoder.classes_.shape[0]
    logger.info("n_tags : {}".format(n_tags))

    print_loss_step = len(train_dataloader) // 5
    evaluation_steps = len(train_dataloader) // 2
    logger.info(
        "Under an epoch, loss will be output every {} step, and the model will be evaluated every {} step".format(
            print_loss_step, evaluation_steps))

    model = NERNetwork(args, n_tags=n_tags)
    if args.ckpt is not None:
        load_result = model.load_state_dict(torch.load(args.ckpt, map_location='cpu'), strict=False)
        logger.info("Load ckpt to continue training !")
        logger.info("missing and unexcepted key : {}".format(str(load_result)))

    model.to(device=device)
    logger.info("Using device : {}".format(device))
    optimizer_parameters = model.parameters()
    optimizer = AdamW(optimizer_parameters, lr=args.learning_rate)
    num_train_steps = int(
        len(train_conll_tags) // args.train_batch_size // args.gradient_accumulation_steps) * args.epochs
    warmup_steps = int(num_train_steps * args.warmup_proportion)
    logger.info("num_train_steps : {}, warmup_proportion : {}, warmup_steps : {}".format(num_train_steps,
                                                                                         args.warmup_proportion,
                                                                                         warmup_steps))
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=warmup_steps, num_training_steps=num_train_steps
    )

    global_step = 0

    previous_f1 = -1
    predictions = predict(model=model, test_dataloader=test_dataloader, tag_encoder=tag_encoder, device=device)
    f1 = compute_f1(pred_tags=predictions, golden_tags=test_conll_tags)
    if f1 > previous_f1:
        logger.info("Previous f1 score is {} and current f1 score is {}".format(previous_f1, f1))
        previous_f1 = f1

    for epoch in range(args.epochs):
        model.train()
        model.zero_grad()
        training_loss = 0.0
        for iteration, batch in tqdm(enumerate(train_dataloader)):
            batch = batch_to_device(inputs=batch, device=device)
            input_ids, attention_mask, token_type_ids = batch['input_ids'], batch['attention_mask'], batch[
                'token_type_ids']
            loss = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids,
                         target_tags=batch['target_tags'])  # (batch_size,seq_length,num_classes)
            # target_tags将CLS和SEP赋予标签O
            training_loss += loss.item()
            loss.backward()

            if (iteration + 1) % args.gradient_accumulation_steps == 0:
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
                global_step += 1

            if (iteration + 1) % print_loss_step == 0:
                training_loss /= print_loss_step
                logger.info(
                    "Epoch : {}, global_step : {}/{}, loss_value : {} ".format(epoch, global_step, num_train_steps,
                                                                               training_loss))
                training_loss = 0.0

            if (iteration + 1) % evaluation_steps == 0:
                predictions = predict(model=model, test_dataloader=test_dataloader, tag_encoder=tag_encoder,
                                      device=device)
                f1 = compute_f1(pred_tags=predictions, golden_tags=test_conll_tags)
                if f1 > previous_f1:
                    torch.save(model.state_dict(), args.best_model)
                    logger.info(
                        "Previous f1 score is {} and current f1 score is {}, best model has been saved in {}".format(
                            previous_f1, f1, args.best_model))
                    previous_f1 = f1

                else:
                    args.patience -= 1
                    logger.info("Left patience is {}".format(args.patience))
                    if args.patience == 0:
                        logger.info("Total patience is {}, run our of patience, early stop!".format(args.patience))
                        return

                model.zero_grad()
                model.train()


# 生成测试数据
def my_test(args,
            tag_encoder,
            valid_dataloader):
    n_tags = tag_encoder.classes_.shape[0]
    logger.info("n_tags : {}".format(n_tags))

    model = NERNetwork(args, n_tags=n_tags)
    if args.best_model is not None:
        load_result = model.load_state_dict(torch.load(args.best_model, map_location='cpu'), strict=False)
        logger.info("Load ckpt to continue training !")
        logger.info("missing and unexcepted key : {}".format(str(load_result)))
    model.to(device=device)

    predictions = predict(model=model, test_dataloader=valid_dataloader, tag_encoder=tag_encoder, device=device,
                          train=False)

    sentences = valid_dataloader.dataset.sentences
    # 指定文件名
    file_name = "output_new.txt"

    # 打开文件,以写入模式写入数据
    with open(file_name, "w", encoding="utf-8") as file:
        index = 0
        for prediction in predictions:
            sentence = sentences[index]
            sentence_str = ''.join(sentence)

            prediction_str = ' '.join(prediction)
            line = f"{index + 1}\u0001{sentence_str}\u0001{prediction_str}\n"
            # logger.info(f"line={line}")
            assert len(sentence) == len(prediction)
            file.write(line)
            index += 1


def main():
    if not os.path.exists(args.save_dir):
        logger.info("save_dir not exists, created!")
        os.makedirs(args.save_dir, exist_ok=True)

    train_conll_data = get_semeval_data(split='train', dir=args.file_path, word_idx=1, entity_idx=3)
    test_conll_data = get_semeval_data(split='dev', dir=args.file_path, word_idx=1, entity_idx=3)

    valid_conll_data = load_test_file(split='valid', dir=args.file_path)
    logger.info("train sentences num : {}".format(len(train_conll_data['sentences'])))
    logger.info("test sentences num : {}".format(len(test_conll_data['sentences'])))
    logger.info("Logging some examples...")
    for _ in range(5):
        i = random.randint(0, len(test_conll_data['tags']) - 1)
        sen = test_conll_data['sentences'][i]
        ent = test_conll_data['tags'][i]
        for k in range(len(sen)):
            logger.info("{}  {}".format(sen[k], ent[k]))
        logger.info('-' * 50)

    tag_scheme = get_ent_tags(all_tags=train_conll_data.get('tags'))
    tag_outside = 'O'
    if tag_outside in tag_scheme:
        del tag_scheme[tag_scheme.index(tag_outside)]
    tag_complete = [tag_outside] + tag_scheme
    print(tag_complete, len(tag_complete))
    with open(os.path.join(args.save_dir, 'label.json'), 'w') as f:
        json.dump(obj=' '.join(tag_complete), fp=f)
    logger.info("Tag scheme : {}".format(' '.join(tag_scheme)))
    logger.info("Tag has been saved in {}".format(os.path.join(args.save_dir, 'label.json')))
    tag_encoder = sklearn.preprocessing.LabelEncoder()
    tag_encoder.fit(tag_complete)

    transformer_tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
    transformer_config = AutoConfig.from_pretrained(args.model_name_or_path)

    train_dataloader = create_dataloader(sentences=train_conll_data.get('sentences'),
                                         tags=train_conll_data.get('tags'),
                                         transformer_tokenizer=transformer_tokenizer,
                                         transformer_config=transformer_config,
                                         max_len=args.max_len,
                                         tag_encoder=tag_encoder,
                                         tag_outside=tag_outside,
                                         batch_size=args.train_batch_size,
                                         num_workers=args.num_workers,
                                         take_longest_token=args.take_longest_token,
                                         is_training=True)
    test_dataloader = create_dataloader(sentences=test_conll_data.get('sentences'),
                                        tags=test_conll_data.get('tags'),
                                        transformer_tokenizer=transformer_tokenizer,
                                        transformer_config=transformer_config,
                                        max_len=args.max_len,
                                        tag_encoder=tag_encoder,
                                        tag_outside=tag_outside,
                                        batch_size=args.test_batch_size,
                                        num_workers=args.num_workers,
                                        take_longest_token=args.take_longest_token,
                                        is_training=False)
    valid_dataloader = create_dataloader(sentences=valid_conll_data.get('sentences'),
                                         tags=valid_conll_data.get('tags'),
                                         transformer_tokenizer=transformer_tokenizer,
                                         transformer_config=transformer_config,
                                         max_len=args.max_len,
                                         tag_encoder=tag_encoder,
                                         tag_outside=tag_outside,
                                         batch_size=args.test_batch_size,
                                         num_workers=args.num_workers,
                                         take_longest_token=args.take_longest_token,
                                         is_training=False)

    train(args=args, train_dataloader=train_dataloader,
          tag_encoder=tag_encoder,
          train_conll_tags=train_conll_data.get('tags'),
          test_conll_tags=test_conll_data.get('tags'),
          test_dataloader=test_dataloader)

    my_test(args=args,
            tag_encoder=tag_encoder,
            valid_dataloader=valid_dataloader)


if __name__ == "__main__":
    main()

配置config.py

import argparse

parser = argparse.ArgumentParser()
# input and output parameters
# 预训练模型
parser.add_argument('--model_name_or_path', default='/data/nfs/baozhi/models/google-bert_bert-base-chinese', help='path to the BERT')
# 微调后的模型保存路径
parser.add_argument('--best_model', default='saved_models/pytorch_model_20241031_v2.bin', help='path to the BERT')
# 训练数据目录
parser.add_argument('--file_path', default='data/com334', help='path to the ner data')
# 数据保留目录
parser.add_argument('--save_dir', default='saved_models/', help='path to save checkpoints and logs')
parser.add_argument('--ckpt', default=None, help='Fine tuned model')
# training parameters
# 学习率
parser.add_argument('--learning_rate', default=3e-5, type=float)
parser.add_argument('--weight_decay', default=1e-5, type=float)
# epochs
parser.add_argument('--epochs', default=15, type=int)
parser.add_argument('--train_batch_size', default=64, type=int)
parser.add_argument('--gradient_accumulation_steps', default=1, type=int)
parser.add_argument('--lstm_hidden_size', default=150, type=int)
parser.add_argument('--test_batch_size', default=64, type=int)
parser.add_argument('--max_grad_norm', default=1, type=int)
parser.add_argument('--warmup_proportion', default=0.1, type=float)
# 最大长度
parser.add_argument('--max_len', default=200, type=int)
parser.add_argument('--patience', default=100, type=int)
# 正则化系数
parser.add_argument('--dropout', default=0.5, type=float)

# Other parameters
parser.add_argument('--seed', type=int, default=42, help='random seed')
parser.add_argument('--num_workers', default=1, type=int)
parser.add_argument('--take_longest_token', default=False, type=bool)
args = parser.parse_args()

数据处理:preprocess.py

import re
import warnings
import sklearn.preprocessing
import torch
import transformers
import os,json

class DataSet():
    def __init__(self,
                 sentences: list,
                 tags: list,
                 transformer_tokenizer: transformers.PreTrainedTokenizer,
                 transformer_config: transformers.PretrainedConfig,
                 max_len: int,
                 tag_encoder: sklearn.preprocessing.LabelEncoder,
                 tag_outside: str,
                 take_longest_token: bool = True,
                 pad_sequences: bool = True) -> None:
        """Initialize DataSetReader
        Initializes DataSetReader that prepares and preprocesses 
        DataSet for Named-Entity Recognition Task and training.
        Args:
            sentences (list): Sentences.
            tags (list): Named-Entity tags.
            transformer_tokenizer (transformers.PreTrainedTokenizer): 
                tokenizer for transformer.
            transformer_config (transformers.PretrainedConfig): Config
                for transformer model.
            max_len (int): Maximum length of sentences after applying
                transformer tokenizer.
            tag_encoder (sklearn.preprocessing.LabelEncoder): Encoder
                for Named-Entity tags.
            tag_outside (str): Special Outside tag. like 'O'
            pad_sequences (bool): Pad sequences to max_len. Defaults
                to True.
        """
        self.sentences = sentences
        self.tags = tags
        self.transformer_tokenizer = transformer_tokenizer
        self.max_len = max_len
        self.tag_encoder = tag_encoder
        self.pad_token_id = transformer_config.pad_token_id
        self.tag_outside_transformed = tag_encoder.transform([tag_outside])[0]
        self.take_longest_token = take_longest_token
        self.pad_sequences = pad_sequences

    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, item):
        sentence = self.sentences[item]
        tags = self.tags[item]
        # encode tags
        tags = self.tag_encoder.transform(tags)

        # check inputs for consistancy
        assert len(sentence) == len(tags)

        input_ids = []
        target_tags = []
        tokens = []
        offsets = []

        # for debugging purposes
        # print(item)
        for i, word in enumerate(sentence):
            # bert tokenization
            wordpieces = self.transformer_tokenizer.tokenize(word)
            if self.take_longest_token:
                piece_token_lengths = [len(token) for token in wordpieces]
                word = wordpieces[piece_token_lengths.index(max(piece_token_lengths))]
                wordpieces = [word]  # 仅仅取最长的token

            tokens.extend(wordpieces)
            # make room for CLS if there is an identified word piece
            if len(wordpieces) > 0:
                offsets.extend([1] + [0] * (len(wordpieces) - 1))
            # Extends the ner_tag if the word has been split by the wordpiece tokenizer
            target_tags.extend([tags[i]] * len(wordpieces))

        if self.take_longest_token:
            assert len(tokens) == len(sentence) == len(target_tags)
        # Make room for adding special tokens (one for both 'CLS' and 'SEP' special tokens)
        # max_len includes _all_ tokens.
        if len(tokens) > self.max_len - 2:
            msg = f'Sentence #{item} length {len(tokens)} exceeds max_len {self.max_len} and has been truncated'
            warnings.warn(msg)
        tokens = tokens[:self.max_len - 2]
        target_tags = target_tags[:self.max_len - 2]
        offsets = offsets[:self.max_len - 2]

        # encode tokens for BERT
        # TO DO: prettify this.
        input_ids = self.transformer_tokenizer.convert_tokens_to_ids(tokens)
        input_ids = [self.transformer_tokenizer.cls_token_id] + input_ids + [self.transformer_tokenizer.sep_token_id]

        # fill out other inputs for model.    
        target_tags = [self.tag_outside_transformed] + target_tags + [self.tag_outside_transformed]
        attention_mask = [1] * len(input_ids)
        # set to 0, because we are not doing NSP or QA type task (across multiple sentences)
        # token_type_ids distinguishes sentences.
        token_type_ids = [0] * len(input_ids)
        offsets = [1] + offsets + [1]

        # Padding to max length 
        # compute padding length
        if self.pad_sequences:
            padding_len = self.max_len - len(input_ids)
            input_ids = input_ids + ([self.pad_token_id] * padding_len)
            attention_mask = attention_mask + ([0] * padding_len)
            offsets = offsets + ([0] * padding_len)
            token_type_ids = token_type_ids + ([0] * padding_len)
            target_tags = target_tags + ([self.tag_outside_transformed] * padding_len)

        return {'input_ids': torch.tensor(input_ids, dtype=torch.long),
                'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
                'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
                'target_tags': torch.tensor(target_tags, dtype=torch.long),
                'offsets': torch.tensor(offsets, dtype=torch.long)}

def create_dataloader(sentences,
                      tags,
                      transformer_tokenizer,
                      transformer_config,
                      max_len,
                      tag_encoder,
                      tag_outside,
                      batch_size=1,
                      num_workers=1,
                      take_longest_token=True,
                      pad_sequences=True,
                      is_training=True):
    if not pad_sequences and batch_size > 1:
        print("setting pad_sequences to True, because batch_size is more than one.")
        pad_sequences = True

    data_reader = DataSet(
        sentences=sentences,
        tags=tags,
        transformer_tokenizer=transformer_tokenizer,
        transformer_config=transformer_config,
        max_len=max_len,
        tag_encoder=tag_encoder,
        tag_outside=tag_outside,
        take_longest_token=take_longest_token,
        pad_sequences=pad_sequences)
    # Don't pad sequences if batch size == 1. This improves performance.
    data_loader = torch.utils.data.DataLoader(
        data_reader, batch_size=batch_size, num_workers=num_workers, shuffle=is_training
    )
    return data_loader

def get_conll_data(split: str = 'train',
                   limit_length: int = 196,
                   dir: str = None) -> dict:
    assert isinstance(split, str)
    splits = ['train', 'dev', 'test']
    assert split in splits, f'Choose between the following splits: {splits}'

    # set to default directory if nothing else has been provided by user.

    assert os.path.isdir(
        dir), f'Directory {dir} does not exist. Try downloading CoNLL-2003 data with download_conll_data()'

    file_path = os.path.join(dir, f'{split}.txt')
    assert os.path.isfile(
        file_path), f'File {file_path} does not exist. Try downloading CoNLL-2003 data with download_conll_data()'

    # read data from file.
    with open(file_path, 'r') as f:
        lines = f.readlines()

    sentences = []
    sentence = []
    entities = []
    entity = []
    sentences = []
    labels = []
    sentence = []
    label = []
    pua_pattern = re.compile("[\uE000-\uF8FF]|[\u200b\u200d\u200e]")
    for line in lines:
        line = line.strip()
        if len(line) == 0:
            if len(sentence) > 0:
                sentences.append(sentence)
                labels.append(label)
            sentence = []
            label = []
        else:
            parts = line.split()
            word = parts[0]
            tag = parts[1]
            word = re.sub(pua_pattern, "", word)  # 删除这些私有域字符
            if word:
                sentence.append(word)
                label.append(tag)
    if len(sentence) > 0:
        sentences.append(sentence)
        labels.append(label)

    return {'sentences': sentences, 'tags': labels}


def get_semeval_data(split: str = 'train',
                     limit_length: int = 196,
                     dir: str = None,
                     word_idx=1,
                     entity_idx=4) -> dict:
    assert isinstance(split, str)
    splits = ['train', 'dev', 'test']
    assert split in splits, f'Choose between the following splits: {splits}'

    # set to default directory if nothing else has been provided by user.

    assert os.path.isdir(
        dir), f'Directory {dir} does not exist. Try downloading CoNLL-2003 data with download_conll_data()'

    file_path = os.path.join(dir, f'{split}.txt')
    assert os.path.isfile(
        file_path), f'File {file_path} does not exist. Try downloading CoNLL-2003 data with download_conll_data()'

    # read data from file.
    with open(file_path, 'r', encoding='utf-8') as f:
        lines = f.readlines()

    sentences = []
    sentence = []
    entities = []
    entity = []
    sentences = []
    labels = []
    sentence = []
    label = []
    pua_pattern = re.compile("[\uE000-\uF8FF]|[\u200b\u200d\u200e]")
    for line in lines:
        line = line.strip()
        if len(line) == 0:
            if len(sentence) > 0:
                sentences.append(sentence)
                labels.append(label)
            sentence = []
            label = []
        else:
            parts = line.split()
            word = parts[0]
            tag = parts[1]
            word = re.sub(pua_pattern, "", word)  # 删除这些私有域字符
            if word:
                sentence.append(word)
                label.append(tag)
    if len(sentence) > 0:
        sentences.append(sentence)
        labels.append(label)

    return {'sentences': sentences, 'tags': labels}

工具类:utils.py

import os
from io import BytesIO
from pathlib import Path
from urllib.request import urlopen
from zipfile import ZipFile
import ssl
from typing import Callable
import torch
from seqeval.metrics import accuracy_score, classification_report, f1_score, precision_score, recall_score
import logging
import re

logger = logging.getLogger('main.utils')


def load_test_file(split: str = 'train',
                   dir: str = None):
    file_path = os.path.join(dir, f'{split}.txt')
    sentences = []
    labels = []
    pua_pattern = re.compile("[\uE000-\uF8FF]|[\u200b\u200d\u200e]")
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            ids, words = line.strip().split('\001')
            # 要预测的数据集没有label,伪造个O,
            words = re.sub(pua_pattern, '', words)
            label = ['O' for x in range(0, len(words))]
            sentence = []
            for c in words:
                sentence.append(c)
            sentences.append(sentence)
            labels.append(label)
    return {'sentences': sentences, 'tags': labels}
    # return sentences, labels


def download_unzip(url_zip: str,
                   dir_extract: str) -> str:
    """Download and unzip a ZIP archive to folder.
    Loads a ZIP file from URL and extracts all of the files to a 
    given folder. Does not save the ZIP file itself.
    Args:
        url_zip (str): URL to ZIP file.
        dir_extract (str): Directory where files are extracted.
    Returns:
        str: a message telling, if the archive was succesfully
        extracted. Obviously the files in the ZIP archive are
        extracted to the desired directory as a side-effect.
    """

    # suppress ssl certification
    ctx = ssl.create_default_context()
    ctx.check_hostname = False
    ctx.verify_mode = ssl.CERT_NONE

    print(f'Reading {url_zip}')
    with urlopen(url_zip, context=ctx) as zipresp:
        with ZipFile(BytesIO(zipresp.read())) as zfile:
            zfile.extractall(dir_extract)

    return f'archive extracted to {dir_extract}'


def download_conll_data(dir: str = None) -> str:
    """Download CoNLL-2003 English data set.
    Downloads the [CoNLL-2003](https://www.clips.uantwerpen.be/conll2003/ner/) 
    English data set annotated for Named Entity Recognition.
    Args:
        dir (str, optional): Directory where CoNLL-2003 datasets will be saved. If no directory is provided, data will be saved to a hidden folder '.dane' in your home directory.  
                           
    Returns:
        str: a message telling, if the archive was in fact 
        succesfully extracted. Obviously the CoNLL datasets are
        extracted to the desired directory as a side-effect.
    
    Examples:
        >>> download_conll_data()
        >>> download_conll_data(dir = 'conll')
        
    """
    # set to default directory if nothing else has been provided by user.
    if dir is None:
        dir = os.path.join(str(Path.home()), '.conll')

    return download_unzip(url_zip='https://data.deepai.org/conll2003.zip',
                          dir_extract=dir)


def match_kwargs(function: Callable, **kwargs) -> dict:
    """Matches Arguments with Function
    Match keywords arguments with the arguments of a function.
    Args:
        function (function): Function to match arguments for.
        kwargs: keyword arguments to match against.
    Returns:
        dict: dictionary with matching arguments and their
        respective values.
    """
    arg_count = function.__code__.co_argcount  # 14
    args = function.__code__.co_varnames[
           :arg_count]  # 'self', 'input_ids', 'attention_mask', 'token_type_ids', 'position_ids', 'head_mask', 'inputs_embeds'

    args_dict = {}
    for k, v in kwargs.items():
        if k in args:
            args_dict[k] = v

    return args_dict


def get_ent_tags(all_tags):
    ent_tags = set()
    for each_tag_sequence in all_tags:
        for each_tag in each_tag_sequence:
            ent_tags.add(each_tag)
    return list(ent_tags)


def batch_to_device(inputs, device):
    for key in inputs.keys():
        if type(inputs[key]) == list:
            inputs[key] = torch.LongTensor(inputs[key])
        inputs[key] = inputs[key].to(device)

    return inputs


def compute_loss(preds, target_tags, masks, device, n_tags):
    # initialize loss function.
    lfn = torch.nn.CrossEntropyLoss()

    # Compute active loss to not compute loss of paddings
    active_loss = masks.view(-1) == 1

    active_logits = preds.view(-1, n_tags)
    active_labels = torch.where(
        active_loss,
        target_tags.view(-1),
        torch.tensor(lfn.ignore_index).type_as(target_tags)
    )

    active_labels = torch.as_tensor(active_labels, device=torch.device(device), dtype=torch.long)

    # Only compute loss on actual token predictions
    loss = lfn(active_logits, active_labels)

    return loss


def compute_f1(pred_tags, golden_tags, from_test=False):
    assert len(pred_tags) == len(golden_tags)
    count = 0
    for pred, golden in zip(pred_tags, golden_tags):
        try:
            assert len(pred) == len(golden)
        except:
            print(len(pred), len(golden))
            print(count)
            raise Exception('length is not consistent!')
        count += 1

    result = classification_report(y_pred=pred_tags, y_true=golden_tags, digits=4)
    f1 = f1_score(y_pred=pred_tags, y_true=golden_tags)
    acc = accuracy_score(y_pred=pred_tags, y_true=golden_tags)
    precision = precision_score(y_pred=pred_tags, y_true=golden_tags)
    recall = recall_score(y_pred=pred_tags, y_true=golden_tags)

    if from_test == False:
        logger.info('\n' + result)
        logger.info("F1 : {}, accuracy : {}, precision : {}, recall : {}".format(f1, acc, precision, recall))
        return f1
    else:
        print(result)
        print("F1 : {}, accuracy : {}, precision : {}, recall : {}".format(f1, acc, precision, recall))
        return f1

附:源码

比赛地址:中文NLP地址要素解析方案:BERT-BiLSTM-CRF-NER资源-CSDN文库


http://www.kler.cn/a/380669.html

相关文章:

  • 图神经网络_图嵌入_SDNE
  • 参数名在不同的SpringBoot版本中,处理方案不同
  • Mysql 查询性能调优总结
  • 【求职面试】驾照的种类
  • CH32V307VCT6---工程template创建
  • 使用 OpenCV 绘制线条和矩形
  • 度小满,让“推理大模型”走向金融核心业务
  • Java栈和队列的快速入门
  • 如何使用Varjo直接观看Blender内容
  • ubuntu工具 -- 北京理工大学Linux服务器自动登录校园网 (官方脚本方案), 永远不断
  • Jmeter基础篇(20)压测时如何找到最佳并发量
  • QT-C++ 西门子snap7通讯库接口
  • 计算机网络——TCP中的流量控制和拥塞控制
  • 无人机目标检测与语义分割数据集(猫脸码客 第238期)
  • 接口测试(十)jmeter——关联(正则表达式提取器)
  • 成都睿明智科技有限公司共赴抖音电商蓝海
  • Vue3父传子
  • MATLAB函数,用于计算平均误差、误差最大值、标准差、均方误差、均方根误差
  • 低代码工作流平台概述-自研
  • 150道MySQL高频面试题,学完吊打面试官--如何实现索引机制
  • 链表练习记录
  • LabVIEW在Windows和Linux开发的差异
  • 微店商品详情API接口,json数据参考
  • qt QEvent详解
  • 顺序表专题
  • 软件工程(软考高频)