当前位置: 首页 > article >正文

Wide Deep 模型:记忆能力与泛化能力

实验和完整代码

完整代码实现和jupyter运行:https://github.com/Myolive-Lin/RecSys--deep-learning-recommendation-system/tree/main

引言

Wide & Deep 模型是一种结合了线性模型(Wide)和深度神经网络(Deep)的混合架构,以结合记忆(Memorization)泛化(Generalization) 能力,并有效解决了推荐系统中的稀疏高秩特征交互问题。该模型最初由 Google的Cheng 等人(2016) 提出,广泛应用于推荐系统、广告点击率预测等领域

1. 问题背景

1.1 推荐系统的核心挑战

推荐系统的核心任务是预测用户对物品的交互概率:
P ( y = 1 ∣ x ) = σ ( f ( x ) ) P(y=1|\mathbf{x}) = \sigma(f(\mathbf{x})) P(y=1x)=σ(f(x))
其中 x \mathbf{x} x 包含用户特征、上下文特征和物品特征。关键挑战在于:

  • 记忆:捕获历史数据中频繁共现的特征组合
  • 泛化:探索稀疏甚至未见过的特征与目标相关性的能力

这里介绍一下这两种能力:

1. 记忆能力:记住“强规则”的能力

  • 什么是记忆能力
    模型像学生背公式一样,能直接记住历史数据中频繁出现的“特征组合”与结果的关系。例如:

    • 用户安装了 Netflix(特征A) + 看到过 Pandora(特征B) → 安装Pandora的概率极高(比如10%,而平均安装率仅1%)。
      这种强关联性会被模型直接捕捉,形成类似“看到A就推荐B”的规则。
  • 哪些模型擅长记忆能力
    简单模型(如逻辑回归、协同过滤)是记忆能力的“尖子生”。
    原因:模型结构简单,特征权重直接决定结果。例如,逻辑回归遇到“Netflix & Pandora”组合时,只需给这个特征分配一个很大的权重,就能记住这条规则。

  • 实际场景案例
    在Google Play推荐系统中,如果某个用户安装了视频类应用(如Netflix),同时历史数据显示这类用户安装音乐应用(如Pandora)的概率很高,模型会直接记住这种关联,优先推荐音乐类应用。

2.泛化能力:指模型在面对从未见过的稀疏特征时,能够捕捉到这些特征与最终标签之间的潜在关联。

模型能通过特征之间的隐含联系,推广到从未见过的场景。例如:

  • 例如,矩阵分解方法(如隐因子模型)比传统的协同过滤算法具有更强的泛化能力,因为它通过学习用户和物品的隐向量,使得稀疏数据(如用户与特定物品之间的互动较少)也能通过这些隐向量得到合适的推荐得分。在这种情况下,即使是稀疏的特征组合,模型也能通过隐向量进行预测,从而获得稳定的推荐。

  • 深度神经网络的泛化能力更为强大,因为它通过多个隐藏层的非线性变换,可以发掘数据中更为复杂和深层的潜在模式。即使是非常稀疏的特征向量输入,经过多层组合后,网络也能够平滑地输出推荐概率,这种能力是简单模型无法实现的。

1.2 传统方法的局限

方法优势缺陷
线性模型 (Wide)显式特征交叉,可解释性强依赖特征工程,无法泛化
深度模型 (Deep)自动学习特征交互,泛化能力强对稀疏高秩数据易过泛化

模型架构如下
在这里插入图片描述

2. Wide & Deep 模型架构

模型细节如下:
在这里插入图片描述

2.1 Wide 组件

广义线性模型
y w i d e = w w T [ x , ϕ ( x ) ] + b w y_{wide} = \mathbf{w}_w^T[\mathbf{x}, \phi(\mathbf{x})] + b_w ywide=wwT[x,ϕ(x)]+bw
其中:

  • x \mathbf{x} x:原始稀疏特征
  • ϕ ( x ) \phi(\mathbf{x}) ϕ(x):交叉特征变换,定义为:
    ϕ k ( x ) = ∏ i = 1 d x i c k i c k i ∈ { 0 , 1 } \phi_k(\mathbf{x}) = \prod_{i=1}^d x_i^{c_{ki}} \quad c_{ki} \in \{0,1\} ϕk(x)=i=1dxickicki{0,1}
  • c k i c_{ki} cki 为特征选择指示函数

