当前位置: 首页 > article >正文

FastGPT 引申:常见 Rerank 实现方案

文章目录

    • FastGPT引申:常见 Rerank 实现方案
      • 1. 使用 BGE Reranker
      • 2. 使用 Cohere Rerank API
      • 3. 使用 Cross-Encoder 实现
      • 4. 自定义 Reranker 实现
      • 5. FastAPI 服务实现
      • 6. 实现方案总结

FastGPT引申:常见 Rerank 实现方案

下边介绍几种 Rerank 的具体实现方案。

1. 使用 BGE Reranker

from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch

class BGEReranker:
    def __init__(self):
        # 加载模型和分词器
        self.model_name = "BAAI/bge-reranker-base"
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        self.model = AutoModelForSequenceClassification.from_pretrained(self.model_name)
        self.model.eval()

    def rerank(self, query: str, documents: list[str]) -> list[dict]:
        results = []
        
        # 批处理文档
        for doc in documents:
            # 构造输入格式
            inputs = self.tokenizer(
                text=[query],
                text_pair=[doc],
                padding=True,
                truncation=True,
                max_length=512,
                return_tensors="pt"
            )
            
            # 模型推理
            with torch.no_grad():
                scores = self.model(**inputs).logits.flatten()
                
            results.append({
                "text": doc,
                "score": float(scores[0])  # 转换为Python float
            })
            
        # 按分数排序
        results.sort(key=lambda x: x["score"], reverse=True)
        return results

# 使用示例
reranker = BGEReranker()
query = "如何使用Python进行数据分析?"
docs = [
    "Python数据分析基础教程",
    "数据分析工具pandas使用指南",
    "Python编程基础入门"
]

reranked_results = reranker.rerank(query, docs)

2. 使用 Cohere Rerank API

import cohere
from typing import List, Dict

class CohereReranker:
    def __init__(self, api_key: str):
        self.co = cohere.Client(api_key)
    
    def rerank(
        self, 
        query: str, 
        documents: List[Dict[str, str]], 
        top_n: int = 3
    ) -> List[Dict]:
        try:
            # 调用Cohere API
            results = self.co.rerank(
                query=query,
                documents=[doc["text"] for doc in documents],
                top_n=top_n,
                model="rerank-multilingual-v2.0"
            )
            
            # 格式化结果
            reranked_results = []
            for result in results:
                reranked_results.append({
                    "id": documents[result.index]["id"],
                    "text": result.document["text"],
                    "relevance_score": result.relevance_score
                })
                
            return reranked_results
            
        except Exception as e:
            print(f"Reranking error: {str(e)}")
            return []

# 使用示例
reranker = CohereReranker(api_key="your-api-key")
query = "数据分析方法"
docs = [
    {"id": "1", "text": "使用pandas进行数据处理"},
    {"id": "2", "text": "数据可视化技巧"},
    {"id": "3", "text": "机器学习算法"}
]

results = reranker.rerank(query, docs)

3. 使用 Cross-Encoder 实现

from sentence_transformers import CrossEncoder
import numpy as np

class CrossEncoderReranker:
    def __init__(self):
        # 加载cross-encoder模型
        self.model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
        
    def rerank(
        self, 
        query: str, 
        documents: List[Dict], 
        batch_size: int = 32
    ) -> List[Dict]:
        # 准备文档对
        pairs = [[query, doc["text"]] for doc in documents]
        
        # 批量计算相关性分数
        scores = []
        for i in range(0, len(pairs), batch_size):
            batch = pairs[i:i + batch_size]
            batch_scores = self.model.predict(batch)
            scores.extend(batch_scores)
            
        # 组合结果
        results = []
        for idx, score in enumerate(scores):
            results.append({
                "id": documents[idx]["id"],
                "text": documents[idx]["text"],
                "score": float(score)
            })
            
        # 按分数排序
        results.sort(key=lambda x: x["score"], reverse=True)
        return results

# 使用示例
reranker = CrossEncoderReranker()
results = reranker.rerank(query, documents)

4. 自定义 Reranker 实现

