当前位置: 首页 > article >正文

【NLP实战】基于Bert和双向LSTM的情感分类【上篇】

文章目录

  • 前言
  • 简介
  • 数据获取与提取
  • 数据清洗
    • 读取数据,查看数据
    • 清洗训练集
      • 观察数据分布
      • 去除空数据
      • 去除重复数据
      • 关于去除停用词
      • 关于特殊符号
      • 储存清洗后的数据集
    • 清洗测试集
      • 观察数据分布
      • 去除空数据
      • 去除重复数据(并储存)
    • 清洗验证集
      • 观察数据分布
      • 去除空行
      • 去除重复数据(并储存)
  • 总结
  • 代码汇总

前言

最近自己找了个实验做,写了很多实验记录和方法,现在我将它们整理成文章,希望能对不熟悉NLP的伙伴们起到些许帮助。如有疑问请及时联系作者。

博主page:issey的博客 - 愿无岁月可回首

本系列文章中不会说明环境和包如何安装,这些应该是最基础的东西,可以自己边查边安装。

许多函数用法等在代码里有详细解释,但还是希望各位去看它们的官方文档,我的代码还有很多可以改进的方法,需要的函数等在官方文档都有说明。

简介

本系列将带领大家从数据获取、数据清洗、模型构建、训练,观察loss变化,调整超参数再次训练,并最后进行评估整一个过程。我们将获取一份公开竞赛中文数据,并一步步实验,到最后,我们的评估可以达到排行榜13位的位置。但重要的不是排名,而是我们能在其中学到很多。

本系列将分为三篇文章,分别是:

  • 上篇:数据获取,数据分割与数据清洗
  • 中篇:模型构建,改进pytorch结构,开始第一次训练
  • 下篇:测试与评估,绘图与过拟合,超参数调整

本文为该系列第一篇文章,在本文中,我们将一同观察原始数据,进行数据清洗。样本是很重要的一个部分,学会观察样本并剔除一些符合特殊条件的样本,对模型在学习时有很大的帮助。

数据获取与提取

数据来源:Weibo nCoV Data | Kaggle

竞赛官网:疫情期间网民情绪识别 竞赛 - DataFountain

关于kaggle如何下载数据,本文不再赘述。

为了把数据分割也作为我们实验的一部分,假设我们现在拿到的nCoV_100k_train.labled.csv就是我们爬取到的原始数据。

先来看看我们用到的数据长什么样。

思考:

  • 我们只需要text和情感倾向的列,其他列都不需要。
  • 分割数据时,训练集:测试集:验证集 = 6:2:2。这只是博主自己选择的比例,各位可以自行调整。

编写代码。这部分比较简单,就不一步步运行了,但是各位应该逐行运行观察变化,写文章不能像jupyter notebook那样一行行运行,为了方便起见,文章涉及的代码都将以块状给出。但是运行实际上很多都是逐行调整的。

数据获取.py

import pandas as pd
from sklearn.model_selection import train_test_split

# todo: 读取数据
df = pd.read_csv('./data/archive/nCoV_100k_train.labled.csv')
print(df)
# 只要text和标签
df = df[['微博中文内容', '情感倾向']]
df = df.rename(columns={'微博中文内容': 'text', '情感倾向': 'label'})
print(df)

# todo: 分割数据集,储存.0.6/0.2/0.2
train, test = train_test_split(df, test_size=0.2)
train, val = train_test_split(train, test_size=0.25)
print(train)
print(test)
print(val)
train.to_csv('./data/archive/train.csv', index=None)
val.to_csv('./data/archive/val.csv', index=None)
test.to_csv('./data/archive/test.csv', index=None)

运行结束后,这三个文件就是我们需要的文件。

数据清洗

我的清洗思路来源于这篇:Emotion analysis and Classification using LSTM 93% | Kaggle

该部分需要的库:

  • seaborn:一个适合数据分析的绘图库,需要matplotlib作为前置库

