当前位置: 首页 > article >正文

【大模型实战篇】高质量数据过滤及一种BoostedBaggingFilter处理方法的介绍

1. 高质量数据过滤  

1.1 背景介绍

        数据质量对于大模型的训练至关重要,经常会听到一句话:数据决定模型的上限。模型的性能上限通常受到训练数据的质量限制。如果数据集不够好,模型可能无法学习到泛化的特征,导致其在新数据上的表现不佳。在大模型场景这个道理也通用,训练大模型依然需要喂入高质量的数据。否则造成大量计算资源的浪费,而且学习效果也差强人意。高质量的数据可以提高大型语言模型(LLMs)的最新技术水平,同时显著减少数据集的大小和训练计算量。重要的是需要较少训练的较小模型可以显著降低LLMs的环境成本【1,2】。

        训练大模型之前需要收集大量的文本数据。在收集了大量的文本数据后,为确保数据的质量和有效性,还需进行预处理,以清除低质量、冗余、无关甚至可能有害的数据【3】。目前业内出现了一些系统化的数据处理框架,比如开源库Data-Juicer 【4】来保证预训练数据的质量。       

        【3,4,6,7】中介绍了一系列常用的数据预处理方法和步骤。下图是典型的大语言模型预训练数据的预处理流程,可以参考加深对于预处理过程的了解。

        接下来,详细介绍其中的重要步骤,本文会重点关注质量过滤环节。

        收集到的原始文本数据通常包含大量低质量的内容,例如,从网页抓取的数据中可能包含机器生成的广告网页。为了提升模型学习的性能,有必要从语料库中剔除这些低质量数据。目前,研究人员主要使用两种数据清洗方法:(1)基于启发式规则的方法,和(2)基于分类器的方法。

1.2 基于启发式规则的方法

        通过设计针对性的规则来识别和剔除低质量文本数据是一种常见做法。不同类型的文本数据通常需要不同的清洗规则。例如,在处理问答数据时,可以通过过滤点赞数过少的帖子来剔除低质量内容;在处理代码语料时,可以过滤掉与代码无关的数据格式。以下是对一些常见数据集的清洗规则的总结【3】,仅供参考。

  • 基于语种的过滤:在训练以特定语言为主的大语言模型时,通常需要过滤掉其他语言的文本数据。例如,为了训练非英文主导的大语言模型(如中英双语模型),不仅要保留特定目标语言的数据,还需要保留高质量的英文数据。在训练过程中,使用语言识别器过滤非中英文的网页数据。

  • 基于简单统计指标的过滤:可以通过语料中的标点符号分布、符号与单词比率、句子长度等特征来衡量文本质量,并过滤低质量数据。例如:

    1. 过滤掉任何包含超过100个重复单词或句子的网页数据。
    2. 过滤掉符号与词元比大于0.1的网页数据。
    3. 过滤掉点赞数少于3的评论。
    4. 利用已有的语言模型计算文档的困惑度,作为过滤依据。
    5. 使用FastText分类器检测并删除有害或仇恨言论。
  • 基于关键词的过滤:预训练语料中可能存在大量重复文本模式,如HTML标签、超链接和模板等,还可能包含攻击性或冒犯性的信息。为解决这些问题,可以使用特定的关键词集对文本进行扫描和过滤。例如:

    1. 过滤掉任何少于25个UTF-8单词的页面。
    2. 过滤掉网页数据中的HTML标签。
    3. 过滤掉不含常用词汇(如the, be, to, of, and, that, have, with,的,是,到,的,和,那个,有,与。)的文档。
    4. 过滤掉所有数据中的电话号码、邮箱地址及IP地址等隐私信息。

1.3 基于分类器的方法

        除了启发式规则,还可以训练文本分类器来判断数据质量并进行清洗。具体来说,可以对部分代表性数据进行质量标注,以此训练一个精准的文本质量分类器。例如,可将高质量数据作为正样本,从网页中筛选出含有不良或低质量数据的样本作为负样本。使用这种分类器可以精确识别并过滤低质量数据,从而显著提升语料库的整体质量。当然这里面涉及到一系列工作,比如需要明确什么样的数据被认为是高质量的。这可能包括数据的完整性、准确性、一致性、相关性和时效性。根据你的数据类型和质量标准,选择合适的分类器模型。提取有助于分类器区分高质量和低质量数据的特征,这可能包括统计特征、文本特征、图像特征等。

        基于分类器的方法可能会无意中删除一些低资源但高质量的文本(如文言文数据)。为减少误筛,可以使用多个分类器进行联合过滤或召回。此外,还可根据不同的评估维度训练不同的分类器,采用类似集成的方式进行全面过滤。目前常用的分类器方法包括轻量级模型(如FastText等)、可微调的预训练语言模型(如BERT、LLaMA3等)以及闭源大语言模型API(如GPT-4o、Claude 3.5)。

