当前位置: 首页 > article >正文

faiss库中ivf-sq(ScalarQuantizer,标量量化)代码解读-7

流程

在这里插入图片描述

代码

void IndexIVF::search(
        idx_t n,
        const float* x,
        idx_t k,
        float* distances,
        idx_t* labels,
        const SearchParameters* params_in) const {
    FAISS_THROW_IF_NOT(k > 0);
    const IVFSearchParameters* params = nullptr;
    if (params_in) {
        params = dynamic_cast<const IVFSearchParameters*>(params_in);
        FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type");
    }
    const size_t nprobe =
            std::min(nlist, params ? params->nprobe : this->nprobe);
    FAISS_THROW_IF_NOT(nprobe > 0);

    // search function for a subset of queries
    auto sub_search_func = [this, k, nprobe, params](
                                   idx_t n,
                                   const float* x,
                                   float* distances,
                                   idx_t* labels,
                                   IndexIVFStats* ivf_stats) {
        std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
        std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);

        double t0 = getmillisecs();
        quantizer->search(
                n,
                x,
                nprobe,
                coarse_dis.get(),
                idx.get(),
                params ? params->quantizer_params : nullptr);

        double t1 = getmillisecs();
        invlists->prefetch_lists(idx.get(), n * nprobe);

        search_preassigned(
                n,
                x,
                k,
                idx.get(),
                coarse_dis.get(),
                distances,
                labels,
                false,
                params,
                ivf_stats);
        double t2 = getmillisecs();
        ivf_stats->quantization_time += t1 - t0;
        ivf_stats->search_time += t2 - t0;
    };

    if ((parallel_mode & ~PARALLEL_MODE_NO_HEAP_INIT) == 0) {
        int nt = std::min(omp_get_max_threads(), int(n));
        std::vector<IndexIVFStats> stats(nt);
        std::mutex exception_mutex;
        std::string exception_string;

#pragma omp parallel for if (nt > 1)
        for (idx_t slice = 0; slice < nt; slice++) {
            IndexIVFStats local_stats;
            idx_t i0 = n * slice / nt;
            idx_t i1 = n * (slice + 1) / nt;
            if (i1 > i0) {
                try {
                    sub_search_func(
                            i1 - i0,
                            x + i0 * d,
                            distances + i0 * k,
                            labels + i0 * k,
                            &stats[slice]);
                } catch (const std::exception& e) {
                    std::lock_guard<std::mutex> lock(exception_mutex);
                    exception_string = e.what();
                }
            }
        }

        if (!exception_string.empty()) {
            FAISS_THROW_MSG(exception_string.c_str());
        }

        // collect stats
        for (idx_t slice = 0; slice < nt; slice++) {
            indexIVF_stats.add(stats[slice]);
        }
    } else {
        // handle paralellization at level below (or don't run in parallel at
        // all)
        sub_search_func(n, x, distances, labels, &indexIVF_stats);
    }
}

代码解析

IndexIVF::search 函数是 FAISS 的 IndexIVF 类中实现的一个核心函数,用于在倒排文件(Inverted File List, IVF)索引中执行搜索操作。以下是对函数的详细解析:

函数功能

在倒排文件索引中搜索最近的 k 个向量,返回它们的距离和对应的索引。
支持多线程并行化以提高查询性能。

参数说明

void IndexIVF::search(
    idx_t n,                    // 查询向量的数量
    const float* x,             // 查询向量(每个向量有 d 个维度)
    idx_t k,                    // 每个查询向量要找到的最近邻个数
    float* distances,           // 输出的距离数组,大小为 n*k
    idx_t* labels,              // 输出的索引数组,大小为 n*k
    const SearchParameters* params_in // 搜索参数,可选
) const;
  • n:查询向量的数量。
  • x:指向查询向量的指针,形状为 (n, d)。
  • k:每个查询向量需要返回的最近邻数量。
  • distances:保存结果的距离数组。
  • labels:保存结果的索引数组。
  • params_in:可选的搜索参数对象,通常包括 nprobe(控制搜索的倒排列表数量)等。
函数实现解析
  1. 参数验证
FAISS_THROW_IF_NOT(k > 0);

确保 k > 0,即需要找到至少一个最近邻。
2. 处理搜索参数

const IVFSearchParameters* params = nullptr;
if (params_in) {
    params = dynamic_cast<const IVFSearchParameters*>(params_in);
    FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type");
}
const size_t nprobe = std::min(nlist, params ? params->nprobe : this->nprobe);
FAISS_THROW_IF_NOT(nprobe > 0);
  • 检查输入的搜索参数 params_in 是否是 IVFSearchParameters 类型。
  • 从参数中提取 nprobe,即查询时访问的倒排列表数量:
    • 如果参数提供了 nprobe,则使用参数中的值。
    • 如果未提供,则使用索引默认的 nprobe。
  • 确保 nprobe > 0。
  1. 定义子搜索函数
