多模态图文检索实战——基于CLIP实现图文检索系统(附源码)
在当今数字化时代,图文检索技术 变得愈发重要,它能够帮助我们在海量的图像和文本数据中快速找到匹配的内容。(如果手机相册能够出一个这个功能也挺不错,有时候想找一张图片,记不得是什么时候拍的,只记得大概的内容,如果能利用这种图文检索或许是个不错的解决方案)。今天,就和大家分享一下如何利用强大的CLIP模型来实现图文检索功能,并且结合实际代码来深入剖析整个过程。
文章目录
- 一、准备工作与环境搭建
- 二、核心功能函数解析
- 2.1 文本嵌入生成函数(text_embedding)
- 2.2 图片嵌入生成函数(get_image_embedding)
- 2.3 余弦相似度计算函数(cosine_similarity)
- 2.4 相似度计算主函数(calulate_similarity)
- 2.5 批量图片嵌入生成函数(getImage_embedding)
- 三、实例调用与结果展示
- 四、完整代码
- 五、总结与展望
一、准备工作与环境搭建
首先,我们需要导入一些必要的库。在Python环境中,我们引入了 time
用于记录时间,从 transformers
库中导入了 CLIPProcessor
和 CLIPModel
,这两个可是实现CLIP功能的核心组件呀,同时还引入了 torch
用于张量相关操作、 PIL
库中的 Image
模块来处理图片,以及 numpy
用于数值计算,另外,为了避免一些警告信息干扰,我们还对警告进行了过滤处理,代码如下:
import time
from transformers import CLIPProcessor,CLIPModel
import torch
from PIL import Image
import numpy as np
import warnings
warnings.filterwarnings("ignore")
接下来,就是加载CLIP模型和对应的处理器啦。这里我们指定了模型的预训练路径,并且根据是否有可用的 cuda
设备来选择合适的计算设备(优先使用GPU加速,如果没有GPU则使用CPU),如下所示:
# 加载模型和处理器
model = CLIPModel.from_pretrained("/root/model/clip-vit-large-patch14")
processor = CLIPProcessor.from_pretrained("/root/model/clip-vit-large-patch14")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
二、核心功能函数解析
2.1 文本嵌入生成函数(text_embedding)
这个函数的作用是将输入的文本转化为对应的嵌入表示(embedding)。它通过处理器对输入文本进行处理,使其符合模型的输入要求,然后利用模型获取文本特征,最后将结果转换为 numpy
数组格式返回,方便后续的计算和比较,代码如下:
def text_embedding(text):
inputs = processor(text=[text], return_tensors="pt", padding=True).to(device)
with torch.no_grad():
embedding = model.get_text_features(**inputs)
return embedding.cpu().numpy()
2.2 图片嵌入生成函数(get_image_embedding)
其功能是针对给定的图片路径,读取图片并将其转换为合适的格式后,通过模型获取图片的特征嵌入。如果在读取图片过程中出现错误,会进行相应的错误提示并返回 None
,代码如下:
def get_image_embedding(image_path):
try:
# 从本地路径读入图片
image = Image.open(image_path).convert("RGB")
inputs = processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
image_features = model.get_image_features(**inputs)
return image_features.cpu().numpy()
except Exception as e:
print(f"Error loading image {e}")
return None
2.3 余弦相似度计算函数(cosine_similarity)
在图文检索中,我们常常需要衡量文本嵌入和图片嵌入之间的相似度,这里采用了余弦相似度的计算方法。它将输入的向量转换为 numpy
数组后,按照余弦相似度的数学公式来计算两者的相似度数值,代码如下:
def cosine_similarity(vec1, vec2):
# 将列表转换为numpy数组
vec1 = np.array(vec1)
vec2 = np.array(vec2)
return np.dot(vec1, vec2.T) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
2.4 相似度计算主函数(calulate_similarity)
这个函数可以说是整个图文检索的核心协调者啦。它根据输入的查询内容(可以是文本或者图片)类型,调用相应的嵌入生成函数获取查询的嵌入表示,然后遍历候选的嵌入(这里是图片嵌入),通过余弦相似度函数计算每个候选与查询之间的相似度得分,并最终找出相似度最高的那个候选(也就是最佳匹配的图片),代码如下:
def calulate_similarity(query, candidates,query_type="text"):
if query_type == "text":
query_embedding = text_embedding(query).tolist() # 转换为可序列化格式
elif query_type == "image":
query_embedding = get_image_embedding(query)
if query_embedding is None:
raise ValueError("Error loading image")
query_embedding = query_embedding.tolist()
else:
raise ValueError("query_type must be 'text' or 'image'")
similarities = []
for candidate,candidate_embedding in zip(candidates,Image_embedding):
similarity_score = cosine_similarity(query_embedding, candidate_embedding)
similarities.append((candidate,float(similarity_score))) # 确保similarity_score是float类型
# 获取相似度最大的图片
best_similarity_score = max(similarities,key=lambda x:x[1])[1]
best_image_path = max(similarities,key=lambda x:x[1])[0]
return {"similarities":similarities,"best_match":best_image_path,"best_similarity_score":best_similarity_score}
2.5 批量图片嵌入生成函数(getImage_embedding)
当我们有多个候选图片时,就需要批量生成它们的嵌入表示了。这个函数会遍历所有的候选图片路径,依次调用 get_image_embedding
函数来获取嵌入,同时还会记录生成这些嵌入所花费的时间,代码如下:
def getImage_embedding(candidates):
# 生成候选嵌入
result = []
before_time = time.time()
for candidate in candidates:
candidate_embedding = get_image_embedding(candidate)
result.append(candidate_embedding)
after_time = time.time()
print(f"Time taken to generate image embeddings: {after_time - before_time:.2f} seconds")
return result
三、实例调用与结果展示
下面就是实际运用这些函数来进行图文检索的示例啦。我们首先定义了一组候选图片的路径,这里简单地用了 Data/image{i}.png
( i
从1到5)这样的形式来表示。然后通过 getImage_embedding
函数批量生成这些候选图片的嵌入表示。接着,我们设定了一个文本查询,比如“一张大象的照片”,指定查询类型为文本,再调用 calulate_similarity
函数来计算这个文本查询与候选图片之间的相似度,最终找到最佳匹配的图片,并输出相应的结果,代码如下:
# 实例调用
candidates = [f"Data/image{i}.png" for i in range(1,6)]
Image_embedding = getImage_embedding(candidates)
query = "一张老虎的照片"
query_type = "text"
similarity_result = calulate_similarity(query,candidates,query_type)
# print("similarities result:",similarity_result)
print("query:",query)
print("best_match:",similarity_result["best_match"])
输出
Time taken to generate image embeddings: 1.50 seconds
query: 一张老虎的照片
best_match: Data/image4.png
通过这样的流程,我们就利用CLIP模型成功地实现了图文检索功能,能够快速地从给定的候选图片中找到与文本描述最匹配的那一张图片啦,可以看到正确的检索到了老虎的图片:
四、完整代码
import time
from transformers import CLIPProcessor,CLIPModel
import torch
from PIL import Image
import numpy as np
import warnings
warnings.filterwarnings("ignore")
# 加载模型和处理器
model = CLIPModel.from_pretrained("/root/model/clip-vit-large-patch14")
processor = CLIPProcessor.from_pretrained("/root/model/clip-vit-large-patch14")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# 函数:生成文本嵌入
def text_embedding(text):
inputs = processor(text=[text], return_tensors="pt", padding=True).to(device)
with torch.no_grad():
embedding = model.get_text_features(**inputs)
return embedding.cpu().numpy()
def get_image_embedding(image_path):
try:
# 从本地路径读入图片
image = Image.open(image_path).convert("RGB")
inputs = processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
image_features = model.get_image_features(**inputs)
return image_features.cpu().numpy()
except Exception as e:
print(f"Error loading image {e}")
return None
def cosine_similarity(vec1, vec2):
# 将列表转换为numpy数组
vec1 = np.array(vec1)
vec2 = np.array(vec2)
return np.dot(vec1, vec2.T) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
def calulate_similarity(query, candidates,query_type="text"):
if query_type == "text":
query_embedding = text_embedding(query).tolist() # 转换为可序列化格式
elif query_type == "image":
query_embedding = get_image_embedding(query)
if query_embedding is None:
raise ValueError("Error loading image")
query_embedding = query_embedding.tolist()
else:
raise ValueError("query_type must be 'text' or 'image'")
similarities = []
for candidate,candidate_embedding in zip(candidates,Image_embedding):
similarity_score = cosine_similarity(query_embedding, candidate_embedding)
similarities.append((candidate,float(similarity_score))) # 确保similarity_score是float类型
# 获取相似度最大的图片
best_similarity_score = max(similarities,key=lambda x:x[1])[1]
best_image_path = max(similarities,key=lambda x:x[1])[0]
return {"similarities":similarities,"best_match":best_image_path,"best_similarity_score":best_similarity_score}
def getImage_embedding(candidates):
# 生成候选嵌入
result = []
before_time = time.time()
for candidate in candidates:
candidate_embedding = get_image_embedding(candidate)
result.append(candidate_embedding)
after_time = time.time()
print(f"Time taken to generate image embeddings: {after_time - before_time:.2f} seconds")
return result
# 实例调用
candidates = [f"Data/image{i}.png" for i in range(1,6)]
Image_embedding = getImage_embedding(candidates)
query = "一张老虎的照片"
query_type = "text"
similarity_result = calulate_similarity(query,candidates,query_type)
# print("similarities result:",similarity_result)
print("query:",query)
print("best_match:",similarity_result["best_match"])
五、总结与展望
利用CLIP模型实现图文检索为我们在多媒体数据处理等诸多领域提供了很大的便利,比如在图像搜索引擎、内容推荐系统等方面都有着广阔的应用前景。这也是RAG最核心的部分
,可以说这是多模态RAG的一个简单尝试,在实际应用中,我们还可以进一步优化模型参数、增加更多的图片和文本数据进行训练、改进相似度计算的策略等,来不断提升图文检索的准确性和效率,希望这篇博客能够帮助大家对利用CLIP实现图文检索有一个初步的了解和实践思路哦。