当前位置: 首页 > article >正文

IJCAI-信也科技杯全球AI大赛-华东师范大学亚军队伍分享

作者:彭欣怡(找不到工作版) 华东师范大学; 马千里(搬砖版) 虾皮;
指导:闫怡搏(科研版) 华东师范大学
比赛链接:https://ai.ppdai.com/mirror/goToMirrorDetailSix?mirrorId=34

前言

这是我们首次参加语音领域的比赛,最初只是抱着试一试的心态,想借此机会打磨一下DL的基本功。凭借一点点运气,我们最终斩获了亚军。同时,我们也有幸向包括来自小米AI实验室的冠军团队在内的众多优秀团队学习,收获颇丰。
非常庆幸能在如此大规模的语音比赛中拿到名次,我们也在IJCAI workshop (2024) 中分享了方案(还见到了周志华大佬,非常激动)。接下来,我们将从语音领域初学者的视角,分享一些在比赛中的方案与心得,希望能为大家带来一些启发,也期望能激励更多人工智能爱好者勇敢参与比赛,探索更多的可能性。

赛题背景

本次比赛的赛题聚焦于语音deepfake技术的深入研究,这是一项能够生成逼真伪造语音的前沿技术,但其背后潜藏的隐患不容忽视,对个人隐私和信息安全构成了严峻挑战。此次赛题由信也科技发起,旨在通过模拟真实场景,推动对伪造语音检测技术的研究与发展。在本题中,选手需要建模预测每条语音是否伪造。

数据描述

比赛数据集由真语音和假语音构成。真语音被定义为由一次录音生成的原始语音。假语音则是利用一种或多种技术手段,对真语音进行干预后产生的语音。为了更贴近现实世界的复杂性,主办方精心设计了一个多维度的数据集,涵盖了以下几个核心要点:
1.多语音:数据集中包含了多种不同来源的语音样本(约10+语种),模拟了信也在现实世界中可能遇到的多样化语音环境。
2.多来源:伪造方法非常多样,涵盖了从物理伪造到模型伪造的广泛手段。使用了数不胜数的TTS、VC伪造模型生成数据。
3.混杂对抗:为了模拟真实世界中的攻击场景,数据集中包含了各种对抗性样本,要求模型能够识别出这些经过特殊处理的伪造语音。测试集语音往往由多条连续录音/伪造样本进行混杂构成,以提高鉴别难度。
混杂对抗

数据样例
链接:https://pan.baidu.com/s/1yYLVc1dnPJLhmm0yrDJXxw?pwd=uifi
提取码:uifi

评价指标

本次比赛采用了异常检测领域中最常用的F1指标作为评价标准。由于测试集中正负样本的比例为1:1,因此未对recall和precision进行加权。然而,这道题目存在一项特殊的挑战:选手需要通过调整推理阈值,才能获得更高的F1得分。在quantile0.5决策平面附近的样本极具挑战性,难以准确区分,进一步增加了比赛的复杂性和竞争的激烈程度。

数据理解

针对比赛数据集,我们进行了深入细致的研究。测试集的语言分布相当广泛,涵盖了来自欧洲、亚洲等国家和地区的多种语言。此外,为了检验模型的实际应用价值,主办方在数据集中掺杂了大量由大模型生成的语音数据(我们甚至在测试集中听到了生成的丁真语音)。更为复杂的是,我们发现测试集中的大部分语音样本是由多种语言组合而成,有的甚至还混入了动物叫声和纯噪声数据。
数据分析

