当前位置: 首页 > article >正文

AI大模型语音识别转文字

提取音频

本项目作用在于将常见的会议录音文件、各种语种音频文件进行转录成相应的文字,也可从特定视频中提取对应音频进行转录程文字保存在本地。最原始的从所给网址下载对应视频和音频进行处理。下载ffmpeg(https://www.gyan.dev/ffmpeg/builds/packages/ffmpeg-7.1-full_build.7z)并配置好环境变量(path.append(.\bin)),eg:

import os
import yt_dlp # 支持腾讯视频,小红书,tiktok,bilibili,youtube等
from url_cookies import cookies
from moviepy.editor import AudioFileClip

def get_available_formats(url):
    ydl_opts = {'quiet': True,'extract_flat': True} # 静默模式
    with yt_dlp.YoutubeDL(ydl_opts) as ydl:
        info_dict = ydl.extract_info(url, download=False) # info_dict['title']
        formats = info_dict.get('formats', [])
        if not formats:
            print("No available format found!")
            return
        results = pd.DataFrame(formats).iloc[:,:].groupby('ext').apply(lambda x: x.tail(1)).reset_index(drop=True)
        return results,info_dict['title']

def get_downloaded_filename(d):
    if d['status'] == 'processing': # callback function to capture the downloaded file name information
        print(f"download filename:{d['filename']}")

def download_video(url):
    get_best_id,filename = get_available_formats(url)
    id = (str(get_best_id[get_best_id['ext']=='mp4']['format_id'].values[0]) if len(get_best_id[get_best_id['ext']=='mp4'])>0 else '')+ \
            '+'+(str(get_best_id[get_best_id['ext']=='m4a']['format_id'].values[0]) if len(get_best_id[get_best_id['ext']=='m4a'])>0 else '')
    id = id.split('+')[0] if id.endswith('+') else id.split('+')[1] if id.startswith('+') else id
    ydl_opts = {'quiet': True,
                'format': id,
                'outtmpl': '%(title)s.%(ext)s',
                'concurrent-fragments': 10,
                'cookiefile': None, # 禁用默认的cookie文件
                'cookies':cookies['cookies_bili'],
                'progress_hooks': [get_downloaded_filename], # callback function
                }
    with yt_dlp.YoutubeDL(ydl_opts) as ydl:
        ydl.download([url])
        return filename

def get_audio(url):
    file_name = download_video(url)
    current_directory = os.getcwd()
    path = current_directory+'\\'+file_name+'.mp4'
    out_path = current_directory+'\\'+file_name+'.mp3'
    out_path_wav = current_directory+'\\'+file_name+'.wav'
    my_audio_clip = AudioFileClip(path.replace('\\','/'))
    my_audio_clip.write_audiofile(out_path.replace('\\','/'))
    my_audio_clip.write_audiofile(out_path_wav.replace('\\','/'))


if __name__ == "__main__":
    video_url = 'https://www.youtube.com/watch?v=eA0lHNZ1KCA'
    get_audio(video_url)

音频切割

 对于比较长的音频,需要运行很长时间,特别大的音频可能无法导入,可以利用pydub将音频分块处理:

from pydub import AudioSegment
song = AudioSegment.from_mp3("ROSÉ - toxic till the end (OFFICIAL MUSIC VIDEO).mp3")
t_minutes = 2 * 60 * 1000 # PyDub handles time in milliseconds
first_5_minutes = song[:t_minutes] # 前2分钟输出成单独的mp3文件
first_5_minutes.export("ROSÉ - toxic till the end (OFFICIAL MUSIC VIDEO)_2min.mp3", format="mp3")

模型准备

Whisper 是一种自动语音识别 (ASR) 系统,根据从 Web 收集的 680,000 小时的多语言和多任务监督数据进行训练。使用如此庞大且多样化的数据集可以提高对口音、背景噪声和技术语言的鲁棒性。此外,它还支持多种语言的转录,以及从这些语言翻译成英语。 whisper的好处是开源免费、支持多语种(包括中文),有不同模型可供选择,最终的效果比市面上很多音频转文字的效果好是一个典型的transformer Encoder-Decoder结构,针对语音和文本分别进行多任务(Multitask)处理。其原理介绍在(extension://ngbkcglbmlglgldjfcnhaijeecaccgfi/https://arxiv.org/pdf/2212.04356)这篇论文中。

架构:

Whisper 架构是一种简单的端到端方法,实现为编码器-解码器 Transformer。输入音频被分割成 30 秒的块,转换为 log-Mel 频谱图,然后传递到编码器中。解码器经过训练以预测相应的文本标题,并与特殊标记混合在一起,这些标记指示单个模型执行语言识别、短语级时间戳、多语言语音转录和到英语语音翻译等任务。

Transformer 序列到序列模型针对各种语音处理任务进行训练,包括多语言语音识别、语音翻译、口语识别和语音活动检测。这些任务共同表示为解码器要预测的令牌序列,从而允许单个模型替换传统语音处理管道的许多阶段。多任务训练格式使用一组特殊标记,用作任务说明符或分类目标。 whisper目前有5个模型,随着参数的变多,转文字的理解性和准确性会提高,但相应速度会变慢:

Whisper 的音频数据集中约有三分之一是非英语的,它交替被赋予了以原始语言转录或翻译成英语的任务。这种方法在学习语音到文本翻译方面特别有效,并且优于 CoVoST2 到英语翻译零样本的监督 SOTA。

 安装whisper:

cmd>>pip install -U openai-whisper  # 最新版本的 Whisper
cmd>>pip install git+https://github.com/openai/whisper.git  # 最新依赖项

如果安装成功,在cmd中输入whisper可以得到以下输出: 

安装chocolatey(Chocolatey Software | Installing Chocolatey),安装chocolatey是为了后面方便在Windows中安装ffmpeg,输入以下命令:

powershell>>Set-ExecutionPolicy Bypass -Scope Process -Force; [System.Net.ServicePointManager]::SecurityProtocol = [System.Net.ServicePointManager]::SecurityProtocol -bor 3072; iex ((New-Object System.Net.WebClient).DownloadString('https://community.chocolatey.org/install.ps1'))
powershell>>choco install ffmpeg
cmd>> pip install setuptools-rust # 可选项,如果没有构建轮子

可以在命令行中直接使用whisper,第一次调用模型时需要下载。模型导入后会逐步生成文字。例如导入audio音频,利用medium模型进行转录,模型会自动检测出语言是英语,然后根据断句和语气,生成每句话的文字。中途可以随时中断。

whisper audio.mp3 --model medium

whisper其他参数,可以参考帮助:所有可用语言的列表可以参阅 tokenizer.py。

whisper --help

也可以在python中调用:

import whisper
model = whisper.load_model("base")
result = model.transcribe("audio.mp3",,fp16="False") # 输出音频转录文字
print(result["text"]) # 将结果储存在变量中

音频识别

为了追求转录速度和精确度,将基于c++的Whisper 在linux中实现转录目标,进入python3.8虚拟环境:

 安装依赖:

sudo apt update
sudo apt install cmake g++ wget ffmpeg nilfs-tools
git clone https://github.com/ggerganov/whisper.cpp.git # open VPN
cd whisper.cpp


下载其中任意一个whisper模型转换成 ggml格式,如:

sh ./models/download-ggml-model.sh base

具体可选模型可参考:ggerganov/whisper.cpp at main,根据需要下载对应模型,模型越大所推理的时间越长,但是精度越好,鲁棒性越大,反之同理。

如果下载较慢可选择手动下载上传到linux目录,创建实例:

# build the project
cmake -B build
cmake --build build --config Release

完成后将会出现以下文件夹:

发现虚拟机无法联网导致无法下载文件,修改配置(本机基于WSL2配置的ubuntu,未使用VM或者clash for windows等代理软件):

sudo lshw -c Network  # 检查网络状况:disabled
ls /etc/NetworkManager/conf.d/
sudo touch /etc/NetworkManager/conf.d/10-globally-managed-devices.conf
sudo systemctl restart NetworkManager

问题解决后ping一下baidu.com:

问题已解决,若运行文件仍然出现问题,可能是未刷新,退出虚拟机重新进入即可。

转录文字

基于linux下python脚本运行,输入为文件mp3路径或者对应网址附加语种,在Linux下下暂不支持代理VPN连接外站视频:

import os
import subprocess
import pandas as pd
import yt_dlp # 支持腾讯视频,小红书,tiktok,bilibili,youtube等
from url_cookies import cookies
from moviepy.editor import AudioFileClip

class audio_to_content:
    def __init__(self, video_url, ln = 'auto', model = 'ggml-base.bin'):
        if video_url.endswith('zh'):
            self.url = video_url[:-3]
            self.l = 'zh'
        elif video_url.endswith('en'):
            self.url = video_url[:-3]
            self.l = 'en'
        else:
            self.url = video_url
            self.l = ln
        self.model = model
        self.curr_path = os.getcwd().replace('\\','/')

    def get_available_formats(self, url):
        ydl_opts = {'quiet': True,'extract_flat': True} # 静默模式
        with yt_dlp.YoutubeDL(ydl_opts) as ydl:
            info_dict = ydl.extract_info(url, download=False) # info_dict['title']
            formats = info_dict.get('formats', [])
            if not formats:
                print("No available format found!")
                return
            results = pd.DataFrame(formats).iloc[:,:].groupby('ext').apply(lambda x: x.tail(1)).reset_index(drop=True)
            return results,info_dict['title']

    def get_downloaded_filename(self, d):
        if d['status'] == 'processing': # callback function to capture the downloaded file name information
            print(f"download filename:{d['filename']}")

    def download_video(self, url):
        get_best_id,filename = self.get_available_formats(url)
        id = (str(get_best_id[get_best_id['ext']=='mp4']['format_id'].values[0]) if len(get_best_id[get_best_id['ext']=='mp4'])>0 else '')+ \
                '+'+(str(get_best_id[get_best_id['ext']=='m4a']['format_id'].values[0]) if len(get_best_id[get_best_id['ext']=='m4a'])>0 else '')
        id = id.split('+')[0] if id.endswith('+') else id.split('+')[1] if id.startswith('+') else id
        ydl_opts = {'quiet': True,
                    'format': id,
                    'outtmpl': '%(title)s.%(ext)s',
                    'concurrent-fragments': 10,
                    'cookiefile': None, # 禁用默认的cookie文件
                    'cookies':cookies['cookies_bili'],
                    'progress_hooks': [self.get_downloaded_filename]} # callback function
                    
        with yt_dlp.YoutubeDL(ydl_opts) as ydl:
            ydl.download([url])
            return filename

    def get_audio(self, url):
        file_name = self.download_video(url)
        path = self.curr_path + '/'+file_name+'.mp4'
        my_audio_clip = AudioFileClip(path)
        my_audio_clip.write_audiofile(path.replace('mp4','mp3'))
        return path

    def convert_mp3_to_wav(self, input_file, output_file):
        if os.path.exists(output_file):
            os.remove(output_file)
        command = ['ffmpeg','-i', input_file,'-ar', '16000','-ac', '1','-c:a', 'pcm_s16le',output_file]
        subprocess.run(command, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)

    def run(self):
        # cli_path = self.curr_path + '/build/bin/whisper-cli'
        if self.url.startswith('https'):
            mp4_path  = self.get_audio(self.url)
            self.convert_mp3_to_wav(mp4_path.replace('mp4','mp3'), mp4_path.replace('mp4','wav'))
            wav_path = mp4_path.replace('mp4','wav')
        elif self.url.endswith('mp3'):
            mp3_path = self.curr_path + '/' + self.url
            self.convert_mp3_to_wav(mp3_path, mp3_path.replace('mp3','wav'))
            wav_path = mp3_path.replace('mp3','wav')
            txt_path = mp3_path.replace('mp3','txt')
        elif self.url.endswith('aac'):
            mp3_path = self.curr_path + '/' + self.url
            self.convert_mp3_to_wav(mp3_path, mp3_path.replace('aac','wav'))
            wav_path = mp3_path.replace('aac','wav')
            txt_path = mp3_path.replace('aac','txt')
        if self.l == 'zh':
            model_path = self.curr_path + '/models/' + self.model
        elif self.l == 'en':
            model_path = self.curr_path + '/models/ggml-base.en.bin'
        try:
            command = ['whisper-cli','-f', wav_path,'-l', self.l,'-m', model_path]
        except Exception as e:
            print('enter language!')
        result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
        
        if result:
            if self.url.startswith('https'):
                txt_path = mp4_path.replace('mp4','txt')
                os.remove(mp4_path)
                os.remove(mp4_path.replace('mp4','mp3'))
                os.remove(mp4_path.replace('mp4','wav'))
            else:
                os.remove(mp3_path)
                os.remove(wav_path)
            with open(txt_path, 'w') as file:
                file.write(result.stdout)
            print('complete!')


if __name__ == "__main__":
    content = input("enter url or audio_file_path and language:")
    audio_convert = audio_to_content(content)
    audio_convert.run()

为文件时:

为网址时:

实时转录

依赖于stream实时监听麦克风阵列,可以加入到脚本中,当然可以直接在命令行中执行,这种情况仅支持仅推理模式,流工具每半秒对音频进行一次采样,并连续运行转录:

./build/bin/stream -m ./models/ggml-base.en.bin -t 8 --step 500 --length 5000

在cpu上运行转录一段半小时的音频文件大概需要6分钟左右,但若存在GPU加速该时间能缩短至一半甚至更快。

模型微调

以大模型large_v3为例,首先安装依赖:

pip install transformers datasets huggingface-hub accelerate evaluate tensorboard

在安装 pip 依赖库后登入 HuggingFace:

huggingface-cli login

登录需要 token. 在上面这条指令中,huggingface-cli 会提供一个网址,获取 API token。如果需要将模型上传到 HuggingFace Hub需要一个拥有 write 权限的 token。

准备数据集

Whisper 是一个监督学习的模型。因此在数据集中,需要提供音频文件以及音频对应的文字。 最简单的数据集准备方法是使用 HuggingFace AudioFolder.建立文件夹,并将文件如下摆放:

folder/train/metadata.jsonl
folder/train/first.mp3
folder/train/second.mp3
folder/train/third.mp3 # 不是所有文件都支持。例如m4a文件就无法使用

metadata.jsonl 是一个 JSON Lines 格式的文件,其格式如下:

{"file_name": "first.mp3", "transcription": "First Audio Transcription"}
{"file_name": "second.mp3", "transcription": "Second Audio Transcription"}
{"file_name": "third.mp3", "transcription": "Third Audio Transcription"}

JSONL 的意思是 JSON Lines: 每一行是一个 JSON 对象,而整个文件可以被看作是一个数组的 JSON 对象。 在每行中,file_name 的名字必须是 file_name。它提供的是音频文件的相对路径,相对这一个 metadata.jsonl文件。 其他键值(比如transcription)可以任意起名。最后这个文件将会被转成 Arrow 格式的表格(类似 pandas 的 Dataset),而每一个键值对应的表格中的一列。 可以加入任意多的其他键值,可以指明说话人、语言、来源等信息。

数据集准备完成后,使用如下命令将数据集上传至 HuggingFace Hub:

from datasets import load_dataset
audio_dataset = load_dataset("audiofolder", data_dir=".")
audio_dataset.push_to_hub("YOUR_HF_NAME/HF_DATASET_REPO") # Replace this with your Huggingface Repository

这将会读取音频文件,将整个数据集转换为 Parquet 格式,自动生成包含数据集信息的 README.md 文件,并上传到 HuggingFace Hub。

微调基于 HuggingFace 版本的 OpenAI Whisper 模型。关于微调的详细过程可以在这里找到:

# 训练过程文件夹: ./whisper-large-v3-ft-train
# 模型输出文件夹: ./whisper-large-v3-finetuned

ref(APA): metricv.MetricVoid's Blog.https://me.sakana.moe. Retrieved 2024/12/29.
# NOTE: 注意:在此处填入finetune 的基座模型。
base_model = "openai/whisper-large-v3"

# NOTE: 此处不要修改。除非你想训练 translate 模式,且你的数据集包含原音频的英文翻译。
task = "transcribe"

from datasets import load_dataset, DatasetDict

# ========== Load Dataset ==========
tl_dataset = DatasetDict()
tl_dataset["train"] = load_dataset("YOUR_HF_NAME/HF_DATASET_REPO", split="train")
# NOTE: 如果你的数据集包含 test 分区,将下一行取消注释
# tl_dataset["test"] = load_dataset("metricv/tl-whisper", "hi", split="test")

# ========== Load Whisper Preprocessor ==========

from transformers import WhisperFeatureExtractor
from transformers import WhisperTokenizer
from transformers import WhisperProcessor

feature_extractor = WhisperFeatureExtractor.from_pretrained(base_model)
tokenizer = WhisperTokenizer.from_pretrained(base_model, task=task)
processor = WhisperProcessor.from_pretrained(base_model, task=task)

# ========== Process Dataset ==========

from datasets import Audio

tl_dataset = tl_dataset.cast_column("audio", Audio(sampling_rate=16000))

def prepare_dataset(batch):
    # load and resample audio data from 48 to 16kHz
    audio = batch["audio"]
    # compute log-Mel input features from input audio array
    batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
    # encode target text to label ids
    # NOTE: 此处的键值 "transcription" 指的就是你在创建数据集的过程中,包含音频文件对应文字的键值。如果你是用的键名不是 transcription,在此处修改。
    batch["labels"] = tokenizer(batch["transcription"]).input_ids
    return batch

tl_dataset = tl_dataset.map(prepare_dataset, remove_columns=tl_dataset.column_names["train"], num_proc=8)

# ========== Load Whisper Model ==========

from transformers import WhisperForConditionalGeneration
model = WhisperForConditionalGeneration.from_pretrained(base_model)
model.generation_config.task = task
model.generation_config.forced_decoder_ids = None

# ========== Fine-tune model ==========

import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Union

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch
    
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor,
    decoder_start_token_id=model.config.decoder_start_token_id,
)

import evaluate

metric = evaluate.load("wer")

def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = tokenizer.pad_token_id

    # we do not want to group tokens when computing the metrics
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    wer = 100 * metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir="./whisper-large-v3-ft-train",  # change to a repo name of your choice
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,  # increase by 2x for every 2x decrease in batch size
    learning_rate=1e-5,
    num_train_epochs=2.0,
    # warmup_steps=500,
    # max_steps=4000,
    gradient_checkpointing=True,
    fp16=True,
    do_eval=False,
    # eval_strategy="steps",    # NOTE: 如果你的数据集包含 test 分区,可取消注释此行
    per_device_eval_batch_size=8,
    predict_with_generate=True,
    generation_max_length=225,
    # save_steps=1000,
    # eval_steps=1000,
    logging_steps=5,
    report_to=["tensorboard"],
    load_best_model_at_end=False,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=False,
)

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=tl_dataset["train"],
    # eval_dataset=tl_dataset["test"], # NOTE: 如果你的数据集包含 test 分区,可取消注释此行
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)

