当前位置: 首页 > article >正文

faiss库中ivf-sq(ScalarQuantizer,标量量化)代码解读-2

文件ScalarQuantizer.h

主要介绍这里面的枚举以及一些函数内容:QuantizerType、RangeStat、ScalarQuantizer、train、compute_codes、decode、SQuantizer、FlatCodesDistanceComputer、get_distance_computer、select_InvertedListScanner

QuantizerType

量化类型中包含一些枚举类型,代码内容如下:

enum QuantizerType {
        QT_8bit,         ///< 8 bits per component
        QT_4bit,         ///< 4 bits per component
        QT_8bit_uniform, ///< same, shared range for all dimensions
        QT_4bit_uniform,
        QT_fp16,
        QT_8bit_direct, ///< fast indexing of uint8s
        QT_6bit,        ///< 6 bits per component
        QT_bf16,
        QT_8bit_direct_signed, ///< fast indexing of signed int8s ranging from
                               ///< [-128 to 127]
    };

QuantizerType 枚举内容解析

枚举值描述适用场景
QT_8bit每个分量用 8-bit 表示,分量间有独立范围(非均匀量化)。精度较高,适用于一般场景。
QT_4bit每个分量用 4-bit 表示,分量间有独立范围(非均匀量化)。精度较低但压缩率高,适用于存储敏感的大规模数据。
QT_8bit_uniform每个分量用 8-bit 表示,但所有分量共享一个范围(均匀量化)。适合数据分布较均匀的场景,计算简单,存储开销低。
QT_4bit_uniform每个分量用 4-bit 表示,所有分量共享一个范围(均匀量化)。存储极为节省,但精度较低,适合快速粗筛。
QT_fp16每个分量用 16-bit 浮点数(half-precision floating-point) 表示。高精度场景,兼顾压缩率和精度,但存储开销大于 8-bit 和 4-bit。
QT_8bit_direct直接使用无符号 uint8 值,无需量化,直接存储整数。数据本身就是无符号 8-bit 整数,例如像素值。
QT_6bit每个分量用 6-bit 表示,比 8-bit 压缩但比 4-bit 精度更高。存储与精度需求的折中选择,适合数据分布复杂但存储受限的场景。
QT_bf16使用 bfloat16(16-bit 浮点数) 表示,范围与精度介于 fp16 和 uint8 之间。适用于高动态范围但对精度要求不高的场景,例如部分深度学习推理中。
QT_8bit_direct_signed直接使用有符号 int8 值,无需量化,范围为 [-128, 127]。数据本身就是有符号的整数,例如分量可能包含负值的数据。

为什么设置为这些类型?

支持不同的精度需求

  • 数据量化需要在存储成本与精度之间找到平衡,不同类型的量化方式提供了不同的压缩率和精度。
    • 8-bit 和 4-bit:提供了常见的两种量化方式,适合大多数存储需求。
    • 6-bit:为 8-bit 和 4-bit 之间提供折中选择。
    • fp16 和 bf16:用于精度较高的场景,支持浮点值表示。
      适配多种数据分布
    • 有些数据在不同维度的范围不同(非均匀分布),需要 非均匀量化(如 QT_8bit 和 QT_4bit)。
    • 对于范围均匀的数据,可以使用 均匀量化(如 QT_8bit_uniform 和 QT_4bit_uniform),进一步减少存储复杂性。
      提高存储效率
    • QT_4bit 和 QT_4bit_uniform:只用 4-bit 表示每个分量,适合存储空间受限的大规模向量。
    • QT_6bit:比 8-bit 进一步压缩,但比 4-bit 提供更高精度。
      兼容特殊数据类型
    • QT_8bit_direct 和 QT_8bit_direct_signed:直接存储原始值,不需要量化过程,适合已经量化好的数据或离散整数值。

RangeStat

代码内容如下:

enum RangeStat {
  RS_minmax,    ///< [min - rs*(max-min), max + rs*(max-min)]
  RS_meanstd,   ///< [mean - std * rs, mean + std * rs]
  RS_quantiles, ///< [Q(rs), Q(1-rs)]
  RS_optim,     ///< alternate optimization of reconstruction error
};
枚举值描述用途与适用场景
RS_minmax根据向量的最小值和最大值生成范围:
[ min - rs * (max - min), max + rs * (max - min) ]。
- 数据分布均匀的场景
- 快速计算,但对异常值较为敏感。
RS_meanstd根据向量的均值和标准差生成范围:
[ mean - std * rs, mean + std * rs ]。
- 考虑数据集中趋势和离散程度
- 对异常值有一定鲁棒性,适合常规分布数据。
RS_quantiles根据分位数生成范围:
[ Q(rs), Q(1-rs) ]。
- 非参数方法,适合分布不均的数据
- 有效减少异常值的影响(如 5%-95% 分位数)。
RS_optim基于优化误差动态调整范围。- 重建精度要求高的场景
- 适合对误差敏感的任务,可能效率较低,但结果更精确。
  • rs 是范围缩放因子(Range Scale)。rs是用户定义的参数,根据具体场景和需求选择合格的值。
  • mean:向量分量的平均值(Mean)。
  • std:向量分量的标准差。