读取数据,查看数据

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# todo:读取数据
df_train = pd.read_csv('./data/archive/train.csv')
df_test = pd.read_csv('./data/archive/test.csv')
df_val = pd.read_csv('./data/archive/val.csv')

# 输出前5行
print(df_train.head())
print(df_train.shape)

print(df_test.head())
print(df_test.shape)

print(df_val.head())
print(df_val.shape)

输出:

清洗训练集

观察数据分布

# todo: 清洗Train
# 观察数据是否平衡
print(df_train.label.value_counts())
print(df_train.label.value_counts() / df_train.shape[0] * 100)
plt.figure(figsize=(8, 4))
sns.countplot(x='label', data=df_train)
plt.show()

可以发现,-2.0,9.0,10.0都只有一个样本,当作异常数据处理,我选择直接丢掉不要。

另外,这个样本分布略微存在分布不平衡imbalance的情况,至于要不要用smote等方法过采样,暂时先不进行讨论,我们暂时保持数据不变。

print(df_train[df_train.label > 5.0])
print(df_train[(df_train.label < -1.1)])
# 丢掉异常数据
df_train.drop(df_train[(df_train.label < -1.1) | (df_train.label > 5)].index, inplace=True, axis=0)
df_train.reset_index(inplace=True, drop=True)
print(df_train.label.value_counts())
sns.countplot(x='label', data=df_train)
plt.show()

去除空数据

# 观察是否有空行
print(df_train.isnull().sum())
# 删除空行数据
df_train.dropna(axis=0, how='any', inplace=True)
df_train.reset_index(inplace=True, drop=True)
print(df_train.isnull().sum())

去除重复数据

# 查看重复数据
print(df_train.duplicated().sum())
# print(df_train[df_train.duplicated()==True])
# 删除重复数据
index = df_train[df_train.duplicated() == True].index
df_train.drop(index, axis=0, inplace=True)
df_train.reset_index(inplace=True, drop=True)
print(df_train.duplicated().sum())

然后我们还需要去除text一样但是label不一样的数据。

# 我们还需要关心的重复数据是text一样但是label不一样的数据。
print(df_train['text'].duplicated().sum())
print(df_train[df_train['text'].duplicated() == True])
# 查看例子
print(df_train[df_train['text'] == df_train.iloc[856]['text']])
print(df_train[df_train['text'] == df_train.iloc[3096]['text']])

# 去掉text一样但是label不一样的数据
index = df_train[df_train['text'].duplicated() == True].index
df_train.drop(index, axis=0, inplace=True)
df_train.reset_index(inplace=True, drop=True)
# 检查
print(df_train['text'].duplicated().sum())  # 0

关于去除停用词

去不去除停用词和构建word embedding选择的方法有关,去查了一下,使用Bert构建时,不需要去除停用词处理,否则还会丢失上下文。于是这里没有进一步去除停用词。

问题解答:nlp - Is it necessary to do stopwords removal ,Stemming/Lemmatization for text classification while using Spacy,Bert? - Stack Overflow

关于特殊符号

观察我们现在的数据:

很容易发现里面有特殊字符。

待会儿用到的bert,它会用到一个中文字典,这个字典是它自己有的,如果出现字典里没有的字符,它会自动替换成[UNK],所以不用管。

储存清洗后的数据集

df_train.to_csv('./data/archive/train_clean.csv', index=None)

清洗测试集

整体步骤和清洗训练集的一样。这里为了巩固处理思路,自己还是详细做一遍吧。

观察数据分布

# 观察数据是否平衡
print(df_test.label.value_counts())
print(df_test.label.value_counts() / df_test.shape[0] * 100)
plt.figure(figsize=(8, 4))
sns.countplot(x='label', data=df_test)
plt.show()

输出就不放了,放个图。

没有特殊label,不用进行去除的操作。

哦对,执行时可以把上面清洗train的代码注释了,用不着重新跑。

