当前位置: 首页 > article >正文

FFM 因子分解机原理与特征域概念解析

实验和完整代码

完整代码实现和jupyter运行:https://github.com/Myolive-Lin/RecSys--deep-learning-recommendation-system/tree/main

引言

因子分解机(Field-aware Factorization Machine,FFM)是一种广泛应用于推荐系统、CTR 预估等任务的机器学习模型。相比于传统的因子分解机(Factorization Machine,FM),FFM 通过引入特征域(Field) 的概念,能够更好地建模特征交叉,提高模型的预测性能。

1. FFM 的基本原理

因子分解机(FM)通过低维向量来表示每个特征,并计算特征之间的二阶交叉项,公式如下:

y ^ = w 0 + ∑ i = 1 n w i x i + ∑ i = 1 n ∑ j = i + 1 n ⟨ v i , v j ⟩ x i x j \hat{y} = w_0 + \sum_{i=1}^{n}w_ix_i + \sum_{i=1}^{n}\sum_{j=i+1}^{n} \langle v_i, v_j \rangle x_i x_j y^=w0+i=1nwixi+i=1nj=i+1nvi,vjxixj

其中,

  • w 0 w_0 w0 为偏置项,
  • w i w_i wi 为特征 x i x_i xi 的权重,
  • v i v_i vi 为特征$x_i $ 的嵌入向量,
  • ⟨ v i , v j ⟩ \langle v_i, v_j \rangle vi,vj表示向量 v i v_i vi v j v_j vj的内积。

然而,FM 仅为每个特征学习一个固定的向量 v i v_i vi,无法根据不同特征域的组合动态调整交叉信息。因此,FFM 在此基础上进一步细化了特征交叉的建模方式。

在 FFM 中,每个特征 $x_i $ 依据其所属的特征域(Field)来学习不同的嵌入向量。FFM 的预测公式如下:

y ^ = w 0 + ∑ i = 1 n w i x i + ∑ i = 1 n ∑ j = i + 1 n ⟨ v i , f j , v j , f i ⟩ x i x j \hat{y} = w_0 + \sum_{i=1}^{n}w_ix_i + \sum_{i=1}^{n}\sum_{j=i+1}^{n} \langle v_{i,f_j}, v_{j,f_i} \rangle x_i x_j y^=w0+i=1nwixi+i=1nj=i+1nvi,fj,vj,fixixj

与 FM 主要的区别在于:

  • v i , f j v_{i,f_j} vi,fj 代表特征 x i x_i xi 在特征域 f j f_j fj 上的嵌入向量,而非全局共享的 v i v_i vi
  • 每个特征在不同特征域上有不同的向量,从而可以更灵活地建模特征交互。

2. 特征域(Field)的概念

在 FFM 中,特征域(Field) 指的是特征的类别或属性集合。最典型的例子是由One-hot编码的特征之后也属于一个特征域

在实际应用中,许多分类特征通常会被转换为 One-hot 编码。例如,假设我们有一个“职业”特征,其可能的取值包括医生工程师教师,那么一名医生经过 One-hot 编码后可能会变为特征向量[1,0,0],这样一段特征向量也就是一个职业的特征域。如下所示:

职业_医生职业_工程师职业_教师
100

此外,在 CTR 预估任务中,数据集可能包含以下特征:

  • 用户 ID
  • 用户年龄
  • 设备类型
  • 广告 ID
  • 广告类别

在 FM 中,这些特征被视为独立个体,而在 FFM 中,它们会被分配到不同的特征域。例如:

  • 用户 ID、用户年龄属于“用户特征”域;
  • 设备类型属于“设备”域;
  • 广告 ID、广告类别属于“广告特征”域。

FFM 的核心思想是:

  • 对于每个特征 x i x_i xi,当它与另一特征 x j x_j xj 交互时,会使用该特征在 f j f_j fj 领域上的嵌入向量 ( v i , f j ) ( v_{i,f_j} ) (vi,fj)
  • 这样可以保证不同领域的特征交叉时,使用更加合适的嵌入表示,提高模型效果。

3. FFM举例

UserMovieGenrePrice
YuChin3IdiotsComedy, Drama (Co, Dr)$9.99

这个样本有4个特征,其中User,Moive,Genre是类别型特征,Price是数值型特征,经过One-hot处理后有

Field nameField indexFeature nameFeature index
Userfield 1User-YuChinfeature 1
Moviefield 2Movie-3Idiotsfeature 2
Genrefield 3Genre-Comedyfeature 3
pricefield 4Genre-Dramafeature 4
Pricefeature 5

对于二阶交叉项部分,对应会有特征数量5 * 特征域4 = 20 * 隐向量维度 = 20* k参数,其中蓝色是特征index,红色对应的特征域index

在这里插入图片描述

4.代码实现