1.4 效率与准确性的平衡

        在数据清洗过程中,过滤效率也是需要考虑的重要因素。例如,基于启发式规则的方法由于其规则设计相对简单,可以迅速过滤大量文档集;而基于分类器的方法尽管能提供更高的精确度,但需要消耗更多的计算资源。因此,可以结合使用这两种方法以达到平衡——先通过启发式规则进行初筛,快速排除不符合要求的文档,再使用分类器进一步精细过滤,以确保最终语料的高质量。

2. BoostedBaggingFilter处理方法

2.1 项目目标

        关于基于分类器的数据过滤方法,想起很多年前自己做的一个项目中也涉及到类似的任务,需要从用户标注的数据中筛选出可靠性强的样本数据,过滤掉噪声数据,虽然时间久远,但思路也许还是可以参考,这里做下简单的介绍。该方法其实是基于多层次集成分类器的思想,称为BoostedBaggingFilter。

        当时项目的目标是增加对用户标注数据的自学习,提高算法自适应能力,使算法能够应对新数据的分类任务,并且能够有效提升现有算法的分类准确性。基本思路如下:

 (1)增加与原有数据相似的样本数据,即使用样本数据训练产生的算法模型,去识别更多的能够支持样本数据分类准确度的用户标注数据。也就是同类型但可能表达形式差异的数据。


 (2)增加原有样本数据不存在的新的分类数据。比如原有数据中不存在"Uber#交通出行",现有分类算法会使得Uber被判别为其他。用户对该流水Uber人为改为Uber#交通出行,那么通过识别算法可以将该正确的数据识别出来,将Uber#交通出行这条新的信息补充到训练样本中,使得现有分类算法具备对新数据的识别能力。

2.2 算法思路描述

总体算法框架示意图如下:
        

说明:

解决的问题:
用户标注数据(反馈信息)的有效利用。

方法:
采用两层标注数据过滤算法,识别和利用正确的用户标注数据。

  1. 第一层过滤算法:旨在识别与现有样本数据存在交叉信息的标注数据,有助于提高分类的准确性。
  2. 第二层过滤算法:旨在识别全新的、未在样本数据中出现过的标注数据,帮助分类算法更好地应对全新的数据来源情况。
  • 第一层过滤算法(橙色椭圆):使用 Biterm 特征结合多分类逻辑回归(OvsR),以样本数据作为训练集,识别出与样本数据存在交叉信息的数据来源。
  • 第二层过滤算法(红色椭圆):采用 Boosted Bagging Filter 算法,针对第一层算法未能识别的标注数据。这些数据可能包括明显存在分类误差的样本以及在现有样本数据中没有对应信息的全新数据,从而有效识别出用户标注的新数据。

算法测试结果:

  1. 对于备注特征与样本特征存在一定交集的用户标注数据,通过交叉验证测试,过滤后的用户标注数据准确率平均达到了 99%。
  2. 对于备注特征与样本特征没有交集的全新用户标注数据,过滤后的类别准确率为 99.3%,错误数据的召回率为 97.8%。

2.3 标注数据过滤算法详述

目的:通过算法对用户标注数据进行清洗,去除噪声数据。

2.3.1 第一层过滤算法流程

目标: 使用样本集训练一个基本的分类器 f,这里选择逻辑回归(Logistic Regression)作为分类器,策略为“一对多”(One-vs-Rest)。

流程:

  1. 对于每一个新的用户标注数据实例 x_i​,使用分类器f(x_i)进行预测,获得预测概率的最大值 \text{max}(f(x_i)),对应的类别为 y'_i
  2. 获取用户实际标注的类别 y_i。
  3. 如果\text{max}(f(x_i)) > ty'_i = y_i,则将该用户标注数据添加到训练集中,继续用于分类器训练;否则,将该数据放入候选标注数据集,交由第二层过滤算法处理。

测试结果:过滤算法对测试集中数据的预测类别与实际类别的平均匹配度达到了 99%。匹配准确率是指类别正确标注的比例。

2.3.2 第二层过滤算法