参数含义作用
rs范围缩放因子,控制范围的扩展程度。调节范围宽度。rs 越大,范围越宽,适合分布离散的情况;rs 越小,范围越窄,适合分布集中的情况。
mean均值,表示数据的集中趋势。用作范围的中心值,减少偏移误差。
std标准差,表示数据的离散程度。用于计算范围宽度,离散程度越大,生成的范围越宽。

ScalarQuantizer

用于将浮点转化为相对应的量化结果
头文件C语言内容:

#ifndef SCALAR_QUANTIZER_H
#define SCALAR_QUANTIZER_H

#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>

/* Quantizer 类型枚举 */
typedef enum {
    QT_8bit,            /* 8 位每分量 */
    QT_4bit,            /* 4 位每分量 */
    QT_8bit_uniform,    /* 每分量 8 位,范围统一 */
    QT_4bit_uniform,    /* 每分量 4 位,范围统一 */
    QT_fp16,            /* 16 位浮点 */
    QT_bf16,            /* 16 位 bfloat */
    QT_8bit_direct,     /* 无需额外计算的直接 8 位编码 */
    QT_8bit_direct_signed, /* 带符号的直接 8 位编码 */
    QT_6bit             /* 每分量 6 位 */
} QuantizerType;

/* ScalarQuantizer 数据结构 */
typedef struct {
    size_t d;            /* 输入向量的维度 */
    size_t code_size;    /* 每个向量的编码大小(字节数) */
    int bits;            /* 每分量的位数 */
    QuantizerType qtype; /* 量化器类型 */
} ScalarQuantizer;

/* 初始化函数 */
void ScalarQuantizer_init(ScalarQuantizer* sq, size_t d, QuantizerType qtype);

/* 设置 derived sizes */
void ScalarQuantizer_set_derived_sizes(ScalarQuantizer* sq);

/* 打印量化器信息 */
void ScalarQuantizer_print_info(const ScalarQuantizer* sq);

#endif /* SCALAR_QUANTIZER_H */

实现部分:

#include "ScalarQuantizer.h"

/* 初始化 ScalarQuantizer */
void ScalarQuantizer_init(ScalarQuantizer* sq, size_t d, QuantizerType qtype) {
    if (sq == NULL) {
        fprintf(stderr, "Error: ScalarQuantizer pointer is NULL.\n");
        exit(EXIT_FAILURE);
    }

    sq->d = d;
    sq->qtype = qtype;
    sq->code_size = 0; /* 初始化为 0 */
    sq->bits = 0;      /* 初始化为 0 */

    ScalarQuantizer_set_derived_sizes(sq);
}

/* 设置 code_size 和 bits */
void ScalarQuantizer_set_derived_sizes(ScalarQuantizer* sq) {
    if (sq == NULL) {
        fprintf(stderr, "Error: ScalarQuantizer pointer is NULL.\n");
        exit(EXIT_FAILURE);
    }

    switch (sq->qtype) {
        case QT_8bit:
        case QT_8bit_uniform:
        case QT_8bit_direct:
        case QT_8bit_direct_signed:
            sq->code_size = sq->d;
            sq->bits = 8;
            break;

        case QT_4bit:
        case QT_4bit_uniform:
            sq->code_size = (sq->d + 1) / 2; /* 每两个分量占 1 字节 */
            sq->bits = 4;
            break;

        case QT_6bit:
            sq->code_size = (sq->d * 6 + 7) / 8; /* 位对齐处理 */
            sq->bits = 6;
            break;

        case QT_fp16:
        case QT_bf16:
            sq->code_size = sq->d * 2; /* 每分量 2 字节 */
            sq->bits = 16;
            break;

        default:
            fprintf(stderr, "Error: Unsupported QuantizerType.\n");
            exit(EXIT_FAILURE);
    }

    /* 检查初始化结果 */
    if (sq->code_size == 0 || sq->bits == 0) {
        fprintf(stderr, "Error: Invalid derived sizes.\n");
        exit(EXIT_FAILURE);
    }
}

/* 打印 ScalarQuantizer 信息 */
void ScalarQuantizer_print_info(const ScalarQuantizer* sq) {
    if (sq == NULL) {
        fprintf(stderr, "Error: ScalarQuantizer pointer is NULL.\n");
        return;
    }

    printf("ScalarQuantizer Info:\n");
    printf("  Dimension (d): %zu\n", sq->d);
    printf("  Code Size (bytes per vector): %zu\n", sq->code_size);
    printf("  Bits per component: %d\n", sq->bits);
    printf("  Quantizer Type: %d\n", sq->qtype);
}