processor.save_pretrained(training_args.output_dir)

trainer.train()

# ========== Save model ==========

trainer.save_model(output_dir="./whisper-large-v3-finetuned")
torch.save(model.state_dict(), f"{training_args.output_dir}/pytorch_model.bin")

# ========== Push model to HF hub ==========
# 如果你不想上传模型,注释掉以下行。
trainer.push_to_hub("YOUR_HF_NAME/HF_MODEL_REPO") # 修改为你的 HuggingFace 仓库名

在部署微调后的模型前还需要设置一些东西:

基座模型的 tokenizer.json 可能没有被复制过来。你需要手动复制一下。

在 HuggingFace Hub上找到基座模型 (such as openai/whisper-large-v3), 下载它的 tokenizer.json,并放到“模型输出文件夹”下。 如果你使用了 push_to_hub() 来上传模型,但是上传后的模型没有 tokenizer.json,你可以使用 HuggingFace 的网页界面手动上传。

如果“模型输出文件夹”中没有 tokenizer_config.json,将“训练过程文件夹”的对应文件复制过来。

如果“模型输出文件夹”中没有 preprocessor_config.json,将“训练过程文件夹”的对应文件复制过来。

微调后的模型可以用多种方式部署:

1、使用 HuggingFace 运行库

微调后的模型和原版 HuggingFace 模型一样,可以使用 from_pretrained() 来部署。此方法略快于原版 openai-whisper 包,但会占用更多 RAM。