去除空数据

# 观察是否有空行
print(df_test.isnull().sum())
# 删除空行数据
df_test.dropna(axis=0, how='any', inplace=True)
df_test.reset_index(inplace=True, drop=True)
print(df_test.isnull().sum())

去除重复数据(并储存)

# 查看重复数据
print(df_test.duplicated().sum())
# print(df_test[df_test.duplicated()==True])
# 删除重复数据
index = df_test[df_test.duplicated() == True].index
df_test.drop(index, axis=0, inplace=True)
df_test.reset_index(inplace=True, drop=True)
print(df_test.duplicated().sum())

# 重复数据是text一样但是label不一样的数据。
print(df_test['text'].duplicated().sum())
print(df_test[df_test['text'].duplicated() == True])
# 查看例子
# print(df_test[df_test['text'] == df_test.iloc[2046]['text']])
# print(df_test[df_test['text'] == df_test.iloc[3132]['text']])
# 去掉text一样但是label不一样的数据
index = df_test[df_test['text'].duplicated() == True].index
df_test.drop(index, axis=0, inplace=True)
df_test.reset_index(inplace=True, drop=True)
# 检查
print(df_test['text'].duplicated().sum())  # 0
# print(df_test)
# 检查形状与编号
print(df_test.tail())
print(df_test.shape)
df_test.to_csv('./data/archive/test_clean.csv', index=None)

有的注释可以打开自己看着调。

清洗验证集

观察数据分布

# 观察数据是否平衡
print(df_val.label.value_counts())
print(df_val.label.value_counts() / df_val.shape[0] * 100)
plt.figure(figsize=(8, 4))
sns.countplot(x='label', data=df_val)
plt.show()

有三个取值我们需要剔除。

# 丢掉异常数据
df_val.drop(df_val[(df_val.label == '4') |
                   (df_val.label == '-') |
                   (df_val.label == '·')].index, inplace=True, axis=0)
df_val.reset_index(inplace=True, drop=True)
print(df_val.label.value_counts())
sns.countplot(x='label', data=df_val)
plt.show()

去除空行

# 观察是否有空行
print(df_val.isnull().sum())
# 删除空行数据
df_val.dropna(axis=0, how='any', inplace=True)
df_val.reset_index(inplace=True, drop=True)
print(df_val.isnull().sum())

去除重复数据(并储存)

# 查看重复数据
print(df_val.duplicated().sum())
# print(df_val[df_val.duplicated()==True])
# 删除重复数据
index = df_val[df_val.duplicated() == True].index
df_val.drop(index, axis=0, inplace=True)
df_val.reset_index(inplace=True, drop=True)
print(df_val.duplicated().sum())

# 重复数据是text一样但是label不一样的数据。
print(df_val['text'].duplicated().sum())
# print(df_val[df_val['text'].duplicated() == True])
# 查看例子
# print(df_val[df_val['text'] == df_val.iloc[1817]['text']])
# print(df_val[df_val['text'] == df_val.iloc[2029]['text']])
# 去掉text一样但是label不一样的数据
index = df_val[df_val['text'].duplicated() == True].index
df_val.drop(index, axis=0, inplace=True)
df_val.reset_index(inplace=True, drop=True)
# 检查
print(df_val['text'].duplicated().sum())  # 0
# print(df_val)
# 检查形状与编号
print(df_val.tail())
print(df_val.shape)
df_val.to_csv('./data/archive/val_clean.csv', index=None)

总结

到此为止,我们已经清洗好了数据。让我们来看看在本次清洗时,忽略了哪些在其他实验中可以继续改进的地方:

  • 本次清洗没有去除停用词,因为使用bert时去除停用词可能会丢失上下文。
  • 本次清洗没有去除特殊字符,因为bert会自动将未知字符转化为[UKN]
  • 本次没有对样本进行过采样/欠采样来解决imbalance问题,这个问题留到评估模型后再考虑要不要讨论。