数学意义
Wide模型本质上是一个线性模型,用于处理特征之间的交互。它通过引入交叉特征(例如“用户已安装Netflix”和“推荐应用为Pandora”)来捕捉频繁出现的特征组合。交叉特征可以看作是特征之间的乘积,表示多个特征同时为真时的情况。 例如,假设有两个特征:性别(gender)和语言(language),它们的交叉特征可以表示为“性别=女性 且 语言=英语”,当这两个条件同时满足时,交叉特征的值为1。 通过这种交叉变换,宽模型能够有效地记忆那些频繁出现的特征交互,确保在推荐系统中准确预测用户的常见行为。


2.2 Deep 组件

前馈神经网络

  1. 嵌入层:将稀疏特征映射为低维稠密向量
    e i = Embedding ( x i ) ∈ R m \mathbf{e}_i = \text{Embedding}(x_i) \in \mathbb{R}^m ei=Embedding(xi)Rm
  2. 隐层计算:
    a ( l + 1 ) = ReLU ( W ( l ) a ( l ) + b ( l ) ) \mathbf{a}^{(l+1)} = \text{ReLU}(\mathbf{W}^{(l)}\mathbf{a}^{(l)} + \mathbf{b}^{(l)}) a(l+1)=ReLU(W(l)a(l)+b(l))
  3. 最终输出:
    y d e e p = w d T a ( L ) + b d y_{deep} = \mathbf{w}_d^T \mathbf{a}^{(L)} + b_d ydeep=wdTa(L)+bd
    意义:深度模型通过将类别特征映射到低维的稠密嵌入空间来实现对稀疏数据的泛化。每个类别特征会被映射为一个稠密向量,这些向量在训练过程中被优化,以捕捉特征之间的隐藏关系。嵌入向量的维度通常是几百个维度,这比原始的稀疏特征空间要小得多。 在深度模型中,经过嵌入层后,特征被传递到多个隐藏层,每个隐藏层通过一个非线性激活函数(如ReLU)进行计算。深度模型的目标是通过这些隐藏层的变换来学习更复杂的特征交互,并提高对未见特征组合的预测能力。

2.3 联合训练

概率输出
P ( y = 1 ∣ x ) = σ ( y w i d e + y d e e p + b ) P(y=1|\mathbf{x}) = \sigma(y_{wide} + y_{deep} + b) P(y=1x)=σ(ywide+ydeep+b)

3. 代码实现

3.1 特征工程

特征类型处理方式
类别特征哈希分桶 + 嵌入层
连续特征分位数归一化
交叉特征笛卡尔积生成

3.2 Code