class FFM(nn.Module):
    def __init__(self, num_features, num_fields, embedding_dim):
        """
        FFM 模型初始化
        :param num_features: 特征总数(特征编码后的维度)
        :param num_fields: 域(Field)的总数
        :param embedding_dim: 隐向量维度
        """
        super(FFM, self).__init__()
        self.num_features = num_features
        self.num_fields = num_fields
        self.embedding_dim = embedding_dim

        # 初始化参数
        self.w0 =nn.Parameter(torch.zeros(1))
        self.w = nn.Parameter(torch.randn(num_features))

        # Field-aware 隐向量部分
        # 每个特征对每个其他域维护一个隐向量
        # 形状: (num_features, num_fields, embedding_dim)
        self.embeddings = nn.Parameter(
            torch.randn(num_features,num_fields,embedding_dim)
        )

    def forward(self,X,filed_map):
        """
        :param x: 输入张量,形状 (batch_size, num_features),稀疏 one-hot 编码
        :param field_map: 每个特征所属的域编号,形状 (num_features,)
        """
        # -------------------- 线性部分计算 --------------------
        # 线性项: w0 + sum(wi * xi)
        linear_terms = self.w0 + torch.sum(self.w * X, dim = 1) # (batch_size,)

        # -------------------- 交叉部分计算 --------------------
        batch_size = X.shape[0]
        cross_terms = torch.zeros(batch_size,device = X.device)# 存储交叉项的结果

        for i in range(self.num_features):
            for j in range(i+1,self.num_features):
                # 只计算非零元素的交叉项

                xi = X[:,i] # (batch_size,)
                xj = X[:,j] # (batch_size,)
                non_zero = (xi != 0) & (xj != 0) #仅同时处理非0特征对 逻辑与(&)操作会逐元素执行,返回一个布尔数组,表示 xi 和 xj 在相应位置上同时非零的位置。


                if non_zero.any(): #至少存一个非0特征对
                    fi = filed_map[i]
                    fj = filed_map[j]

                    #提取对应的隐向量
                    vi = self.embeddings[i,fi,:]
                    vj = self.embeddings[j,fj,:]

                    #计算点积并加权
                    interaction = torch.sum(vi * vj)# 标量
                    cross = interaction * xi[non_zero] * xj[non_zero]
                    cross_terms[non_zero] += cross
        # -------------------- 输出 --------------------
        output = linear_terms + cross_terms
        return torch.sigmoid(output) # 适用于二分类(如CTR预测)
    
    def predict(self,X,filed_map):
        """
        预测函数
        :param x: 输入张量,形状 (batch_size, num_features),稀疏 one-hot 编码
        :param field_map: 每个特征所属的域编号,形状 (num_features,)
        """
        # 计算模型输出
        output = self.forward(X,filed_map)
        
        return output
    
        

4. FFM 相比 FM 的优势

FFM 由于采用了针对不同特征域的独立嵌入向量,在实际应用中具有以下优势:

  • 更细粒度的特征交互:不同特征域的交互关系能够被更精准地捕捉。
  • 更高的预测精度:尤其在 CTR 预估等高维稀疏数据场景下,FFM 的表现通常优于 FM。
  • 适用于广告推荐场景:FM 可能会遗漏某些特征组合的影响,而 FFM 能够更灵活地学习不同特征域的组合模式。

5. FFM 的计算复杂度

FFM 虽然增强了特征交互的表达能力,但也带来了计算和存储上的额外开销。

  • 存储方面,FM 的嵌入参数数量为 O ( n k ) O(nk) O(nk),而 FFM 由于为每个特征在不同特征域上都学习了独立的向量,参数规模变为 O ( n ⋅ m ⋅ k ) O(n\cdot m\cdot k) O(nmk)(其中 m为特征域数量)。
  • 另外在训练方面,由于不能像FM一样简化,训练复杂度为 O ( k n 2 ) O(kn^2) O(kn2)

在训练过程中,FFM 采用随机梯度下降(SGD)分布式优化方法来减少计算成本。此外,也可以采用一些降维策略,如利用哈希技巧来减少嵌入向量的维度。

6. 结论

因子分解机(FFM)通过引入特征域的概念,相较于 FM 提供了更精细的特征交叉建模能力,在 CTR 预估等任务中取得了显著的效果。然而,FFM 也带来了更高的计算和存储开销,因此在实际应用中需要权衡性能和效率。

Reference

  1. 王喆 《深度学习推荐系统》
  2. (二)FFM(Field-aware Factorization Machine)原理
  3. Field-aware Factorization Machines

http://www.kler.cn/a/531816.html

相关文章:

  • 读书笔记--分布式架构的异步化和缓存技术原理及应用场景
  • 【自学笔记】GitHub的重点知识点-持续更新
  • 【LLM-agent】(task2)用llama-index搭建AI Agent
  • Vue 3 30天精进之旅:Day 13 - 路由守卫
  • 深入理解 `box-sizing: border-box;`:CSS 布局的利器
  • 【零基础学JAVA】数据类型
  • 追逐低空经济,无人机研学技术详解
  • 【双指针题目】
  • Vue3学习笔记-Vue开发前准备-1
  • Rust场景示例:为什么要使用切片类型
  • Deep Sleep 96小时:一场没有硝烟的科技保卫战
  • 即梦(Dreamina)技术浅析(三):数据库与存储
  • 手写单例模式
  • Java循环操作哪个快
  • bootstrap.yml文件未自动加载问题解决方案
  • 【回溯+剪枝】优美的排列 N皇后(含剪枝优化)
  • 【游戏设计原理】98 - 时间膨胀
  • SpringBoot 引⼊MybatisGenerator
  • 【C++ STL】vector容器详解:从入门到精通
  • IBM Cognos Analytics配置LTPA SSO单点登录
  • 【02】智能合约与虚拟机
  • Node 服务器数据响应类型处理
  • SLAM技术栈 ——《视觉SLAM十四讲》学习笔记(一)
  • c++ stl 遍历算法和查找算法
  • BMS和无刷电机产品拆解学习
  • TryHackMe: TryPwnMe Two