下一篇文章中,我们将会使用Pytorch搭建Bert和双向LSTM实现多分类。

代码汇总

数据清洗.py

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# todo:读取数据
df_train = pd.read_csv('./data/archive/train.csv')
df_test = pd.read_csv('./data/archive/test.csv')
df_val = pd.read_csv('./data/archive/val.csv')

# 输出前5行
# print(df_train.head())
# print(df_train.shape)

# print(df_test.head())
# print(df_test.shape)

# print(df_val.head())
# print(df_val.shape)

# todo: 清洗Train
# 观察数据是否平衡
# print(df_train.label.value_counts())
# print(df_train.label.value_counts() / df_train.shape[0] * 100)
# plt.figure(figsize=(8, 4))
# sns.countplot(x='label', data=df_train)
# plt.show()
# print(df_train[df_train.label > 5.0])
# print(df_train[(df_train.label < -1.1)])
# 丢掉异常数据
df_train.drop(df_train[(df_train.label < -1.1) | (df_train.label > 5)].index, inplace=True, axis=0)
df_train.reset_index(inplace=True, drop=True)
# print(df_train.label.value_counts())
# sns.countplot(x='label', data=df_train)
# plt.show()

# 观察是否有空行
# print(df_train.isnull().sum())
# 删除空行数据
df_train.dropna(axis=0, how='any', inplace=True)
df_train.reset_index(inplace=True, drop=True)
# print(df_train.isnull().sum())

# 查看重复数据
# print(df_train.duplicated().sum())
# print(df_train[df_train.duplicated()==True])
# 删除重复数据
index = df_train[df_train.duplicated() == True].index
df_train.drop(index, axis=0, inplace=True)
df_train.reset_index(inplace=True, drop=True)
# print(df_train.duplicated().sum())

# 我们还需要关心的重复数据是text一样但是label不一样的数据。
# print(df_train['text'].duplicated().sum())
# print(df_train[df_train['text'].duplicated() == True])
# 查看例子
# print(df_train[df_train['text'] == df_train.iloc[856]['text']])
# print(df_train[df_train['text'] == df_train.iloc[3096]['text']])
# 去掉text一样但是label不一样的数据
index = df_train[df_train['text'].duplicated() == True].index
df_train.drop(index, axis=0, inplace=True)
df_train.reset_index(inplace=True, drop=True)
# 检查
# print(df_train['text'].duplicated().sum())  # 0
# print(df_train)
# 检查形状与编号
print("======train-clean======")
print(df_train.tail())
print(df_train.shape)
df_train.to_csv('./data/archive/train_clean.csv', index=None)

# todo: 清洗test
# 观察数据是否平衡
# print(df_test.label.value_counts())
# print(df_test.label.value_counts() / df_test.shape[0] * 100)
# plt.figure(figsize=(8, 4))
# sns.countplot(x='label', data=df_test)
# plt.show()
# 观察是否有空行
# print(df_test.isnull().sum())
# 删除空行数据
df_test.dropna(axis=0, how='any', inplace=True)
df_test.reset_index(inplace=True, drop=True)
# print(df_test.isnull().sum())
# 查看重复数据
# print(df_test.duplicated().sum())
# print(df_test[df_test.duplicated()==True])
# 删除重复数据
index = df_test[df_test.duplicated() == True].index
df_test.drop(index, axis=0, inplace=True)
df_test.reset_index(inplace=True, drop=True)
# print(df_test.duplicated().sum())
# 重复数据是text一样但是label不一样的数据。
# print(df_test['text'].duplicated().sum())
# print(df_test[df_test['text'].duplicated() == True])
# 查看例子
# print(df_test[df_test['text'] == df_test.iloc[2046]['text']])
# print(df_test[df_test['text'] == df_test.iloc[3132]['text']])
# 去掉text一样但是label不一样的数据
index = df_test[df_test['text'].duplicated() == True].index
df_test.drop(index, axis=0, inplace=True)
df_test.reset_index(inplace=True, drop=True)
# 检查
# print(df_test['text'].duplicated().sum())  # 0
# print(df_test)
# 检查形状与编号
print("======test-clean======")
print(df_test.tail())
print(df_test.shape)
df_test.to_csv('./data/archive/test_clean.csv', index=None)