目标:基于带噪声的标注数据集 D = \{(x_i, y_i) \mid 1 \leq i \leq n\},通过迭代和加权抽样的方法识别并过滤掉噪声数据。实验参数设置为 Bootstrap 样本数 m = 10,迭代次数 T = 10,并使用逻辑回归(Logistic Regression)作为基本分类器。

流程:

  1. 初始化错误分类计数 TNC

    \text{TNC}(i) 表示第 i 个实例的错误分类次数,初始状态为 0:\text{TNC}(i) = 0, \, i = 1, \ldots, n
  2. 初始化实例选择权重 W

    W(i) 表示第 i 个实例的选择权重,初始状态为均匀分布:W(i) = \frac{1}{n}, \, i = 1, \ldots, n
  3. 迭代过程(从 t = 1 到 T):

    • 对每次迭代:
      • 对每次 Bootstrap 采样(从 j = 1 到 m):
        • 从数据集中根据权重 W 抽样生成样本集 S_j = \text{ResampleWithWeight}(D, W)
        • 用样本集 S_j 训练分类器 h_j = \text{buildClassifier}(S_j)
      • 重置局部错误计数 NC\text{NC}(i) = 0
      • 对每个实例 x_i​:
        • 检查所有分类器 hjh_jhj​ 的预测结果:
          • 如果 h_j(x_i) \neq y_i,则增加相应的错误计数:\text{NC}(i)++
        • 更新全局错误计数:\text{TNC}(i) = \text{TNC}(i) + \text{NC}(i)
      • 更新实例权重:
        • W(i) \leftarrow W(i) \times e^{-\text{NC}(i)}
        • 归一化权重:W(i) \leftarrow \frac{W(i)}{\sum W(i)}​。
  4. 返回按 TNC 升序排列的实例。

  5. 根据 TNC 结果:

    • TNC 为 0 的实例标记为正确分类。
    • TNC 大于等于 10 的实例标记为错误分类。

算法测试方式:

  • 从样本集中随机抽取 5000 个实例,设置噪声比例(Noise Rate)来模拟带噪数据。例如,噪声比例为 0.2 表示 5000 个样本中有 20% 的数据被标记为错误类别。

  • 计算准确率(Precision)和召回率(Recall):这里的准确率表示TNC为0的实例中被正确标注的实例个数与TNC为0的实例总个数的比例。而召回率是指TNC 非 0 的实例中包含的噪声实例个数与所有噪声实例总个数的比例。

算法测试结果:

噪音率准确率(Precision)召回率(Recall)
0.10.9980.986
0.20.9970.989
0.30.9920.985
0.60.9600.982
0.80.8750.985

代码示例(java),仅供参考:

package com.yuanquan.userlabelcleaning;

import com.yuanquan.liblinear.FeatureNode;
import com.yuanquan.liblinear.Model;
import com.yuanquan.utils.Utils;
import com.yuanquan.utils.SysCfgUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.*;
import java.util.*;

/**
 * <p>
 * Title:BoostedBaggingFilter
 * </p>
 * <p>
 * Description:
 * 第二层过滤算法BBF, 提取出新的用户标注信息
 * 通过bootstrap抽样,训练多个子分类器, 并通过调整每一个实例的权重使得最正确的数据可以被选择.
 * 由于用户标注数据经过第一层算法过滤后, 会更新到第二层算法数据池, 因此每次使用都需要重新训练,针对不同的现有数据池.
 * </p>
 *
 */
public class BoostedBaggingFilter {

    private static final Logger LOGGER = LoggerFactory.getLogger(BoostedBaggingFilter.class);

    private static int T;  //BBF模型迭代次数
    private static int NUM_BOOTSTRAP_SAMPLE; //bootstrap抽样次数
    private static int THRESHOLD_TRUE; //BBF算法中错误数小于该值,则被认为用户分类正确的可能性极大
    private static int THRESHOLD_FALSE; //BBF算法中错误数大于该值,则被认为用户分类错误的可能性极大

    private static String PRE_PATH;
    private static String BASIC_INFO;
    private HashMap<String, String> BasicInfo;