class WideAndDeep(nn.Module):
    def __init__(self, num_wide, num_deep_dim, cat_deep_dims, cross_feature_indices,  hidden_units, embedding_dim):
        """
        Args:
            num_wide: Wide部分的特征数量
            cross_feature_indices: 需要交叉的特征索引列表,格式 [(i,j), (k,l), ...]
            num_deep_dim: Deep部分的数值型特征数量
            cat_deep_dims: Deep部分的类别特征维度列表
            hidden_units: 深度网络隐藏层维度列表
            embedding_dim: 嵌入维度
        """
        super(WideAndDeep, self).__init__()

        #wide 部分
        self.cross_indices = cross_feature_indices
        self.wide = nn.Linear(num_wide + len(cross_feature_indices), 1)

        #deep 部分
        self.embeddings = nn.ModuleList([
            nn.Embedding(dim, embedding_dim) 
            for dim in cat_deep_dims
        ])

        # 计算Deep部分输入维度
        deep_input_dim = num_deep_dim + len(cat_deep_dims) * embedding_dim

        #Deep 部分全连接层
        self.dnn = nn.Sequential()

        for i,unit in enumerate(hidden_units):
            self.dnn.add_module(
                name = f"fc_{i}",
                module = nn.Linear(deep_input_dim, unit)
            )

            self.dnn.add_module(
                name = f"relu_{i}",
                module = nn.ReLU()
            )

            deep_input_dim = unit

        #最终组合层
        self.final = nn.Linear(deep_input_dim + 1, 1) #wide部分和deep部分的输出相加

    def create_cross_features(self, x):
        """动态生成交叉特征"""
        cross_features = []
        for i, j in self.cross_indices:
            feature = x[:, i] * x[:, j]
            cross_features.append(feature.unsqueeze(1))

        return torch.cat(cross_features, dim=1)

    
    def forward(self, num_x, deep_x ,cat_x,):
        """
        Args:
            num_x: wide部分数值型特征
            deep_x: deep网络数值型特征
            cat_x: deep网络类别型特征
        """
        num_x = num_x.float()       # 确保转为float32
        deep_x = deep_x.float()     # 确保转为float32

        cross_features = self.create_cross_features(num_x)
        wide_input = torch.cat([num_x, cross_features], dim=1)
        wide_output = self.wide(wide_input)

        #deep部分
        embeds = [] 
        for i in range(len(cat_x[0])):
            embed = self.embeddings[i](cat_x[:, i])
            embeds.append(embed)
        
        deep_input = torch.cat(embeds, dim=1)
        deep_input = torch.cat([deep_input, deep_x], dim=1) #数值型特征和类别型特征拼接
        deep_output = self.dnn(deep_input)

        #wide和deep部分输出相加
        output = torch.cat([wide_output, deep_output], dim=1)
        output = self.final(output) #[batch_size,1]
        return torch.sigmoid(output).squeeze()



Reference

  1. Cheng, H. T., Koc, L., Harmsen, J., Shaked, T., Chandra, T., Aradhye, H., … & Ispir, M. (2016). Wide & Deep Learning for Recommender Systems. arXiv preprint arXiv:1606.07792.
  2. 王喆《深度学习推荐系统》

http://www.kler.cn/a/532643.html

相关文章:

  • 10.8 LangChain Output Parsers终极指南:从JSON解析到流式处理的规范化输出实践
  • OpenGL学习笔记(六):Transformations 变换(变换矩阵、坐标系统、GLM库应用)
  • 【R语言】R语言安装包的相关操作
  • vscode软件操作界面UI布局@各个功能区域划分及其名称称呼
  • 基于人脸识别的课堂考勤系统
  • C语言:链表排序与插入的实现
  • NSSCTF Pwn [SWPUCTF 2022 新生赛]shellcode?题解
  • 网安学习xss和php反序列后的心得
  • minikube 的 Kubernetes 入门教程--Dify
  • [C++]C++中的常见异常和自定义异常
  • 半导体器件与物理篇6 MESFET
  • 解释 Java 中的垃圾回收机制,以及如何优化垃圾回收性能?
  • directx12 3d开发过程中出现的报错 一
  • Python 与 PostgreSQL 集成:深入 psycopg2 的应用与实践
  • 排序算法--计数排序
  • 【NLP 20、Encoding编码 和 Embedding嵌入】
  • 文字加持:让 OpenCV 轻松在图像中插上文字
  • 逻辑运算短路现象记录
  • PostCss
  • 关于deepseek的一些普遍误读
  • Vant框架:助力移动端开发的利器
  • SpringBoot 连接Elasticsearch带账号密码认证 ES连接 加密连接
  • 7.2.背包DP
  • 获取 ARM Cortex - M 系列处理器中 PRIMASK 寄存器的值
  • Azure DevOps Server:集成奇安信开源卫士(OpenSourceSafe)
  • 16 旋转操作模块(rotation.rs)