根据对数据的深入分析,我们展开了系统的思考,并为四种关键情况量身设计了解决方案。
① 为了让模型更好地适应多语言环境,利用Common Voice数据集填补了空缺,特别是加入了西班牙语、法语等测试集中出现频率较高的语言,并且使用了在多语言数据上预训练过的模型,从而增强模型的泛化能力。
② 为了提升模型在新数据上的表现,引入少量的multi-tts开源数据,旨在将问题从zero-shot转化为few-shot,从而为模型提供更多学习机会。
③ 面对异常样本严重不足的挑战,参考one-class分类理论,通过加入大量正常样本,使模型能够深入理解什么才是“正常”的语音特征,从而提高异常检测的准确性。
④ 为了有效应对混杂对抗样本(最关键的一步),我们精心设计了一套复杂的数据增强策略,旨在让模型在面对高度混杂的样本时,仍能保持卓越的辨识能力。

解题技巧

整体框架

在深入探讨技术细节之前,想先与大家分享一下我们整体的策略。这一策略可以用一句古老的智慧来概括——“三个臭皮匠,赛过诸葛亮”。正是秉持这种集思广益的理念,我们构建了一个由两个主要部分组成的框架。

整体框架

首先,我们构建了whisper的支线,它由经过微调的模型组成。我们特别挑选了whisper作为主要模型,因为它的聪明才智——也就是它的能力——经过在三个不同难度级别上精心设计的数据集微调后,表现得尤为出色。我们还通过引入voting机制,进一步提升模型的表现力,具体细节将在后续内容中详细介绍。(这一支线是最大贡献者。)
接下来是 “笨皮匠”部分,由MMS和AST模型以及决策树组成。虽然这两个模型没有经过微调,但它们在理解大量正常样本方面具有独特的优势。我们直接利用它们提取数据的特征表示,并基于这些信息训练我们的决策树模型。这种方法不仅简化了工程复杂度,同时也充分发挥了这些模型的长处,有助于提升我们最终方案的性能上限。
最终,将三组经过微调的whisper模型与预训练的决策树模型相结合,通过voting机制集成它们的预测结果,得出了最终的检测结果,使得模型在识别语音deepfake方面更加稳健和可靠。

技术细节

接下来,我们将深入探讨我们的关键技术细节,这些技术是我们这些语音小白的主要上分点。值得一提的是,部分技术其实是从其他领域迁移而来的,但经过巧妙的调整和应用,在语音任务中展现出了意想不到的效果。

关键技术1:数据增强

我们首先引入了频域和时域的数据增强技术,这是在语音竞赛中广受欢迎的策略。通过应用频域遮蔽(Freq Mask)和时域遮蔽(Time Mask),我们有效地迫使模型去学习更加鲁棒的特征,而这一点在本次赛题中显得尤为关键。这些技术通过对输入数据进行遮蔽和扭曲操作,不仅丰富了训练数据的多样性,还拓展了模型的数据分布,使得模型在面对从未见过的新数据时,依然能够保持出色的表现。
mask

关键技术2:Mixup数据增强

接下来要介绍的是Mixup数据增强方式,它为我们的模型性能带来了近2%的显著提升,我们得分的关键点。Mixup通过对不同样本进行凸组合,创造出新的虚拟样本,不仅增强了模型在语言切换场景中的适应能力,还从理论上优化了对近邻风险的建模,使得模型的判别边界更加平滑,从而大大提升了抗噪能力。
Mixup是我们在此次比赛中最重要的得分点,引入Mixup让模型在判别边界上不再过于紧凑,因而在应对噪声和对抗样本时表现得更加稳健。此外,从实践角度来看,测试集中包含大量语言切换样本,Mixup的操作让模型更好地适应了这一复杂场景,从而显著提升了最终得分。

我们的数据增强代码如下,在Mixup的设计上花了许多心思。

def clip(wave, min_length = 0.7, max_length = 0.9, sr = 16000):
    if len(wave) <= 15 * sr / 8:
        return wave
    length = np.random.randint(int(len(wave) * min_length), int(len(wave) * max_length))
    start_idx = np.random.randint(0, len(wave) - length)
    return wave[start_idx:start_idx + length]

def add_back_noise(wave, sr = 16000):
    noise, _ = librosa.load(np.random.choice(noise_paths), sr=sr)
    noise = np.tile(noise, 2 + int(np.ceil(len(wave) / len(noise))))
    return wave + noise[:len(wave)]

