当前位置: 首页 > article >正文

【以图搜图代码实现2】--faiss工具实现犬类以图搜图

第一篇:【以图搜图代码实现】–犬类以图搜图示例 使用保存成h5文件,使用向量积来度量相似性,实现了以图搜图,说明了可以优化的点。
第二篇:【使用resnet18训练自己的数据集】 准对模型问题进行了优化,取得了显著性的效果。
本篇继续第一篇中所说的优化方向,使用faiss实现以图搜图。

1.faiss使用介绍

Faiss的全称是Facebook AI Similarity Search,是FaceBook针对大规模相似度检索问题开发的一个工具,底层是使用C++代码实现的,提供了python的接口,号称对10亿量级的索引可以做到毫秒级检索。

使用faiss的基本步骤
1、数据转换:把原始数据转换为"float32"数据类型的向量。
2、index构建:用 faiss 构建index
3、数据添加:将向量add到创建的index中
4、通过创建的index进行检索

1.创建索引

import faiss

def create_index(datas_embedding):
    # 构建索引,L2代表构建的index采用的相似度度量方法为L2范数
    # 必须传入一个向量的维度,创建一个空的索引
    index = faiss.IndexFlatL2(datas_embedding.shape[1])  
    # 把向量数据加入索引
    index.add(datas_embedding)   
    return index

2.保存索引

def faiss_index_save(faiss_index, save_file_location):
    faiss.write_index(faiss_index, save_file_location)

3.加载索引

def faiss_index_load(faiss_index_save_file_location):
    index = faiss.read_index(faiss_index_save_file_location)
    return index

4.向索引中添加向量

def index_data_add(faiss_index, img_path):
    # 获得索引向量的数量
    print(faiss_index.ntotal)
    img_embedding = extract_image_features(img_path)
    faiss_index.add(img_embedding)
    print(faiss_index.ntotal)

5.删除索引中的向量

def index_data_delete(faiss_index):
    print(faiss_index.ntotal)
    # remove, 指定要删除的向量id,是一个np的array
    faiss_index.remove_ids(np.array([0]))
    print(faiss_index.ntotal)

可以看出使用Faiss工具更加的灵活,可以向索引中添加和删除向量。

2.faiss实现以图搜图

本篇代码有部分是在前两篇的基础之上的,这里使用11类犬类数据集微调之后的resnet18进行特征提取。
第一篇:【以图搜图代码实现】–犬类以图搜图示例
第二篇:【使用resnet18训练自己的数据集】

数据集准备和下载可以去看第二篇文章。

1.模型加载

为了更好的适配,对第一篇中的resnet18的初始化方法进行了修改,如下:

