当前位置: 首页 > article >正文

使用数据库sqlite 筛选人脸信息

# 主要筛选人脸信息(比如:0 这个人的文件夹里有很多张属于0的人脸照片,但是同时又参杂一些非常模糊或者其他人的照片,那么可以通过这个方法把参杂的模糊的和其他人的人脸排序到最后,那样清理的时候就不需要到处找那些不合格的照片)

import os
import shutil

import numpy as np
from sklearn.metrics.pairwise import euclidean_distances
from PIL import Image
import torch
import torchvision.transforms as transforms
from facenet_pytorch import InceptionResnetV1
import sqlite3
import threading

# 1. 加载预训练的人脸特征提取模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = InceptionResnetV1(pretrained='vggface2').eval().to(device)

# 2. 图像预处理
transform = transforms.Compose([
    transforms.Resize((160, 160)),  # FaceNet 输入尺寸
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])


# 3. 提取单张图像的特征向量
def extract_feature(image_path):
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0).to(device)
    with torch.no_grad():
        feature = model(image).cpu().numpy().flatten()
    return feature


# 4. 创建 SQLite 数据库
def create_database(db_path):
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    cursor.execute('''
        CREATE TABLE IF NOT EXISTS features (
            person_id TEXT,
            image_path TEXT,
            feature_vector BLOB,
            PRIMARY KEY (person_id, image_path)
        )
    ''')
    conn.commit()
    conn.close()


# 5. 将特征向量保存到数据库
def save_feature_to_db(db_path, person_id, image_path, feature):
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()

    # 检查是否有相同的person_id 和 image_path 存在 (目的是为例防止程序中断 后 又重新运行 数据插入冲突导致报错)
    cursor.execute("""
       SELECT COUNT(*) FROM features
       WHERE person_id = ? AND image_path = ? """, (person_id, image_path))
    count = cursor.fetchone()[0]

    # 如果不存在
    if count == 0:
        feature_blob = feature.tobytes()  # 将特征向量转换为二进制格式
        cursor.execute('''
            INSERT INTO features (person_id, image_path, feature_vector)
            VALUES (?, ?, ?)
        ''', (person_id, image_path, feature_blob))
        conn.commit()
        conn.close()
    else:
        print(f"Feature for {person_id} - {image_path} already exists,  skipping")


# 6. 处理每个文件夹,提取特征并保存到数据库
def process_folder(db_path, folder_path, person_id):
    for image_name in os.listdir(folder_path):
        image_path = os.path.join(folder_path, image_name)
        # 避免处理非图片文件
        if image_path.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
            # 防止因图片损坏导致提取特侦失败致使程序中断
            try:
                feature = extract_feature(image_path)
                save_feature_to_db(db_path, person_id, image_path, feature)
            except Exception as e:
                print(e)


# 7. 从数据库中获取某个人的平均特征向量
def get_avg_feature(db_path, person_id):
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    cursor.execute('''
        SELECT feature_vector FROM features WHERE person_id = ?
    ''', (person_id,))
    rows = cursor.fetchall()
    conn.close()

    # 将所有特征向量转换为 numpy 数组
    features = [np.frombuffer(row[0], dtype=np.float32) for row in rows]
    avg_feature = np.mean(features, axis=0)
    return avg_feature


# 8. 根据欧氏距离排序并重命名图像
def sort_and_rename_images(db_path, out_path, person_id):
    avg_feature = get_avg_feature(db_path, person_id)

    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    cursor.execute('''
        SELECT image_path, feature_vector FROM features WHERE person_id = ?
    ''', (person_id,))
    rows = cursor.fetchall()
    conn.close()

    # 计算欧氏距离并排序
    distances = []
    for row in rows:
        image_path, feature_blob = row
        feature = np.frombuffer(feature_blob, dtype=np.float32)
        distance = euclidean_distances([feature], [avg_feature])[0][0]
        distances.append((image_path, distance))

    # 按距离排序
    distances.sort(key=lambda x: x[1])

    # 重命名文件
    for idx, (image_path, _) in enumerate(distances):
        new_name = f"{idx:04d}.jpg"  # 按距离排序后的新文件名
        # new_path = os.path.join(folder_path, new_name)
        new_path = rf'{out_path}/{person_id}/{new_name}'
        # 如果目标文件夹不存在,则创建
        os.makedirs(os.path.dirname(new_path), exist_ok=True)
        shutil.copy(image_path, new_path)

        # os.rename(image_path, new_path)


# 9. 主函数
def main():
    # 数据库路径
    db_path = r'D:\FS_project2\Feature_extraction\sql_database\features.db2'
    create_database(db_path)

    # 基础路径
    base_path = r'D:\FS_project2\Feature_extraction\peopel_crop'
    out_path = r'D:\FS_project2\Feature_extraction\out'

    # 第一步:提取特征并保存到数据库
    for folder in os.listdir(base_path):
        folder_path = os.path.join(base_path, folder)
        if os.path.isdir(folder_path):
            process_folder(db_path, folder_path, folder)
            print(f"Processed folder: {folder}")

    # 第二步:排序并重命名图像
    for folder in os.listdir(base_path):
        folder_path = os.path.join(base_path, folder)
        if os.path.isdir(folder_path):
            sort_and_rename_images(db_path, out_path, folder)
            print(f"Sorted and renamed folder: {folder}")


if __name__ == "__main__":
    main()


http://www.kler.cn/a/547183.html

相关文章:

  • Oracle查看执行计划
  • 项目中菜单按照层级展示sql
  • SpringCloud面试题----微服务下为什需要链路追踪系统
  • 【C++】C++ 旅馆管理系统(含 源码+报告)【独一无二】
  • 统计安卓帧率和内存
  • (萌新入门)如何从起步阶段开始学习STM32 —— 1如何迁移一个开发版的工程
  • c#展示网页并获取网页上触发按钮的值进行系统业务逻辑处理
  • vue3 关于插槽的使用
  • 手写一个Java Android Binder服务及源码分析
  • 云创智城充电系统:基于 SpringCloud 的高可用、可扩展架构详解-多租户、多协议兼容、分账与互联互通功能实现
  • git bash在github的库中上传或更新本地文件
  • SOUI基于Zint生成Code 39码
  • 【面试】网络安全常问150道面试题
  • Vue 2 + Webpack 项目中集成 ESLint 和 Prettier
  • 前端包管理器的发展以及Npm、Yarn和Pnpm对比
  • Netty源码解析之异步处理(二):盛赞Promise中的集合设计
  • 三、k8s pod详解
  • SQLMesh系列教程-3:SQLMesh模型属性详解
  • 算法04-希尔排序
  • Windows搭建Docker+Ollama+Open-WebUI部署DeepSeek本地模型