def add_white_noise(wave):
    white_noise = np.random.randn(len(wave)) * np.random.uniform(0.001, 0.02)
    return wave + white_noise

def add_noise(wave):
    return add_back_noise(wave) if np.random.rand() < 0.8 else add_white_noise(wave)

def pitch_shift(wave, sr = 16000):
    steps = 0
    while -1e-3 <= steps <= 1e-3:
        steps = np.random.uniform(-4, 4)
    return librosa.effects.pitch_shift(wave, sr=sr, n_steps=steps)

def echo(wave, sr = 16000):
    delay = np.random.uniform(0.1, 0.5)
    attenuation = np.random.uniform(0.1, 0.5)
    delay_samples = int(delay * sr)
    echo_filter = np.zeros(delay_samples + 1)
    echo_filter[0] = 1
    echo_filter[-1] = attenuation
    return lfilter(echo_filter, 1, wave)

def mixup(wave, other):
    if other is None:
        return wave
    other = clip(other)
    return np.concatenate((wave, other)) if np.random.rand() < 0.5 else np.concatenate((other, wave))

def speedup(wave):
    rate = 1
    while (1 - 1e-3) < rate < (1 + 1e-3):
        rate = np.random.uniform(0.9, 1.1)
    return librosa.effects.time_stretch(wave, rate=rate)

def resample(wave, orig_sr = 16000, mid_sr = 8000):
    wave = librosa.resample(wave, orig_sr=orig_sr, target_sr=mid_sr)
    wave = librosa.resample(wave, orig_sr=mid_sr, target_sr=orig_sr)
    return wave

class Collation():
    def __init__(self, train=True):
        self.train = train
        self.feature_extractor = AutoFeatureExtractor.from_pretrained(dirs)
    def collate_fn(self, batch):
        sr = 16000
        labels, inputs, mixups = [], [], []
        for item in batch:
            wave, path, label, mixcnt = item['audio']['array'], item['audio']['path'], item['label'], 0
            if self.train:
                while len(wave) >= 15 * sr / 2:
                    wave = clip(wave)
                # 单样本数据增强
                if np.random.rand() < 0.1:
                    wave = clip(wave)
                if np.random.rand() < 0.1:
                    wave = echo(wave)
                if np.random.rand() < 0.2:
                    wave = add_noise(wave)
                if np.random.rand() < 0.2:
                    wave = resample(wave)
                if np.random.rand() < 0.1:
                    wave = pitch_shift(wave) if np.random.rand() < 0.5 else speedup(wave)
                    label, mixcnt = 0, 1
                # mixup增强
                if len(wave) <= 15 * sr / 4 and mixcnt == 0 and np.random.rand() < 0.3:
                    mixcnt = np.random.randint(1,4)
                    for _ in range(mixcnt):
                        if len(wave) >= 15 * sr / 2:
                            break
                        if np.random.rand() < 0.9:
                            pos = get_data(inputs, labels, mixups, 1)
                            wave = mixup(wave, pos)
                        else:
                            neg = get_data(inputs, labels, mixups, 0)
                            wave = mixup(wave, neg)
                            label = 0
                while len(wave) >= 15 * sr / 2:
                    wave = clip(wave)
            labels.append(label)
            inputs.append(wave)
            mixups.append(mixcnt)
        inputs = self.feature_extractor(inputs, sampling_rate=sr, return_tensors='pt')
        inputs['labels'] = torch.tensor(labels)
        return inputs

关键技术3:注意力增强


注意力增强的策略是我们在比赛中灵光一闪的发现。经典的语言模型如whisper通常处理长达30秒的输入,而在分析后发现,训练数据的长度往往较短。基于这一观察,我们大胆地将模型的输入长度从30秒压缩至15秒。简单看,这一操作将计算复杂度从30的平方降至15的平方,极大地提升了计算效率。更为重要的是,当切除输入的后半段内容时,模型的注意力得以更加集中地作用于前半段,使得这一部分的参数得到了更充分的训练。这一调整最终为我们带来了关键的1个百分点精度提升。在实际操作中,我们采用了简(tou)单(le)直(ge)接(lan)的方式——将训练样本的长度截断至15秒,从而实现了这一“注意力增强”的策略。