# todo: 清洗验证集
# 观察数据是否平衡
# print(df_val.label.value_counts())
# print(df_val.label.value_counts() / df_val.shape[0] * 100)
# plt.figure(figsize=(8, 4))
# sns.countplot(x='label', data=df_val)
# plt.show()
# 丢掉异常数据
df_val.drop(df_val[(df_val.label == '4') |
                   (df_val.label == '-') |
                   (df_val.label == '·')].index, inplace=True, axis=0)
df_val.reset_index(inplace=True, drop=True)
# print(df_val.label.value_counts())
# sns.countplot(x='label', data=df_val)
# plt.show()

# 观察是否有空行
# print(df_val.isnull().sum())
# 删除空行数据
df_val.dropna(axis=0, how='any', inplace=True)
df_val.reset_index(inplace=True, drop=True)
# print(df_val.isnull().sum())

# 查看重复数据
# print(df_val.duplicated().sum())
# print(df_val[df_val.duplicated()==True])
# 删除重复数据
index = df_val[df_val.duplicated() == True].index
df_val.drop(index, axis=0, inplace=True)
df_val.reset_index(inplace=True, drop=True)
# print(df_val.duplicated().sum())

# 重复数据是text一样但是label不一样的数据。
print(df_val['text'].duplicated().sum())
# print(df_val[df_val['text'].duplicated() == True])
# 查看例子
# print(df_val[df_val['text'] == df_val.iloc[1817]['text']])
# print(df_val[df_val['text'] == df_val.iloc[2029]['text']])
# 去掉text一样但是label不一样的数据
index = df_val[df_val['text'].duplicated() == True].index
df_val.drop(index, axis=0, inplace=True)
df_val.reset_index(inplace=True, drop=True)
# 检查
print(df_val['text'].duplicated().sum())  # 0
# print(df_val)
# 检查形状与编号
print("======val-clean======")
print(df_val.tail())
print(df_val.shape)
df_val.to_csv('./data/archive/val_clean.csv', index=None)

http://www.kler.cn/a/15336.html

相关文章:

  • 网页作业9
  • python+Django+MySQL+echarts+bootstrap制作的教学质量评价系统,包括学生、老师、管理员三种角色
  • 深度学习:计算卷积神经网络中输出特征图尺寸的关键公式
  • 物联网——UNIX时间戳、BKP备份寄存器、RTC时钟
  • Java连接MySQL(测试build path功能)
  • 分布式cap理论学习
  • C#简单向:textbox添加提示内容
  • APK文件结构
  • 成功上岸国防科大!
  • 【22-23 春学期】人工智能基础--AI作业6-误差反向传播
  • 大数据管理中心规划设计方案(ppt可编辑)
  • Centos编译安装python3.9.0
  • 【IT 常识】安装(Install)、部署(Deploy)、搭建(Setup)和配置(Configure、Set)的区别
  • Selenium:WebDriver简介及元素定位
  • 人人都能用,3s学会加密你的网址变成ooo
  • LVS+Keepalived群集
  • 第三节:支持向量机分类预测
  • 【论文精度(李沐老师)】Deep Residual Learning for Image Recognition
  • 【opencv】几何变换——仿射变换(4 计算仿射矩阵)
  • 【具体到每一步】uniapp打包app生成apk包
  • MySql数据库的初步安装与数据表结构数据管理
  • 进程与线程:同步和互斥
  • WiFi电子标签简介
  • 关于GeoServer发布服务时数据源设置的避坑指南
  • 如何本地mySQL迁移到另一台电脑上?
  • 【Linux基本指令和权限(1)】