当前位置: 首页 > article >正文

【多模态RAG】多模态RAG ColPali实践

关于【RAG&多模态】多模态RAG-ColPali:使用视觉语言模型实现高效的文档检索前面已经介绍了(供参考),这次来看看ColPali实践。

所需权重:

  1. 多模态问答模型:Qwen2-VL-72B-Instruct,https://modelscope.cn/models/Qwen/Qwen2-VL-72B-Instruct

  2. 基于 PaliGemma-3B 和 ColBERT 策略的视觉检索器:

    • ColPali(LoRA):https://huggingface.co/vidore/colpali

    • ColPali(基座):https://huggingface.co/vidore/colpaligemma-3b-mix-448-base

多模态检索问答实践

  • lora的adapter_config.json字段base_model_name_or_path修改地址:ColPali(基座)存储路径

  • qwen_vl_utils下载地址:https://github.com/QwenLM/Qwen2-VL/tree/main/qwen-vl-utils/src/qwen_vl_utils

  • byaldi安装方式:https://github.com/AnswerDotAI/byaldi

  • 完整代码


from byaldi import RAGMultiModalModel
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import torch
from pdf2image import convert_from_path

class DocumentQA:
    def __init__(self, rag_model_name: str, vlm_model_name: str, device: str = 'cuda', system_prompt: str = None):
        self.rag_engine = RAGMultiModalModel.from_pretrained(rag_model_name)
        self.vlm = Qwen2VLForConditionalGeneration.from_pretrained(
            vlm_model_name,
            torch_dtype=torch.bfloat16,
            attn_implementation="flash_attention_2",
            device_map=device
        )
        self.processor = AutoProcessor.from_pretrained(vlm_model_name, trust_remote_code=True)
        self.device = device
        if system_prompt is None:
            self.system_prompt = (
                "你是一位专精于计算机科学和机器学习的AI研究助理。"
                "你的任务是分析学术论文,尤其是关于文档检索和多模态模型的研究。"
                "请仔细分析提供的图像和文本,提供深入的见解和解释。"
            )
        else:
            self.system_prompt = system_prompt

    def index_document(self, pdf_path: str, index_name: str = 'index', overwrite: bool = True):
        self.pdf_path = pdf_path
        self.rag_engine.index(
            input_path=pdf_path,
            index_name=index_name,
            store_collection_with_index=False,
            overwrite=overwrite
        )
        self.images = convert_from_path(pdf_path)

    def query(self, text_query: str, k: int = 3) -> str:
        results = self.rag_engine.search(text_query, k=k)
        print("搜索结果:", results)

        if not results:
            print("未找到相关查询结果。")
            return None

        try:
            page_num = results[0]["page_num"]
            image_index = page_num - 1
            image = self.images[image_index]
        except (KeyError, IndexError) as e:
            print("获取页面图像时出错:", e)
            return None

        messages = [
            {
                "role": "system",
                "content": self.system_prompt
            },
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": image},
                    {"type": "text", "text": text_query},
                ],
            }
        ]

        text = self.processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )

        image_inputs, video_inputs = process_vision_info(messages)

        # 准备模型输入
        inputs = self.processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt",
        )
        inputs = inputs.to(self.device)

        generated_ids = self.vlm.generate(**inputs, max_new_tokens=1024)

        generated_ids_trimmed = [
            out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]
        output_text = self.processor.batch_decode(
            generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )

        return output_text[0]

if __name__ == "__main__":
    # 初始化 DocumentQA 实例
    document_qa = DocumentQA(
        rag_model_name="./colpali",
        vlm_model_name="./Qwen2-VL-7B-Instruct",
        device='cuda'
    )

    # 索引 PDF 文档
    document_qa.index_document("test.pdf")

    # 定义查询
    text_query = (
        "文中模型在哪个数据集上相比其他模型有最大的优势?"
        "该优势的改进幅度是多少?"
    )

    # 执行查询并打印答案
    answer = document_qa.query(text_query)
    print("答案:", answer)

http://www.kler.cn/a/376046.html

相关文章:

  • 《GBDT 算法的原理推导》 11-12计算损失函数的负梯度 公式解析
  • 中仕公考:上海市25年公务员考试今日报名
  • Java的jackson库
  • HTB:Analytics[WriteUP]
  • docker-compose 什么情况下需要先down再up(ChatGPT回答)
  • Ubantu/Linux 采用Repo或Git命令报错!!
  • Unity WebGL项目中,如果想在网页端配置数字人穿红色上衣,并让桌面端保持同步
  • 3.使用ref定义页面元素,
  • ZooKeeper 客户端API操作
  • 工厂电气及PLC【1章各种元件符号】
  • T-Mobile股票分析:T-Mobile的股价还能继续上涨吗?
  • 动态ip如何自动更换ip
  • Apache Paimon主键表的一些最佳实践
  • 3d点在立方体内(numpy,不使用for循环)
  • 免费送源码:Java+Springboot+MySQL Springboot酒店客房管理系统的设计与实现 计算机毕业设计原创定制
  • [Python技术]利用akshare获取股票基本信息、K线图、最新新闻 以及大模型投资建议
  • 电脑换网络环境,IP地址会变吗?答案来了
  • 1008:计算(a+b)/c的值
  • 使用 ADB 在某个特定时间点点击 Android 设备上的某个按钮
  • 我的工具列表
  • DCN网络进行新冠肺炎影像分类
  • 浅谈C++深、浅拷贝
  • RPC和API关系
  • 2024三掌柜赠书活动第三十四期:破解深度学习
  • OpenMV的无人驾驶智能小车模拟系统
  • 使用 Q3D 计算并联和串联 RLCG 值