from transformers import WhisperForConditionalGeneration, WhisperProcessor

processor = WhisperProcessor.from_pretrained("YOUR_HF_NAME/HF_MODEL_REPO") # 如果你没有上传模型,使用 from_pretrained("模型输出文件夹") 加载本地模型。

2、使用原版 PyPI 的 openai-whisper 包

微调后的模型可被转换为兼容原版 openai-whisper 包的格式。 在微调的结尾,把 PyTorch 格式的模型保存在了 "训练过程文件夹"/pytorch_model.bin. 但是,这个模型中的层和原版模型的命名方式不一样。重命名即可解决该问题。 使用以下代码转换:

# NOTE: Change this to the base model you fine-tuned from.
BASE_MODEL = "large-v3"

#!/bin/env python3
import whisper
import re
import torch

def hf_to_whisper_states(text):
    text = re.sub('.layers.', '.blocks.', text)
    text = re.sub('.self_attn.', '.attn.', text)
    text = re.sub('.q_proj.', '.query.', text)
    text = re.sub('.k_proj.', '.key.', text)
    text = re.sub('.v_proj.', '.value.', text)
    text = re.sub('.out_proj.', '.out.', text)
    text = re.sub('.fc1.', '.mlp.0.', text)
    text = re.sub('.fc2.', '.mlp.2.', text)
    text = re.sub('.fc3.', '.mlp.3.', text)
    text = re.sub('.fc3.', '.mlp.3.', text)
    text = re.sub('.encoder_attn.', '.cross_attn.', text)
    text = re.sub('.cross_attn.ln.', '.cross_attn_ln.', text)
    text = re.sub('.embed_positions.weight', '.positional_embedding', text)
    text = re.sub('.embed_tokens.', '.token_embedding.', text)
    text = re.sub('model.', '', text)
    text = re.sub('attn.layer_norm.', 'attn_ln.', text)
    text = re.sub('.final_layer_norm.', '.mlp_ln.', text)
    text = re.sub('encoder.layer_norm.', 'encoder.ln_post.', text)
    text = re.sub('decoder.layer_norm.', 'decoder.ln.', text)
    text = re.sub('proj_out.weight', 'decoder.token_embedding.weight', text)
    return text