    public BoostedBaggingFilter(){
        try {
            PRE_PATH = SysCfgUtils.getInstance().getValue("PRE_PATH");
            BASIC_INFO = PRE_PATH + "basicInfo.txt";
            BasicInfo = loadBasicInfo();
            T = Integer.parseInt(SysCfgUtils.getInstance().getValue("T"));
            NUM_BOOTSTRAP_SAMPLE = Integer.parseInt(SysCfgUtils.getInstance().getValue("NUM_BOOTSTRAP_SAMPLE"));
            THRESHOLD_TRUE = Integer.parseInt(SysCfgUtils.getInstance().getValue("THRESHOLD_TRUE"));
            THRESHOLD_FALSE = Integer.parseInt(SysCfgUtils.getInstance().getValue("THRESHOLD_FALSE"));

        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public HashMap<String, String> loadBasicInfo() throws IOException {
        BufferedReader inputReader = new BufferedReader(new InputStreamReader(new FileInputStream(BASIC_INFO)));
        try {
            return loadBasicInfo(inputReader);
        } finally {
            inputReader.close();
        }
    }

    //基于集成bagging过滤算法, 筛选出能够可信度大的正确及错误数据
    public  ArrayList<ArrayList<Integer>> ClassifyByBoostedBaggingFilter(List<CategoryAndFeatures> noiseInstances) throws IOException {
        int lenOfNoiseInstances = noiseInstances.size();
        int[] TNC = new int[lenOfNoiseInstances];
        double[] Weights = new double[lenOfNoiseInstances];
        int[] instances = new int[lenOfNoiseInstances];
        for(int i=0;i<lenOfNoiseInstances;i++){
            TNC[i] = 0;
            instances[i] = i;
            Weights[i] = 1.0/lenOfNoiseInstances;
        }

        for(int t=0;t<T;t++){
            Model[] LRmodels = new Model[NUM_BOOTSTRAP_SAMPLE];
            LOGGER.info("T:"+t);
            for(int j=0;j<NUM_BOOTSTRAP_SAMPLE;j++){
                LOGGER.info("NUM_BOOTSTRAP_SAMPLE:"+t+"-"+j);
                int[] newInstances = resampleWithWeights(Weights, instances);
                String modelFilePath = "BBF_"+j+".txt";
                TrainingDataFormat TDF = constructTrainingFeature(noiseInstances, newInstances);
                Model model = BaseOvsRLogisticRegression.trainingByLR(TDF.getFeatures(),
                        Integer.parseInt(BasicInfo.get("NumberOfFeature")), TDF.getGROUPS_ARRAY(), modelFilePath, false);
                LRmodels[j] = model;
            }
            int[] NC = new int[lenOfNoiseInstances];
            //noise data classification
            for(int i=0;i<lenOfNoiseInstances;i++){
                CategoryAndFeatures currTestInstance = noiseInstances.get(i);
                for(int j=0;j<NUM_BOOTSTRAP_SAMPLE;j++){
                    Model model = LRmodels[j];
                    double[] prediction = BaseOvsRLogisticRegression.predict(model, currTestInstance.getFeatures());
                    double preClass = prediction[0];
                    if(preClass != currTestInstance.getCategory()){
                        NC[i] = NC[i] + 1;
                    }
                }
                TNC[i] = TNC[i] + NC[i];
            }

            for(int i=0;i<lenOfNoiseInstances;i++){
                Weights[i] = Weights[i] * Math.exp(-NC[i]);
            }
            Weights = normalize(Weights, sum(Weights));
        }

        // 按照TNC值升序返回noise Instances
        Map<Integer, Integer> TNC_instance = new HashMap<Integer, Integer>();
        for(int i=0;i<lenOfNoiseInstances;i++){
            TNC_instance.put(i, TNC[i]);
        }

        TNC_instance = Utils.sortByValue(TNC_instance);

        ArrayList<Integer> TrueInstance = new ArrayList<Integer>();
        ArrayList<Integer> FalseInstance = new ArrayList<Integer>();

        for(Map.Entry<Integer, Integer> entry : TNC_instance.entrySet()) {
            LOGGER.info(entry.getKey()+"\t"+TNC_instance.get(entry.getKey()));

            if(entry.getValue() < THRESHOLD_TRUE){
                TrueInstance.add(entry.getKey());
            }

            if(entry.getValue() > THRESHOLD_FALSE){
                FalseInstance.add(entry.getKey());
            }

        }

        //返回最正确的,和最错误的instance index
        ArrayList<ArrayList<Integer>> TrueFalseInstanceIndex = new ArrayList<ArrayList<Integer>>();
        TrueFalseInstanceIndex.add(TrueInstance);
        TrueFalseInstanceIndex.add(FalseInstance);

        return TrueFalseInstanceIndex;


    }

    public HashMap<String, String> loadBasicInfo(Reader inputReader) throws IOException {
        BufferedReader reader;
        if (inputReader instanceof BufferedReader) {
            reader = (BufferedReader)inputReader;
        } else {
            reader = new BufferedReader(inputReader);
        }

        HashMap<String, String> infos = new HashMap<String, String>();

        String line;
        while ((line = reader.readLine()) != null) {
            String[] split = line.split(SysCfgUtils.getInstance().getValue("SEPARATOR_sys"));
            infos.put(split[0],split[1]);
        }
        reader.close();

        return infos;

    }


    public TrainingDataFormat constructTrainingFeature(List<CategoryAndFeatures> noiseInstances, int[] newInstances) {
        int lenOfInstances = newInstances.length;
        FeatureNode[][] trainingFeature = new FeatureNode[lenOfInstances][];
        double[] targetArray = new double[lenOfInstances];
        for (int i = 0; i < lenOfInstances; i++) {
            int index = newInstances[i];
            trainingFeature[i] = noiseInstances.get(index).getFeatures();
            targetArray[i] = noiseInstances.get(index).getCategory();
        }
        TrainingDataFormat TDF = new TrainingDataFormat();
        TDF.setFeatures(trainingFeature);
        TDF.setGROUPS_ARRAY(targetArray);
        return TDF;
    }

    public int[] resampleWithWeights(double[] weights, int[] instances) {

        if (weights.length != instances.length) {
            throw new IllegalArgumentException("weights.length != numOfInstances.");
        }

        int numInstances = instances.length;

        int[] newData = new int[numInstances];
        if (numInstances == 0) {
            return newData;
        }
        Random random = new Random();
        double[] probabilities = new double[numInstances];
        double sumProbs = 0, sumOfWeights = sum(weights);
        for (int i = 0; i < numInstances; i++) {
            sumProbs += random.nextDouble();
            probabilities[i] = sumProbs;
        }
        normalize(probabilities, sumProbs / sumOfWeights);
        probabilities[numInstances - 1] = sumOfWeights;
        int k = 0; int l = 0;
        sumProbs = 0;
        int count = -1;
        while ((k < numInstances && (l < numInstances))) {
            if (weights[l] < 0) {
                throw new IllegalArgumentException("Weights have to be positive.");
            }
            sumProbs += weights[l];
            while ((k < numInstances) && (probabilities[k] <= sumProbs)) {
                count++;
                newData[count] = instances[l];
                k++;
            }
            l++;
        }
        return newData;
    }

    public double sum(double[] weights) {
        double summation = 0;
        for (double weight : weights) {
            summation = summation + weight;
        }
        return summation;
    }

    public double[] normalize(double[] doubles, double summation) {
        for (int i = 0, len = doubles.length; i < len; i++) {
            doubles[i] = doubles[i] * 1.0 / summation;
        }
        return doubles;
    }
}

3. 参考材料

【1】Textbooks Are All You Need

【2】LLM and Dataset Quality

【3】Data Collection and Preprocessing for Large Language Models

【4】RUC AI box-《大语言模型》

【5】Data-Juicer: A One-Stop Data Processing System for Large Language Models

【6】Towards Trustable Language Models: Investigating Information Quality of Large Language Models

【7】A Survey of Large Language Models


http://www.kler.cn/a/308364.html

相关文章:

  • UDP协议和TCP协议之间有什么具体区别?
  • C获取程序名称的方法
  • 在 Service Worker 中caches.put() 和 caches.add()/caches.addAll() 方法他们之间的区别
  • STM32单片机WIFI语音识别智能衣柜除湿消毒照明
  • 设计模式:工厂方法模式和策略模式
  • 【金融风控】特征评估与筛选详解
  • JDK的选择安装和下载
  • 软考 -- 软件设计师 -- 二轮复习(3) -- 数据结构(持续更新)
  • 24.Redis实现全局唯一ID
  • 电脑信息安全:挑战与应对策略
  • PAT甲级-1055 The World‘s Richest
  • 【C++学习入门】6.左值右值
  • 软件测试方法及其应用概述
  • JZ2440开发板——S3C2440的时钟体系
  • RFID射频模块(MFRC522 STM32)
  • Linux 之父 Linus Torvalds:低调的神话创造者
  • 网络协议全景:Linux环境下的TCP/IP、UDP
  • chattr:修改文件的特殊属性
  • Linux抢占调度
  • kali——tshark的使用
  • 2024 批量下载知乎回答/文章/想法/专栏/视频/收藏夹,导出 excel 和 pdf
  • Django_Vue3_ElementUI_Release_004_使用nginx部署
  • idea插件开发的第四天-完善JSON工具
  • 算法:76.最小覆盖子串
  • 文章-深入GPU硬件架构及运行机制 学习后记
  • 撤回仓库的提交