【以图搜图代码实现2】--faiss工具实现犬类以图搜图
第一篇:【以图搜图代码实现】–犬类以图搜图示例 使用保存成h5文件,使用向量积来度量相似性,实现了以图搜图,说明了可以优化的点。
第二篇:【使用resnet18训练自己的数据集】 准对模型问题进行了优化,取得了显著性的效果。
本篇继续第一篇中所说的优化方向,使用faiss实现以图搜图。
1.faiss使用介绍
Faiss的全称是Facebook AI Similarity Search,是FaceBook针对大规模相似度检索问题开发的一个工具,底层是使用C++代码实现的,提供了python的接口,号称对10亿量级的索引可以做到毫秒级检索。
使用faiss的基本步骤
1、数据转换:把原始数据转换为"float32"数据类型的向量。
2、index构建:用 faiss 构建index
3、数据添加:将向量add到创建的index中
4、通过创建的index进行检索
1.创建索引
import faiss
def create_index(datas_embedding):
# 构建索引,L2代表构建的index采用的相似度度量方法为L2范数
# 必须传入一个向量的维度,创建一个空的索引
index = faiss.IndexFlatL2(datas_embedding.shape[1])
# 把向量数据加入索引
index.add(datas_embedding)
return index
2.保存索引
def faiss_index_save(faiss_index, save_file_location):
faiss.write_index(faiss_index, save_file_location)
3.加载索引
def faiss_index_load(faiss_index_save_file_location):
index = faiss.read_index(faiss_index_save_file_location)
return index
4.向索引中添加向量
def index_data_add(faiss_index, img_path):
# 获得索引向量的数量
print(faiss_index.ntotal)
img_embedding = extract_image_features(img_path)
faiss_index.add(img_embedding)
print(faiss_index.ntotal)
5.删除索引中的向量
def index_data_delete(faiss_index):
print(faiss_index.ntotal)
# remove, 指定要删除的向量id,是一个np的array
faiss_index.remove_ids(np.array([0]))
print(faiss_index.ntotal)
可以看出使用Faiss工具更加的灵活,可以向索引中添加和删除向量。
2.faiss实现以图搜图
本篇代码有部分是在前两篇的基础之上的,这里使用11类犬类数据集微调之后的resnet18进行特征提取。
第一篇:【以图搜图代码实现】–犬类以图搜图示例
第二篇:【使用resnet18训练自己的数据集】
数据集准备和下载可以去看第二篇文章。
1.模型加载
为了更好的适配,对第一篇中的resnet18的初始化方法进行了修改,如下:
@Project :ImageRec
@File :resnet18.py
@IDE :PyCharm
@Author :菜菜2024
@Date :2024/9/30
'''
from PIL import Image
from torchvision import transforms
import torch
import torch.nn as nn
from torchvision import models
class ResNet18:
def __init__(self,
out_feature = 11,
model_path='E:\\xxx\\ImageRec\\weights\\resnet18.pth'):
self.trans = transforms.Compose([
transforms.Resize(size=(256, 256)),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
print("-----------loading resnet18------------")
self.model = models.resnet18()
num_feats = self.model.fc.in_features
self.model.fc = nn.Linear(num_feats, out_feature)
self.model.load_state_dict(torch.load(model_path))
self.model.eval()
def extract_image_features(self, img_path):
image = Image.open(img_path).convert('RGB')
image_tensor = self.trans(image).unsqueeze(0)
with torch.no_grad():
features = self.model(image_tensor)
return features
其中out_feature 根据自己的数据集的类别个数进行更改,我这里的犬类是11种。model_path是训练好的保存的权重文件【训练过程可以去看第二篇】
2.文件名映射
在第一篇:【以图搜图代码实现】–犬类以图搜图示例 中使用的是保存成h5文件,索引是没有要求是整数的,这里faiss要求是整数,搞了一个映射方法,同时也是为了在后面可视化的时候,能根据索引再解码得到对应的文件路径。
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
'''
@Project :ImageRec
@File :Imgmap.py
@IDE :PyCharm
@Author :菜菜2024
@Date :2024/9/29 18:02
'''
import os
import uuid
import numpy as np
def getImgMap(img_path):
# 为类别生成一个映射文件
subnames = [f.split('\\')[-1] for f in os.listdir(img_path)]
element_mapping = {}
for i in range(len(subnames)):
unique_id = str(i+2024)
element_mapping[unique_id] = subnames[i]
return element_mapping
def valueGetKey(mapping, target_value):
for key, value in mapping.items():
if value == target_value:
# print(f"值 '{target_value}' 对应的键是: {key}")
break
return key
def nameMap(imgnames, img_path='E:\\xxx\\datas\\pet_dog\\train'):
'''
getImagVector函数得到的image_ids在保存为h5文件时进行了编码
现在faiss工具中index需要是int类型的,这里进行映射转化
:param img_path: 数据集目录,来得到类别映射
:param imgnames: 需要映射的图片名称,解码之后是“中华田园犬_0”格式
这里传参是列表
:return:
'''
element_mapping = getImgMap(img_path)
decode_names = [imgname.decode('utf-8') for imgname in imgnames]
name_ids=[]
for decode_name in decode_names:
cla_name = decode_name.split("_")[0]
img_name = decode_name.split("_")[-1]
key = valueGetKey(element_mapping, cla_name)
name_id = key+img_name
name_ids.append(name_id)
name_ids=np.array(name_ids).astype('int32')
return name_ids
if __name__ == "__main__":
database = 'E:\\xxx\\datas\\pet_dog\\train'
element_mapping = getImgMap(database)
print(element_mapping)
print(element_mapping.get("2024"))
映射文件:
{‘2024’: ‘中华田园犬’, ‘2025’: ‘吉娃娃’, ‘2026’: ‘哈士奇’, ‘2027’: ‘德牧’, ‘2028’: ‘拉布拉多’, ‘2029’: ‘杜宾’, ‘2030’: ‘柴犬’, ‘2031’: ‘法国斗牛’, ‘2032’: ‘萨摩耶’, ‘2033’: ‘藏獒’, ‘2034’: ‘金毛’}
nameMap函数是将之前编码的图像名称进行解码,然后重新编码,编码成20240,20301,分别表示的中华田园犬文件夹下的0.jpg, 柴犬下面的1.jpg。这都是为了可视化的时候进行追溯,得到文件路径。
3.以图搜图实现
定义了一个类ImageRetrival,使用faiss实现创建索引,保存索引,加载索引和图像检索功能
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
'''
@Project :ImageRec
@File :faiss_index.py
@IDE :PyCharm
@Author :菜菜2024
@Date :2024/9/30 15:04
'''
import os
import faiss
from utils.split_data import array_norm
from utils.Imgmap import nameMap, getImgMap
from model import ResNet18
from save_feature import getImagVectors
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import rc
# 设置全局字体为支持中文的字体
rc('font', family='SimHei') # 黑体
class ImageRetrival:
def __init__(self, model_path,
index_dim=None):
self.index_dim = index_dim
self.index = faiss.IndexFlatL2(self.index_dim)
self.model_path = model_path
def build_index(self, image_files):
# image_vectors图片特征,image_ids对应的标签
image_vectors, image_ids = getImagVectors(image_files)
# image_ids 在之前保存为h5文件时进行了编码,这里进行映射
name_ids = nameMap(image_ids)
index = faiss.IndexIDMap(self.index)
index.add_with_ids(image_vectors, name_ids)
return index
def save_index(self, index, index_path):
faiss.write_index(index, index_path)
def load_index(self, index_path):
return faiss.read_index(index_path)
def image_topK_search(self, index, input_image, topK=None):
resnet18 = ResNet18(out_feature=11,
model_path=self.model_path)
queryVec = resnet18.extract_image_features(input_image)
dist, ind = index.search(queryVec, topK)
dist, ind = dist.flatten(), ind.flatten()
res = array_norm(dist, ind)
return res
4.运行调用
if __name__=="__main__":
model_path='E:\\xxx\\Pycharm_files\\ImageRec\\weights\\resnet18.pth'
# 1.创建索引
imageRetrival = ImageRetrival(model_path=model_path,
index_dim=11)
image_files = 'E:\\xxx\\datas\\pet_dog\\train'
save_index = "./weights/dog.index"
index = imageRetrival.build_index(image_files)
# # 2.保存索引
imageRetrival.save_index(index, save_index)
# 3.加载索引
index_load = imageRetrival.load_index(save_index)
#
# # 4.相似度匹配
input_image = './data/pic/德牧.jpg'
out = imageRetrival.image_topK_search(index_load, input_image, topK=3)
print(out)
showFaissRes(image_files, input_image, out)
运行时选择性注销其中的某一步骤。
最后是可视化实现showFaissRes
5.可视化实现
def showFaissRes(image_files, input_image, faissRes):
'''
对faiss得到的结果进行可视化
:param image_files: 图片数据库
:param input_image: 查询图片路径
:param faissRes: 返回的topk跟距离最近的结果[(ind, score), (ind, score)]
:return:
'''
scores = []
imgs = []
info = []
# 1.得到图片名称的映射
element_mapping = getImgMap(image_files)
imgs.append(mpimg.imread(input_image))
info.append(input_image.split("/")[-1])
for i in range(len(faissRes)):
score = faissRes[i][1]
ind = str(faissRes[i][0])
scores.append(score)
# 根据索引构建原本的图像路径ind格式:20276,前四个是类别表示
claName = element_mapping.get(ind[:4])
imgName = ind[4:]+".jpg"
imgpath = image_files +"\\"+ claName+ "\\"+imgName
imgs.append(mpimg.imread(imgpath))
info.append(claName+"_"+ imgName+"_"+ str(score))
print("图片名称是: " + claName+ imgName + " 对应得分是: %f" %score)
num = int((len(faissRes) + 1) // 2)+1
fig, axs = plt.subplots(nrows=num, ncols=num, figsize=(10, 10))
# 确保即使只有一个子图,也可以进行索引
if not isinstance(axs, np.ndarray):
axs = np.array([[axs]])
# 显示图像
flat_index = 0
for i in range(num):
for j in range(num):
if flat_index < len(imgs):
img = imgs[flat_index]
axs[i, j].imshow(img, cmap='gray')
axs[i, j].axis('off')
axs[i, j].set_title(info[flat_index])
flat_index += 1
else:
axs[i, j].set_visible(False)
plt.tight_layout()
plt.show()
3.效果对比
第一篇:【以图搜图代码实现】–犬类以图搜图示例 预训练的resnet18
第二篇:【使用resnet18训练自己的数据集】 微调的resnet18
本章 Faiss实现: 分数不重要,本篇对分数进行了归一化。
准确性更高了。