当前位置: 首页 > article >正文

Clip结合Faiss+Flask简易版文搜图服务

一、实现

使用目录结构:

templates

        ---upload.html

 faiss_app.py

前端代码:upload.html

<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>Search and Show Multiple Images</title>
    <style>
        #image-container {
            display: flex;
            flex-wrap: wrap;
        }
        #image-container img {
            max-width: 150px;
            margin: 10px;
        }
    </style>
</head>
<body>
    <h1>Search Images</h1>
    
    <!-- 搜索框 -->
    <form id="search-form">
        <input type="text" id="search-input" name="query" placeholder="Enter search term" required>
        <input type="submit" value="Search">
    </form>

    <h2>Search Results</h2>
    <!-- 显示搜索返回的多张图片 -->
    <div id="image-container"></div>

    <!-- 使用JS处理表单提交 -->
    <script>
        document.getElementById('search-form').addEventListener('submit', async function(event) {
            event.preventDefault();  // 阻止表单默认提交行为
            
            const query = document.getElementById('search-input').value;  // 获取搜索框中的输入内容

            try {
                // 发送GET请求,将搜索关键词发送到后端
                const response = await fetch(`/search?query=${encodeURIComponent(query)}`, {
                    method: 'GET',
                });

                // 确保服务器返回JSON数据
                const data = await response.json();

                // 清空图片容器
                const imageContainer = document.getElementById('image-container');
                imageContainer.innerHTML = '';

                // 遍历后端返回的图片URL数组,动态创建<img>标签并渲染
                data.image_urls.forEach(url => {
                    const imgElement = document.createElement('img');
                    imgElement.src = url;  // 设置图片的src属性为返回的URL
                    imageContainer.appendChild(imgElement);  // 将图片添加到容器中
                });
            } catch (error) {
                console.error('Error searching for images:', error);
            }
        });
    </script>
</body>
</html>

后端代码 faiss_app.py:

from sentence_transformers import SentenceTransformer, util
from PIL import Image
from flask import Flask, request, jsonify, current_app, render_template, send_from_directory, url_for
from werkzeug.utils import secure_filename
import faiss
import os, glob
import numpy as np
from markupsafe import escape
import shutil

#Load CLIP model
model = SentenceTransformer('clip-ViT-B-32')
IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.gif', '.bmp'}

UPLOAD_FOLDER = 'uploads/'
IMAGES_PATH  = "C:\\Users\\xxxx\\Pictures\\"

def generate_clip_embeddings(images_path, model):
    image_paths = []
    # 使用 os.walk 遍历所有子目录和文件
    for root, dirs, files in os.walk(images_path):
        for file in files:
            # 获取文件的扩展名并转换为小写
            ext = os.path.splitext(file)[1].lower()
            # 判断是否是图片文件
            if ext in IMAGE_EXTENSIONS:
                image_paths.append(os.path.join(root, file)) 
    embeddings = []
    for img_path in image_paths:
        image = Image.open(img_path)
        embedding = model.encode(image)
        embeddings.append(embedding)
    
    return embeddings, image_paths

def create_faiss_index(embeddings, image_paths, output_path):

    dimension = len(embeddings[0])

    # 分情况创建Faiss索引对象
    if len(image_paths) < 39 * 256:
        # 如果条目很少,直接用最普通的L2索引
        faiss_index = faiss.IndexFlatL2(dimension)
    elif len(image_paths) < 39 * 4096:
        # 如果条目少于39 × 4096,就只用PQ量化,不使用IVF
        faiss_index = faiss.index_factory(dimension, 'OPQ64_256,PQ64x8')
    else:
        # 否则就加上IVF
        faiss_index = faiss.index_factory(dimension, 'OPQ64_256,IVF4096,PQ64x8')
    res = faiss.StandardGpuResources()
    co = faiss.GpuClonerOptions()
    co.useFloat16 = True
    faiss_index = faiss.index_cpu_to_gpu(res, 0, faiss_index, co)

    #index = faiss.IndexFlatIP(dimension)
    faiss_index = faiss.IndexIDMap(faiss_index)
    
    vectors = np.array(embeddings).astype(np.float32)

    # Add vectors to the index with IDs
    faiss_index.add_with_ids(vectors, np.array(range(len(embeddings))))
    
    # Save the index
    faiss_index = faiss.index_gpu_to_cpu(faiss_index)
    faiss.write_index(faiss_index, output_path)
    print(f"Index created and saved to {output_path}")
    
    # Save image paths
    with open(output_path + '.paths', 'w') as f:
        for img_path in image_paths:
            f.write(img_path + '\n')
    
    return faiss_index