# Load HF Model
# NOTE: Change the following line to point to "Training Data Directory"/pytorch_model.bin
hf_state_dict = torch.load("Training Data Directory/pytorch_model.bin", map_location=torch.device('cpu'))

# Rename layers
for key in list(hf_state_dict.keys())[:]:
    new_key = hf_to_whisper_states(key)
    hf_state_dict[new_key] = hf_state_dict.pop(key)

model = whisper.load_model(BASE_MODEL)
dims = model.dims

# Save it
# NOTE: This will save file to whisper-model.bin. Change the path as you wish.
torch.save({
    "dims": model.dims.__dict__,
    "model_state_dict": hf_state_dict
}, "whisper-model.bin")

然后,你就可以使用原版的 whisper.load("whisper-model.bin") 来加载模型。

3、Faster-Whisper (CTranslate 2)

最有效率的部署方式是使用 faster-whisper 运行库,但需要再转换一次格式。 首先,安装faster-whisper的转换器:

git clone --depth=1 https://github.com/SYSTRAN/faster-whisper
cd faster-whisper
pip install -e .[convert] # In zsh, quote ".[convert]"

然后使用以下命令进行转换

ct2-transformers-converter \
    --model YOUR_HF_NAME/HF_MODEL_REPO \
    --output_dir whisper-largve-v3-ft-ct2-f16 \
    --copy_files tokenizer.json preprocessor_config.json \
    --quantization float16