auto sub_search_func = [this, k, nprobe, params](
    idx_t n,
    const float* x,
    float* distances,
    idx_t* labels,
    IndexIVFStats* ivf_stats) {
    ...
};

定义一个局部 lambda 函数 sub_search_func,处理子查询任务。参数包括当前的查询向量、结果存储位置和统计信息。
内部实现的步骤:
量化查询向量:

quantizer->search(n, x, nprobe, coarse_dis.get(), idx.get(), params ? params->quantizer_params : nullptr);

使用量化器将查询向量分配到 nprobe 个倒排列表中。idx 保存分配的倒排列表索引。coarse_dis 保存量化后的距离。
倒排列表的预取:

invlists->prefetch_lists(idx.get(), n * nprobe);

预取倒排列表数据以提高内存访问性能。
实际搜索:

search_preassigned(n, x, k, idx.get(), coarse_dis.get(), distances, labels, false, params, ivf_stats);

在分配好的倒排列表中执行搜索,返回最近邻结果的距离和索引。
更新统计信息:

ivf_stats->quantization_time += t1 - t0;
ivf_stats->search_time += t2 - t0;

记录量化时间和搜索时间。
4. 选择并行模式

if ((parallel_mode & ~PARALLEL_MODE_NO_HEAP_INIT) == 0) {
    ...
} else {
    sub_search_func(n, x, distances, labels, &indexIVF_stats);
}

根据 parallel_mode 决定并行模式:如果启用了并行模式,则使用 OpenMP 进行多线程查询。否则直接调用 sub_search_func 处理整个查询。
5. 并行查询

int nt = std::min(omp_get_max_threads(), int(n));
std::vector<IndexIVFStats> stats(nt);
std::mutex exception_mutex;
std::string exception_string;

#pragma omp parallel for if (nt > 1)
for (idx_t slice = 0; slice < nt; slice++) {
    ...
}
  • 线程数量:设置线程数量为查询向量数和最大线程数的较小值。
  • 分片查询:将查询向量分配到多个线程进行并行处理。
  • 异常处理:捕获并记录线程中的异常。
  • 统计合并:将各线程的统计结果合并到全局统计对象。
关键步骤总结
  • 查询向量量化:使用量化器将查询向量映射到倒排列表。
  • 倒排列表预取:优化内存访问以提高效率。
  • 倒排列表搜索:在分配的倒排列表中执行精确搜索。
  • 支持并行化:利用 OpenMP 将查询任务分片并行化处理。
函数作用
  • 高效搜索:支持通过 nprobe 调整查询范围,平衡搜索速度和准确率。
  • 并行优化:通过多线程实现大规模查询的加速。
  • 灵活性:支持自定义搜索参数(如量化器配置)以适应不同场景。
适用场景

海量数据的最近邻搜索,例如向量化文档、推荐系统和图像检索


http://www.kler.cn/a/455988.html

相关文章:

  • 计算机网络 (12)物理层下面的传输媒体
  • 简单发布一个npm包
  • 什么是ondelete cascade以及使用sqlite演示ondelete cascade使用案例
  • Spring Security3.0.2版本
  • 基于Pycharm与数据库的新闻管理系统(3)MongoDB
  • MyBatis知识点笔记
  • CSS面试题|[2024-12-24]
  • python中Windows系统使用 pywin32 来复制图像到剪贴板,并使用 Selenium 模拟 Ctrl+V 操作
  • 嵌入式科普(26)“相面”各大厂MCU和MPU
  • 再谈c++线性关系求值
  • pinia从0到1
  • 美食推荐系统|Java|SSM|JSP|
  • VSCode 插件开发实战(八):创建和管理任务 Task
  • day19
  • 【超简单】Python入门实用教程
  • 你不需要对其他成年人的情绪负责
  • 深入了解 Reactor:响应式编程的利器
  • QT,opencv制作界面化图片操作
  • Vue.js 入门与进阶:打造高效的前端开发体验
  • 机床数据采集网关在某机械制造企业的应用
  • Unity游戏环境交互系统
  • 回声函数 printf重定向 sht20温湿度传感器
  • 代码随想录38 322. 零钱兑换,279.完全平方数,本周小结动态规划,139.单词拆分,动态规划:关于多重背包,你该了解这些!背包问题总结篇。
  • 不修改内核镜像的情况下,使用内核模块实现“及时”的调度时间片超时事件上报
  • Redis-十大数据类型
  • 通过 `@Configuration` 和 `WebMvcConfigurer` 配置 Spring MVC 中的静态资源映射