def load_faiss_index(index_path):
    faiss_index = faiss.read_index(index_path)
    with open(index_path + '.paths', 'r') as f:
        image_paths = [line.strip() for line in f]
    print(f"Index loaded from {index_path}")
    if not faiss_index.is_trained:
            raise RuntimeError(f'从[{index_path}]加载的Faiss索引未训练')
    res = faiss.StandardGpuResources()
    co = faiss.GpuClonerOptions()
    co.useFloat16 = True
    faiss_index = faiss.index_cpu_to_gpu(res, 0, faiss_index, co)
    return faiss_index, image_paths


def retrieve_similar_images(query, model, index, image_paths, top_k=3):
    
    # query preprocess:
    if query.endswith(('.png', '.jpg', '.jpeg', '.tiff', '.bmp', '.gif')):
        query = Image.open(query)

    query_features = model.encode(query)
    query_features = query_features.astype(np.float32).reshape(1, -1)

    distances, indices = index.search(query_features, top_k)

    retrieved_images = [image_paths[int(idx)] for idx in indices[0]]

    return query, retrieved_images

# 检查文件扩展名是否允许
def allowed_file(filename):
    return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS

def search():
    query = request.args.get('query')  # 获取搜索关键词
    safe_query = escape(query)

    if not query:
        return jsonify({"error": "No search query provided"}), 400
    index, image_paths = None, []
    OUTPUT_INDEX_PATH = f"{app.config['UPLOAD_FOLDER']}/vector.index"
    if os.path.exists(OUTPUT_INDEX_PATH):
        index, image_paths = load_faiss_index(OUTPUT_INDEX_PATH)
    else:
        embeddings, image_paths = generate_clip_embeddings(IMAGES_PATH, model)
        index = create_faiss_index(embeddings, image_paths, OUTPUT_INDEX_PATH)
    query, retrieved_images = retrieve_similar_images(query, model, index, image_paths, top_k=5)


    image_urls = []
    for path in retrieved_images:
        base_name = os.path.basename(path)
        shutil.copy(path, os.path.join(app.config['UPLOAD_FOLDER'], base_name))
        image_urls.append(url_for('uploaded_file_path', filename=base_name))

    return jsonify({"image_urls": image_urls})


def index():
    return render_template('upload.html')

# 提供静态文件的访问路径
def uploaded_file_path(filename):
    return send_from_directory(app.config['UPLOAD_FOLDER'], filename)

if __name__ == "__main__":
    app = Flask(__name__)
    app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
    if not os.path.exists(UPLOAD_FOLDER):
        os.makedirs(UPLOAD_FOLDER)
    # 主页显示上传表单
    app.route('/')(index)
    app.route('/search', methods=['GET'])(search)
    app.route('/uploads/images/<filename>')(uploaded_file_path)
    app.run(host='0.0.0.0', port=8080, debug=True)

二、效果


http://www.kler.cn/a/399953.html

相关文章:

  • 网络安全之国际主流网络安全架构模型
  • 51单片机--- 矩阵按键仿真
  • 【售前方案】工业园区整体解决方案,智慧园区方案,智慧城市方案,智慧各类信息化方案(ppt原件)
  • python读写excel等数据文件方法汇总
  • C++内存管理 - new/delete
  • 什么是SMARC?模块电脑(核心板)规范标准简介三
  • 使用PSpice进行第一个电路的仿真
  • ACE之单例
  • 把一个对象序列化为字符串,再反序列化回来
  • cisco防火墙在内网通过外网域名进行访问的配置
  • 汽车与摩托车分类数据集
  • 【Flask+Gunicorn+Nginx】部署目标检测模型API完整解决方案
  • 【gitlab】gitlabrunner部署
  • 基于差分、粒子群算法下的TSP优化对比
  • YOLOv11融合针对小目标FFCA-YOPLO中的FEM模块及相关改进思路
  • Tailscale 自建 Derp 中转服务器
  • 【Mac】卸载JAVA、jdk
  • Day02_AJAX综合案例 (黑马笔记)
  • 在 CentOS 7 上安装 MinIO 的步骤
  • 【爬虫实战】抓取某站评论
  • 【论文笔记】SCOPE: Sign Language Contextual Processing with Embedding from LLMs
  • 代码随想录第三十四天
  • 输出比较简介
  • 来LeetCode练下思维吧
  • uniapp微信小程序转发跳转指定页面
  • git环境开发问题-处理