import torch
from torch import nn
from transformers import AutoTokenizer, AutoModel

class CustomReranker(nn.Module):
    def __init__(self, model_name: str = "bert-base-chinese"):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        
        # 相关性评分层
        self.score = nn.Linear(self.encoder.config.hidden_size, 1)
        
    def forward(self, input_ids, attention_mask):
        # 获取BERT输出
        outputs = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        
        # 使用[CLS]标记的输出计算相关性分数
        pooled_output = outputs.last_hidden_state[:, 0]
        score = self.score(pooled_output)
        return score
    
    def rerank(self, query: str, documents: List[str]) -> List[Dict]:
        self.eval()
        results = []
        
        with torch.no_grad():
            for doc in documents:
                # 构造输入
                inputs = self.tokenizer(
                    text=[query],
                    text_pair=[doc],
                    padding=True,
                    truncation=True,
                    max_length=512,
                    return_tensors="pt"
                )
                
                # 计算分数
                score = self.forward(
                    inputs["input_ids"],
                    inputs["attention_mask"]
                )
                
                results.append({
                    "text": doc,
                    "score": float(score[0])
                })
        
        # 排序
        results.sort(key=lambda x: x["score"], reverse=True)
        return results

# 训练函数示例
def train_reranker(model, train_dataloader, epochs=3):
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
    loss_fn = nn.BCEWithLogitsLoss()
    
    for epoch in range(epochs):
        model.train()
        for batch in train_dataloader:
            optimizer.zero_grad()
            
            input_ids = batch["input_ids"]
            attention_mask = batch["attention_mask"]
            labels = batch["labels"]
            
            scores = model(input_ids, attention_mask)
            loss = loss_fn(scores, labels)
            
            loss.backward()
            optimizer.step()

5. FastAPI 服务实现

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Optional

app = FastAPI()

class Document(BaseModel):
    id: str
    text: str

class RerankerRequest(BaseModel):
    query: str
    documents: List[Document]

class RerankerResponse(BaseModel):
    id: str
    text: str
    score: float

@app.post("/rerank", response_model=List[RerankerResponse])
async def rerank(request: RerankerRequest):
    try:
        reranker = CrossEncoderReranker()  # 或其他实现
        results = reranker.rerank(
            query=request.query,
            documents=[{
                "id": doc.id,
                "text": doc.text
            } for doc in request.documents]
        )
        return results
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

6. 实现方案总结

  • BGE Reranker
    • 开源模型
    • 支持中英文
    • 性能较好
  • Cohere Rerank
    • 商业API
    • 多语言支持
    • 无需维护模型
  • Cross-Encoder
    • 专门针对重排序优化
    • 计算效率较高
    • 易于使用
  • 自定义实现
    • 完全可控
    • 可以针对特定场景优化
    • 需要训练数据

http://www.kler.cn/a/571733.html

相关文章:

  • 知识篇 | 低代码开发(Low-Code Development)是个什么东东?
  • 第40天:安全开发-JavaEE应用SpringBoot框架JWT身份鉴权打包部署JARWAR
  • Stiring-PDF:开源免费的PDF文件处理软件
  • Vue路由器的工作模式
  • PPT 小黑第34套
  • Metal学习笔记目录
  • DFT之SSN架构
  • 备赛蓝桥杯之第十五届职业院校组省赛第五题:悠然画境
  • 医疗AR眼镜:FPC如何赋能科技医疗的未来之眼?【新立电子】
  • 神经网络:AI的网络神经
  • P8692 [蓝桥杯 2019 国 C] 数正方形--输出取模余数
  • DeepSeek DeepEP学习(一)low latency dispatch
  • Scaling Laws(缩放法则)详解
  • lamp平台介绍
  • 记录uniapp小程序对接腾讯IM即时通讯无ui集成(2)
  • 【损失函数(目标函数)在深度学习中的作用】
  • Opencv 直方图与模板匹配
  • 10、HTTP/3有了解过吗?【中高频】
  • 数据结构(纯C语言版)习题(1)
  • 表达式求值(后缀表达式)