CTranslate2 模型会保存到一个文件夹,而并不是单一文件。将 whisper-largve-v3-ft-ct2-f16 改为目标文件夹。 Quantization 不是必要的。训练时使用的就是f16,所以此处的 quantization 其实并没做任何量化。

然后,微调后的模型就可以像任何其他模型一样,被 faster-whisper 加载:

from faster_whisper import WhisperModel

model = WhisperModel("/path/to/model/directory", device="cuda", compute_type="float16")

制作GUI界面交互(demo)

测试安装以下 Python 库:

  • Flask: 用于构建 Web 服务。
  • google-cloud-speech: 用于调用 Google Cloud Speech-to-Text API。
  • werkzeug: 用于处理文件上传
pip install Flask google-cloud-speech werkzeug

设置 Google Cloud API 的认证。可以在 Google Cloud Console 创建项目并启用 Speech-to-Text API,下载服务账号的 JSON 密钥文件,并设置环境变量 GOOGLE_APPLICATION_CREDENTIALS:

export GOOGLE_APPLICATION_CREDENTIALS="path_to_your_service_account_file.json"

function:

  1. 文件上传:前端通过 POST 请求将音频文件发送到 /upload 路由,后端接收文件并保存到服务器本地。
  2. 转录音频:文件上传后,transcribe_audio 函数会调用 Google Cloud Speech-to-Text API 来转录音频文件。
  3. 返回转录结果:转录完成后,后端将转录文本以 JSON 格式返回给前端。 
