当前位置: 首页 > article >正文

清华大学开源视频转文本模型——CogVLM2-Llama3-Caption

在这里插入图片描述
通常情况下,大多数视频数据并不附带相应的描述性文本,因此有必要将视频数据转换为文本描述,为文本到视频模型提供必要的训练数据。 CogVLM2-Caption 是一个视频字幕模型,用于为 CogVideoX 模型生成训练数据。

在这里插入图片描述
文件

在这里插入图片描述

使用

import io

import argparse
import numpy as np
import torch
from decord import cpu, VideoReader, bridge
from transformers import AutoModelForCausalLM, AutoTokenizer

MODEL_PATH = "THUDM/cogvlm2-llama3-caption"

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
TORCH_TYPE = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[
    0] >= 8 else torch.float16

parser = argparse.ArgumentParser(description="CogVLM2-Video CLI Demo")
parser.add_argument('--quant', type=int, choices=[4, 8], help='Enable 4-bit or 8-bit precision loading', default=0)
args = parser.parse_args([])


def load_video(video_data, strategy='chat'):
    bridge.set_bridge('torch')
    mp4_stream = video_data
    num_frames = 24
    decord_vr = VideoReader(io.BytesIO(mp4_stream), ctx=cpu(0))

    frame_id_list = None
    total_frames = len(decord_vr)
    if strategy == 'base':
        clip_end_sec = 60
        clip_start_sec = 0
        start_frame = int(clip_start_sec * decord_vr.get_avg_fps())
        end_frame = min(total_frames,
                        int(clip_end_sec * decord_vr.get_avg_fps())) if clip_end_sec is not None else total_frames
        frame_id_list = np.linspace(start_frame, end_frame - 1, num_frames, dtype=int)
    elif strategy == 'chat':
        timestamps = decord_vr.get_frame_timestamp(np.arange(total_frames))
        timestamps = [i[0] for i in timestamps]
        max_second = round(max(timestamps)) + 1
        frame_id_list = []
        for second in range(max_second):
            closest_num = min(timestamps, key=lambda x: abs(x - second))
            index = timestamps.index(closest_num)
            frame_id_list.append(index)
            if len(frame_id_list) >= num_frames:
                break

    video_data = decord_vr.get_batch(frame_id_list)
    video_data = video_data.permute(3, 0, 1, 2)
    return video_data


tokenizer = AutoTokenizer.from_pretrained(
    MODEL_PATH,
    trust_remote_code=True,
)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_PATH,
    torch_dtype=TORCH_TYPE,
    trust_remote_code=True
).eval().to(DEVICE)


def predict(prompt, video_data, temperature):
    strategy = 'chat'

    video = load_video(video_data, strategy=strategy)

    history = []
    query = prompt
    inputs = model.build_conversation_input_ids(
        tokenizer=tokenizer,
        query=query,
        images=[video],
        history=history,
        template_version=strategy
    )
    inputs = {
        'input_ids': inputs['input_ids'].unsqueeze(0).to('cuda'),
        'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to('cuda'),
        'attention_mask': inputs['attention_mask'].unsqueeze(0).to('cuda'),
        'images': [[inputs['images'][0].to('cuda').to(TORCH_TYPE)]],
    }
    gen_kwargs = {
        "max_new_tokens": 2048,
        "pad_token_id": 128002,
        "top_k": 1,
        "do_sample": False,
        "top_p": 0.1,
        "temperature": temperature,
    }
    with torch.no_grad():
        outputs = model.generate(**inputs, **gen_kwargs)
        outputs = outputs[:, inputs['input_ids'].shape[1]:]
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        return response


def test():
    prompt = "Please describe this video in detail."
    temperature = 0.1
    video_data = open('test.mp4', 'rb').read()
    response = predict(prompt, video_data, temperature)
    print(response)


if __name__ == '__main__':
    test()

感谢大家花时间阅读我的文章,你们的支持是我不断前进的动力。期望未来能为大家带来更多有价值的内容,请多多关注我的动态!


http://www.kler.cn/news/326242.html

相关文章:

  • 因果推断学习
  • Flink集群部署
  • 面试知识点总结篇四
  • 【渗透实战系列】|App渗透 ,由sql注入、绕过人脸识别、成功登录APP
  • 介绍我经常使用的两款轻便易用的 JSON 工具
  • 2024年第一批因AI失业的人,已经出现了
  • 【hot100-java】【寻找重复数】
  • springboot电影售票系统小程序—计算机毕业设计源码36991
  • C++那些事之内存优化
  • Matlab_与CANoe联合仿真方案基础环境搭建
  • 基于微信小程序的美食外卖管理系统
  • 南沙C++信奥赛陈老师解一本通题 1269:【例9.13】庆功会
  • phpstudy简易使用
  • 基于OpenCV的实时年龄与性别识别(支持CPU和GPU)
  • AI多模态基础知识点:LLM小白也能看懂的分词(tokenization)解读
  • Zookeeper下载、安装配置
  • 从画质设置看游戏引擎(其一)
  • 正点原子阿波罗STM32F429IGT6移植zephyr rtos(二)---使用I2C驱动MPU6050
  • 使用apipost工具导入通过swag生成的golang接口文档步骤
  • 思科安全网络解决方案
  • FIOT/浙江信达可恩消防股份有限公司25周年庆典隆重召开
  • 【JavaEE初阶】网络原理
  • 大盘点|9月独家爆款SVG模版(互斥伸长、扑克出牌、预感应滑动等)
  • 【C#生态园】构建高效PDF应用:全面解析C#六大PDF生成库
  • Linux date命令(用于显示和设置系统的日期和时间,不仅可以显示时间,还能进行复杂的时间计算和格式化)
  • 苍穹外卖学习笔记(十四)
  • 【JavaEE】——CAS指令和ABA问题
  • 【Android】获取备案所需的公钥以及签名MD5值
  • Mybatis中遍历List内容进行动态SQL拼接
  • LeetCode 461. 汉明距离