@Project :ImageRec
@File    :resnet18.py
@IDE     :PyCharm
@Author  :菜菜2024
@Date    :2024/9/30
'''
from PIL import Image
from torchvision import transforms
import torch
import torch.nn as nn
from torchvision import models

class ResNet18:
    def __init__(self,
                 out_feature = 11,
                 model_path='E:\\xxx\\ImageRec\\weights\\resnet18.pth'):
        self.trans = transforms.Compose([
        transforms.Resize(size=(256, 256)),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
        print("-----------loading resnet18------------")
        self.model = models.resnet18()
        num_feats = self.model.fc.in_features
        self.model.fc = nn.Linear(num_feats, out_feature)
        self.model.load_state_dict(torch.load(model_path))
        self.model.eval()


    def extract_image_features(self, img_path):

        image = Image.open(img_path).convert('RGB')
        image_tensor = self.trans(image).unsqueeze(0)
        with torch.no_grad():
            features = self.model(image_tensor)
        return features

其中out_feature 根据自己的数据集的类别个数进行更改,我这里的犬类是11种。model_path是训练好的保存的权重文件【训练过程可以去看第二篇】

2.文件名映射

在第一篇:【以图搜图代码实现】–犬类以图搜图示例 中使用的是保存成h5文件,索引是没有要求是整数的,这里faiss要求是整数,搞了一个映射方法,同时也是为了在后面可视化的时候,能根据索引再解码得到对应的文件路径。

#!/usr/bin/env python
# -*- coding: UTF-8 -*-
'''
@Project :ImageRec 
@File    :Imgmap.py
@IDE     :PyCharm 
@Author  :菜菜2024
@Date    :2024/9/29 18:02 
'''
import os
import uuid
import numpy as np


def getImgMap(img_path):
	# 为类别生成一个映射文件
    subnames = [f.split('\\')[-1] for f in os.listdir(img_path)]
    element_mapping = {}
 
    for i in range(len(subnames)):
        unique_id = str(i+2024)
        element_mapping[unique_id] = subnames[i]

    return element_mapping

def valueGetKey(mapping, target_value):


    for key, value in mapping.items():
        if value == target_value:
            # print(f"值 '{target_value}' 对应的键是: {key}")
            break
    return key


def nameMap(imgnames, img_path='E:\\xxx\\datas\\pet_dog\\train'):
    '''
    getImagVector函数得到的image_ids在保存为h5文件时进行了编码
    现在faiss工具中index需要是int类型的,这里进行映射转化
    :param img_path: 数据集目录,来得到类别映射
    :param imgnames: 需要映射的图片名称,解码之后是“中华田园犬_0”格式
    这里传参是列表
    :return:
    '''
    element_mapping = getImgMap(img_path)
    decode_names = [imgname.decode('utf-8') for imgname in imgnames]

    name_ids=[]
    for decode_name in decode_names:
        cla_name = decode_name.split("_")[0]
        img_name = decode_name.split("_")[-1]
        key = valueGetKey(element_mapping, cla_name)
        name_id = key+img_name
        name_ids.append(name_id)

    name_ids=np.array(name_ids).astype('int32')

    return name_ids




if __name__ == "__main__":

    database = 'E:\\xxx\\datas\\pet_dog\\train'
    element_mapping = getImgMap(database)
    print(element_mapping)
    print(element_mapping.get("2024"))

映射文件:

{‘2024’: ‘中华田园犬’, ‘2025’: ‘吉娃娃’, ‘2026’: ‘哈士奇’, ‘2027’: ‘德牧’, ‘2028’: ‘拉布拉多’, ‘2029’: ‘杜宾’, ‘2030’: ‘柴犬’, ‘2031’: ‘法国斗牛’, ‘2032’: ‘萨摩耶’, ‘2033’: ‘藏獒’, ‘2034’: ‘金毛’}
nameMap函数是将之前编码的图像名称进行解码,然后重新编码,编码成20240,20301,分别表示的中华田园犬文件夹下的0.jpg, 柴犬下面的1.jpg。这都是为了可视化的时候进行追溯,得到文件路径。
在这里插入图片描述

3.以图搜图实现

定义了一个类ImageRetrival,使用faiss实现创建索引,保存索引,加载索引和图像检索功能

#!/usr/bin/env python
# -*- coding: UTF-8 -*-
'''
@Project :ImageRec 
@File    :faiss_index.py
@IDE     :PyCharm 
@Author  :菜菜2024
@Date    :2024/9/30 15:04 
'''


import os
import faiss
from utils.split_data import array_norm
from utils.Imgmap import nameMap, getImgMap
from model import ResNet18
from save_feature import getImagVectors
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import rc
# 设置全局字体为支持中文的字体
rc('font', family='SimHei')  # 黑体

class ImageRetrival:
    def __init__(self, model_path,
                 index_dim=None):
        self.index_dim = index_dim
        self.index = faiss.IndexFlatL2(self.index_dim)
        self.model_path = model_path

    def build_index(self, image_files):
        # image_vectors图片特征,image_ids对应的标签
        image_vectors, image_ids = getImagVectors(image_files)

        # image_ids 在之前保存为h5文件时进行了编码,这里进行映射
        name_ids = nameMap(image_ids)

        index = faiss.IndexIDMap(self.index)
        index.add_with_ids(image_vectors, name_ids)
        return index

    def save_index(self, index, index_path):
        faiss.write_index(index, index_path)

    def load_index(self, index_path):
        return faiss.read_index(index_path)

    def image_topK_search(self, index, input_image, topK=None):

        resnet18 = ResNet18(out_feature=11,
                            model_path=self.model_path)
        queryVec = resnet18.extract_image_features(input_image)


        dist, ind = index.search(queryVec, topK)
        dist, ind = dist.flatten(), ind.flatten()
        res = array_norm(dist, ind)
        return res

4.运行调用

if __name__=="__main__":

    model_path='E:\\xxx\\Pycharm_files\\ImageRec\\weights\\resnet18.pth'
    # 1.创建索引
    imageRetrival = ImageRetrival(model_path=model_path,
                                  index_dim=11)
    image_files = 'E:\\xxx\\datas\\pet_dog\\train'
    save_index = "./weights/dog.index"
    index = imageRetrival.build_index(image_files)

    # # 2.保存索引
    imageRetrival.save_index(index, save_index)

    # 3.加载索引
    index_load = imageRetrival.load_index(save_index)
    #
    # # 4.相似度匹配
    input_image = './data/pic/德牧.jpg'
    out = imageRetrival.image_topK_search(index_load, input_image, topK=3)
    print(out)
    showFaissRes(image_files, input_image, out)

运行时选择性注销其中的某一步骤。
最后是可视化实现showFaissRes

5.可视化实现


def showFaissRes(image_files, input_image, faissRes):
    '''
    对faiss得到的结果进行可视化
    :param image_files: 图片数据库
    :param input_image: 查询图片路径
    :param faissRes: 返回的topk跟距离最近的结果[(ind, score), (ind, score)]
    :return:
    '''
    scores = []
    imgs = []
    info = []

    # 1.得到图片名称的映射
    element_mapping = getImgMap(image_files)
    imgs.append(mpimg.imread(input_image))
    info.append(input_image.split("/")[-1])

    for i in range(len(faissRes)):
        score = faissRes[i][1]
        ind = str(faissRes[i][0])
        scores.append(score)

        # 根据索引构建原本的图像路径ind格式:20276,前四个是类别表示
        claName = element_mapping.get(ind[:4])
        imgName = ind[4:]+".jpg"
        imgpath = image_files +"\\"+ claName+ "\\"+imgName
        imgs.append(mpimg.imread(imgpath))

        info.append(claName+"_"+ imgName+"_"+ str(score))
        print("图片名称是: " + claName+ imgName + " 对应得分是: %f" %score)

    num = int((len(faissRes) + 1) // 2)+1
    fig, axs = plt.subplots(nrows=num, ncols=num, figsize=(10, 10))

    # 确保即使只有一个子图,也可以进行索引
    if not isinstance(axs, np.ndarray):
        axs = np.array([[axs]])

    # 显示图像
    flat_index = 0
    for i in range(num):
        for j in range(num):
            if flat_index < len(imgs):
                img = imgs[flat_index]
                axs[i, j].imshow(img, cmap='gray')
                axs[i, j].axis('off')
                axs[i, j].set_title(info[flat_index])
                flat_index += 1
            else:
                axs[i, j].set_visible(False)

    plt.tight_layout()
    plt.show()

3.效果对比

第一篇:【以图搜图代码实现】–犬类以图搜图示例 预训练的resnet18

第二篇:【使用resnet18训练自己的数据集】 微调的resnet18
在这里插入图片描述

本章 Faiss实现: 分数不重要,本篇对分数进行了归一化。
在这里插入图片描述
准确性更高了。


http://www.kler.cn/news/326546.html

相关文章:

  • mips指令系统简介
  • AI大模型面试大纲
  • 基于单片机的催眠电路控制系统
  • [云服务器15] 全网最全!手把手搭建discourse论坛,100%完成
  • 什么是 Apache Ingress
  • 钉钉H5微应用Springboot+Vue开发分享
  • win11下 keil报错Cannot load driver ‘D:\Keil_v5\ARM\Segger\JL2CM3.dll‘
  • WAF,全称Web Application Firewall,好用WAF推荐
  • 小巧机身,但强劲动力实现千元级净须,未野迷你剃须刀测评
  • Java 编码系列:反射详解与面试题解析
  • Julia的安装和使用(附vscode中使用)
  • WordPress 要求插件开发人员进行双因素身份验证
  • Python3 爬虫教程 - Web 网页基础
  • 前端工程规范-3:CSS规范(Stylelint)
  • 栈的最小值
  • 17、CPU缓存架构详解高性能内存队列Disruptor实战
  • Excel技巧:Excel批量提取文件名
  • 开源链动 2+1 模式 S2B2C 商城小程序助力品牌实现先营后销与品效合一
  • Skywalking告警配置
  • 图像生成大模型 Imagen:AI创作新纪元
  • Spring Shell基于注解定义命令
  • 外包干了1个多月,技术明显退步了。。。。。
  • 3-基于容器安装carla
  • Python——判断文件夹/文件是否存在、删除文件夹/文件、新建文件夹
  • SpringAOP学习
  • 【C语言软开面经】
  • pdf提取文字:分享3款pdf文字提取软件,赶快收藏起来!
  • Unity开发绘画板——03.简单的实现绘制功能
  • 配置ssh后又报错git@github.com: Permission denied (publickey)
  • Linux【基础指令汇总】