from flask import Flask, request, jsonify
from google.cloud import speech
import os
from werkzeug.utils import secure_filename

# 初始化 Flask 应用
app = Flask(__name__)

# 配置上传文件的限制
app.config['UPLOAD_FOLDER'] = 'uploads/'
app.config['ALLOWED_EXTENSIONS'] = {'wav', 'mp3', 'flac'}

# 检查文件类型
def allowed_file(filename):
    return '.' in filename and filename.rsplit('.', 1)[1].lower() in app.config['ALLOWED_EXTENSIONS']

# 创建上传目录(如果不存在)
if not os.path.exists(app.config['UPLOAD_FOLDER']):
    os.makedirs(app.config['UPLOAD_FOLDER'])

# 语音转文本函数
def transcribe_audio(file_path):
    client = speech.SpeechClient()

    # 读取音频文件
    with open(file_path, "rb") as audio_file:
        content = audio_file.read()

    # 音频配置
    audio = speech.RecognitionAudio(content=content)
    config = speech.RecognitionConfig(
        encoding=speech.RecognitionConfig.AudioEncoding.LINEAR16,
        sample_rate_hertz=16000,
        language_code="en-US",
    )

    # 调用 Google Cloud Speech API 转录
    response = client.recognize(config=config, audio=audio)

    # 提取转录文本
    transcription = ""
    for result in response.results:
        transcription += result.alternatives[0].transcript

    return transcription