关键技术4:模型集成

最后要介绍的是比赛中最常用的策略——模型集成。我们采用了三个不同构造的whisper模型,每个模型各有侧重:第一个模型专注于异常样本,旨在深入理解伪造细节;第二个模型则以均衡的输入样本为基础,追求稳健的中庸表现;第三个模型则专注于学习大量正常样本,以确保对正样本的准确识别。我们将这三个whisper模型与决策树模型进行集成,利用bagging和voting等集成方法的优势。通过加权平均有效地降低了基模型的方差,进而提升了整体的精度。

总结

作为语音领域的初学者,我们主要比赛中做了四方面的努力,得到最终成果

  1. 深入分析:我们对音频的基础统计量进行了细致的分析,包括音频长度、语言分布等。这些深入的分析帮助我们设计出针对性的解决方案,奠定了成功的基础。
  2. 模型应用:我们尝试并使用了三种多语言预训练模型,以whisper为核心,采用voting算法将不同严格程度的模型进行融合,达到了“三个臭皮匠顶一个诸葛亮”的效果。
  3. 技巧创新:针对测试集中复杂多样的数据结构,我们精心设计了复杂的数据增强策略。这不仅丰富了数据的多样性,还有效缓解了分布偏移问题,使模型表现更加稳健。
  4. 数据选择:我们选取了大量多语言伪造数据和真实数据,特别是最新的伪造样本,以确保模型能够更贴近测试集的分布,提升了模型的实际表现。
    所以说,即使是“小白”,只要敢于尝试,把自己的想法付诸实践,就有可能取得意想不到的收获。
    最后,要特别感谢信也科技提供的这个平台,让我们有机会在实践中打磨技术、验证思路。同时,衷心感谢比赛中的答疑大佬强哥,让我们感受到了比赛背后的温度与关怀。
    希望我们的分享能为大家带来一些见解和启发。

http://www.kler.cn/a/288183.html

相关文章:

  • python类和对象
  • 基于华为ENSP的OSPF状态机、工作过程、配置保姆级别详解(2)
  • 一个基于Spring Boot的智慧养老平台
  • Flink概念知识讲解之:Restart重启策略配置
  • 以下是一些常见的浏览器倒计时测试方法:
  • 从误删到重生:2024年数据恢复软件市场新趋势与精选工具
  • VirtualBox 中 Ubuntu 系统在桥连模式下网络适配器启动过慢或连接失败
  • 如何本地搭建Whisper语音识别模型
  • MySQL5.6迁移到DM8
  • FastAPI 进阶:使用 Pydantic 验证器增强 Query 参数验证
  • 数据结构-二叉树的遍历和线索二叉树
  • 《C++打造高效网络爬虫:突破数据壁垒》
  • CentOS全面停服,国产化提速,央国企信创即时通讯/协同门户如何选型?
  • 技术指南:5分钟零成本实现本地AI知识库搭建
  • 论文笔记: Boosting Object Detection with Zero-Shot Day-Night Domain Adaptation
  • 力扣229题详解:求众数 II 的多种解法与模拟面试问答
  • ELK日志服务收集SpringBoot日志案例
  • 【每日刷题】Day106
  • CentOS 安装 NVIDIA 相关软件包时出现依赖问题
  • 四层神经网络,反向传播计算过程;四层神经网络中:y的函数公式是什么
  • MySQL的事务认识
  • 传输层(TCP、UDP、RDT详解)
  • 视频智能分析打手机检测算法安防监控打手机检测算法应用场景、算法源码、算法模型介绍
  • 计算机网络(一) —— 网络基础入门