关键点说明

  1. ScalarQuantizer 数据结构:
  • 包含 d(维度)、code_size(编码大小)、bits(每分量位数)和 qtype(量化器类型)。
  1. 初始化函数:
  • ScalarQuantizer_init 接收维度和量化器类型作为输入,并调用 ScalarQuantizer_set_derived_sizes 设置计算派生值。
  1. 错误处理:
  • 添加了对空指针的检查以及对非法 QuantizerType 的错误处理。
  1. 位对齐逻辑:
  • 在 QT_6bit 和 QT_4bit 的计算中,处理了位对齐问题。

train、compute_codes和decode

上述函数的作用参考上一篇链接:标量量化
这里我定义的ScalarQuantizer结构体内容如下:

typedef struct ScalarQuantizer{
    size_t d;              /* 输入向量的维度 */
    size_t code_size;      /* 每个向量的编码大小(字节数) */
    int bits;              /* 每分量的位数 */
    QuantizerType qtype;   /* 量化器类型 */

    /* 训练参数 */
    float rangestat;       /* 范围统计参数 */
    float rangestat_arg;   /* 范围统计扩展参数 */
    int n_centroids;       /* 聚类中心数量 */

    /* 训练结果 */
    float* trained;        /* 存储训练的量化中心 */
    float* centroids;      /* 非均匀量化时的聚类中心 */

    /* 并行和状态跟踪 */
    int num_threads;       /* 并行计算的线程数量 */
    size_t progress;       /* 当前训练/编码的进度 */

    /* 高级量化支持 */
    int* per_dim_bits;     /* 每维分量的位数,支持动态调整 */
    float* residuals;      /* 残差向量存储 */
    int residual_levels;   /* 残差量化层数 */
} ScalarQuantizer;

然后将训练train分为3个部分:ScalarQuantizer_train(QT类4、6、8位)、train_Uniform(均匀量化训练 )和train_NonUniform(非均匀量化训练),其中train_Uniform和train_NonUniform在代码中呈现。

void ScalarQuantizer::train(size_t n, const float* x) {
    int bit_per_dim = qtype == QT_4bit_uniform ? 4
            : qtype == QT_4bit                 ? 4
            : qtype == QT_6bit                 ? 6
            : qtype == QT_8bit_uniform         ? 8
            : qtype == QT_8bit                 ? 8
                                               : -1;

    switch (qtype) {
        case QT_4bit_uniform:
        case QT_8bit_uniform:
            train_Uniform(
                    rangestat,
                    rangestat_arg,
                    n * d,
                    1 << bit_per_dim,
                    x,
                    trained);
            break;
        case QT_4bit:
        case QT_8bit:
        case QT_6bit:
            train_NonUniform(
                    rangestat,
                    rangestat_arg,
                    n,
                    d,
                    1 << bit_per_dim,
                    x,
                    trained);
            break;
        case QT_fp16:
        case QT_8bit_direct:
        case QT_bf16:
        case QT_8bit_direct_signed:
            // no training necessary
            break;
    }
}

http://www.kler.cn/a/414512.html

相关文章:

  • 『Linux学习笔记』linux系统有哪些方法计算文件的md5!
  • FPGA实现GTP光口视频转USB3.0传输,基于FT601+Aurora 8b/10b编解码架构,提供3套工程源码和技术支持
  • STL算法之基本算法<stl_algobase.h>
  • Java文件遍历那些事
  • day29|leetcode 134. 加油站 , 135. 分发糖果 ,860.柠檬水找零 , 406.根据身高重建队列
  • JAVA实现上传附件到服务器
  • 淘宝关键词挖掘:Python爬虫技术在电商领域的应用
  • 虚拟现实(VR)与增强现实(AR)有什么区别?
  • 【k8s深入理解之 Scheme 补充-6】理解资源外部版本之间的优先级
  • TypeScript中function和const定义函数的区别
  • java 排序 详解
  • 【Unity基础】初识Unity中的渲染管线
  • 中科亿海微SoM模组——波控处理软硬一体解决方案
  • HarmonyOS 5.0应用开发——装饰器的使用
  • NAT:连接私有与公共网络的关键技术(4/10)
  • NLP任务四大范式的进阶历程:从传统TF-IDF到Prompt-Tuning(提示词微调)
  • C++《二叉搜索树》
  • Vue3.0性能提升主要是通过哪几方面体现的?通过编译阶段、源码体积、响应式系统等进行讲解!
  • 什么是串联谐振
  • 【动态规划入门】【1.2打家劫舍问题】【从记忆化搜索到递推】【灵神题单】【刷题笔记】
  • 【新人系列】Python 入门(十四):文件操作
  • 【微服务】消息队列与微服务之微服务详解
  • 报错:java: 无法访问org.springframework.boot.SpringApplication
  • R 因子
  • 深度学习day4-模型
  • Java知识及热点面试题总结(三)