# 上传并转录音频文件的路由
@app.route('/upload', methods=['POST'])
def upload_file():
    if 'file' not in request.files:
        return jsonify({"error": "No file part"}), 400
    file = request.files['file']
    
    if file.filename == '':
        return jsonify({"error": "No selected file"}), 400
    
    if file and allowed_file(file.filename):
        filename = secure_filename(file.filename)
        file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
        file.save(file_path)
        
        # 转录音频文件
        try:
            transcription = transcribe_audio(file_path)
            return jsonify({"transcription": transcription})
        except Exception as e:
            return jsonify({"error": f"Error transcribing audio: {str(e)}"}), 500
    else:
        return jsonify({"error": "Invalid file type"}), 400

if __name__ == '__main__':
    app.run(debug=True)

运行服务:

启动 Flask 后端服务:

python app.py

打开浏览器,访问 http://127.0.0.1:5000,选择一个音频文件并上传,服务将返回转录的文本。


http://www.kler.cn/a/457852.html

相关文章:

  • 数据结构与算法之动态规划: LeetCode 674. 最长连续递增序列 (Ts版)
  • 如何在IDEA一个窗口中导入多个项目
  • centos7 免安装mysql5.7及配置(支持多个mysql)
  • 2024年中国新能源汽车用车发展怎么样 PaperGPT(二)
  • 7-58 输出不重复的数组元素
  • SQL-Server链接服务器访问Oracle数据
  • 在 CentOS 7 上安装 Node.js 20 并升级 GCC、make 和 glibc
  • 图像处理-Ch7-快速小波变换和小波包
  • redis cluster实验详解
  • 蓝桥杯速成教程{三}(adc,i2c,uart)
  • vulhub-wordpress靶场
  • 区块链安全常见的攻击合约和简单复现,附带详细分析——不安全调用漏洞 (Unsafe Call Vulnerability)【6】
  • 【513. 找树左下角的值 中等】
  • 【Leetcode刷题随笔】977 有序数组的平方
  • google广告 google分析
  • wordpress woodmark max_input_vars = 1000 限制问题
  • 使用proxysql代理mysql连接
  • 【Raven1靶场渗透】
  • 钱币找零.
  • 秒鲨后端之MyBatis【1】环境的搭建和核心配置文件详解(重置)
  • 智能工厂的设计软件 应用场景的一个例子:为AI聊天工具添加一个知识系统 之5
  • vue.js普通组件的注册-全局注册
  • 7-Gin 中自定义控制器 --[Gin 框架入门精讲与实战案例]
  • CPU性能优化--后端优化
  • upload-labs关卡记录5
  • 【论文笔记】Contrastive Learning for Sign Language Recognition and Translation