FastGPT 引申:常见 Rerank 实现方案
文章目录
- FastGPT引申:常见 Rerank 实现方案
- 1. 使用 BGE Reranker
- 2. 使用 Cohere Rerank API
- 3. 使用 Cross-Encoder 实现
- 4. 自定义 Reranker 实现
- 5. FastAPI 服务实现
- 6. 实现方案总结
FastGPT引申:常见 Rerank 实现方案
下边介绍几种 Rerank 的具体实现方案。
1. 使用 BGE Reranker
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
class BGEReranker:
def __init__(self):
# 加载模型和分词器
self.model_name = "BAAI/bge-reranker-base"
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.model = AutoModelForSequenceClassification.from_pretrained(self.model_name)
self.model.eval()
def rerank(self, query: str, documents: list[str]) -> list[dict]:
results = []
# 批处理文档
for doc in documents:
# 构造输入格式
inputs = self.tokenizer(
text=[query],
text_pair=[doc],
padding=True,
truncation=True,
max_length=512,
return_tensors="pt"
)
# 模型推理
with torch.no_grad():
scores = self.model(**inputs).logits.flatten()
results.append({
"text": doc,
"score": float(scores[0]) # 转换为Python float
})
# 按分数排序
results.sort(key=lambda x: x["score"], reverse=True)
return results
# 使用示例
reranker = BGEReranker()
query = "如何使用Python进行数据分析?"
docs = [
"Python数据分析基础教程",
"数据分析工具pandas使用指南",
"Python编程基础入门"
]
reranked_results = reranker.rerank(query, docs)
2. 使用 Cohere Rerank API
import cohere
from typing import List, Dict
class CohereReranker:
def __init__(self, api_key: str):
self.co = cohere.Client(api_key)
def rerank(
self,
query: str,
documents: List[Dict[str, str]],
top_n: int = 3
) -> List[Dict]:
try:
# 调用Cohere API
results = self.co.rerank(
query=query,
documents=[doc["text"] for doc in documents],
top_n=top_n,
model="rerank-multilingual-v2.0"
)
# 格式化结果
reranked_results = []
for result in results:
reranked_results.append({
"id": documents[result.index]["id"],
"text": result.document["text"],
"relevance_score": result.relevance_score
})
return reranked_results
except Exception as e:
print(f"Reranking error: {str(e)}")
return []
# 使用示例
reranker = CohereReranker(api_key="your-api-key")
query = "数据分析方法"
docs = [
{"id": "1", "text": "使用pandas进行数据处理"},
{"id": "2", "text": "数据可视化技巧"},
{"id": "3", "text": "机器学习算法"}
]
results = reranker.rerank(query, docs)
3. 使用 Cross-Encoder 实现
from sentence_transformers import CrossEncoder
import numpy as np
class CrossEncoderReranker:
def __init__(self):
# 加载cross-encoder模型
self.model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
def rerank(
self,
query: str,
documents: List[Dict],
batch_size: int = 32
) -> List[Dict]:
# 准备文档对
pairs = [[query, doc["text"]] for doc in documents]
# 批量计算相关性分数
scores = []
for i in range(0, len(pairs), batch_size):
batch = pairs[i:i + batch_size]
batch_scores = self.model.predict(batch)
scores.extend(batch_scores)
# 组合结果
results = []
for idx, score in enumerate(scores):
results.append({
"id": documents[idx]["id"],
"text": documents[idx]["text"],
"score": float(score)
})
# 按分数排序
results.sort(key=lambda x: x["score"], reverse=True)
return results
# 使用示例
reranker = CrossEncoderReranker()
results = reranker.rerank(query, documents)
4. 自定义 Reranker 实现
import torch
from torch import nn
from transformers import AutoTokenizer, AutoModel
class CustomReranker(nn.Module):
def __init__(self, model_name: str = "bert-base-chinese"):
super().__init__()
self.encoder = AutoModel.from_pretrained(model_name)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
# 相关性评分层
self.score = nn.Linear(self.encoder.config.hidden_size, 1)
def forward(self, input_ids, attention_mask):
# 获取BERT输出
outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask
)
# 使用[CLS]标记的输出计算相关性分数
pooled_output = outputs.last_hidden_state[:, 0]
score = self.score(pooled_output)
return score
def rerank(self, query: str, documents: List[str]) -> List[Dict]:
self.eval()
results = []
with torch.no_grad():
for doc in documents:
# 构造输入
inputs = self.tokenizer(
text=[query],
text_pair=[doc],
padding=True,
truncation=True,
max_length=512,
return_tensors="pt"
)
# 计算分数
score = self.forward(
inputs["input_ids"],
inputs["attention_mask"]
)
results.append({
"text": doc,
"score": float(score[0])
})
# 排序
results.sort(key=lambda x: x["score"], reverse=True)
return results
# 训练函数示例
def train_reranker(model, train_dataloader, epochs=3):
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
loss_fn = nn.BCEWithLogitsLoss()
for epoch in range(epochs):
model.train()
for batch in train_dataloader:
optimizer.zero_grad()
input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"]
labels = batch["labels"]
scores = model(input_ids, attention_mask)
loss = loss_fn(scores, labels)
loss.backward()
optimizer.step()
5. FastAPI 服务实现
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Optional
app = FastAPI()
class Document(BaseModel):
id: str
text: str
class RerankerRequest(BaseModel):
query: str
documents: List[Document]
class RerankerResponse(BaseModel):
id: str
text: str
score: float
@app.post("/rerank", response_model=List[RerankerResponse])
async def rerank(request: RerankerRequest):
try:
reranker = CrossEncoderReranker() # 或其他实现
results = reranker.rerank(
query=request.query,
documents=[{
"id": doc.id,
"text": doc.text
} for doc in request.documents]
)
return results
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
6. 实现方案总结
- BGE Reranker
- 开源模型
- 支持中英文
- 性能较好
- Cohere Rerank
- 商业API
- 多语言支持
- 无需维护模型
- Cross-Encoder
- 专门针对重排序优化
- 计算效率较高
- 易于使用
- 自定义实现
- 完全可控
- 可以针对特定场景优化
- 需要训练数据