当前位置: 首页 > article >正文

推荐系统(十六):基于ESMM的商品召回/推荐系统

《Entire Space Multi-Task Model: An Effective Approach for Estimating Post-Click Conversion Rate》是阿里妈妈广告算法团队于 2018 年发表的一篇论文,论文提出了ESMM 模型(Entire Space Multi-Task Model),基于用户行为序列和多任务学习思想有效解决广告转化率(Conversion Rate,CVR)预估的样本选择偏差(Sample Selection Bias,SSB)和数据稀疏(Data Sparsity,DS)问题,并在淘宝推荐广告的公开和线上数据集上取得了当时最好的效果。本文笔者将基于 ESMM 思想打造一个商品召回/推荐系统 DEMO,麻雀虽小,五脏俱全。

1.系统架构

和上一篇文章《推荐系统(十五):基于双塔模型的多目标商品召回/推荐系统》不同,本文模型基于 ESMM(Entire Space Multi-Task Model)思想,如下图所示,用户塔和物品塔不共用,主要原因在于多任务学习的特性、任务目标差异及模型设计目标相关。以下是具体分析:

1 任务目标差异

  • CTR(点击率预估):关注用户对物品的点击行为,目标是在全曝光样本空间中预测用户点击概率。
  • CVR(转化率预估):关注用户点击后的转化行为(如购买、注册等),目标是在点击样本空间中预测用户转化概率。
  • CTCVR(点击且转化率预估):通过公式 pCTCVR = pCTR*pCVR 间接建模,需同时考虑点击和转化的联合概率。

CTR和CVR的任务目标不同,导致两者对用户和物品特征的关注点存在差异。例如:CTR任务可能更关注物品的吸引力特征(如标题、图片);CVR任务则需关注物品的转化相关特征(如价格、详情页信息)。

2. 样本空间的差异

  • CTR任务:训练数据为全曝光样本(包含点击和未点击样本)。
  • CVR任务:训练数据仅包含点击样本(即点击后的转化行为)。

若共用物品塔/用户塔,模型需在同一塔内同时适配全曝光样本和点击样本的特征分布,可能导致特征表示被稀释或干扰,影响模型性能。

3. 多任务学习的独立性要求

  • 跷跷板现象(Seesaw Phenomenon):若任务间相关性较弱,共享过多参数可能导致一个任务性能提升以牺牲另一个任务为代价(如 MMoE 中提到的现象)。
    ESMM的设计:通过独立塔结构保留任务特异性,避免负迁移(Negative Transfer)。共享底层 Embedding层(如用户ID、物品ID的Embedding)以缓解数据稀疏性,但上层网络独立,确保任务间差异被有效捕捉。

4. 特征交互的差异性

  • CTR 任务的特征交互可能更偏向于用户兴趣与物品曝光的匹配(如用户历史点击与物品标题的关联)。
  • CVR 任务的特征交互需捕捉用户决策行为与物品转化属性的关联(如用户购买力与物品价格的匹配)。

独立塔结构允许模型在不同任务中学习不同的高阶特征交互模式,提升任务适配性。

5. 模型复杂度与灵活性

共用塔结构会限制模型对不同任务的表达能力,而独立塔可通过调整网络深度、宽度等参数灵活适配任务需求。ESMM通过共享Embedding层减少参数量,同时通过独立塔保持多任务建模的灵活性,平衡模型复杂度和效果。

在这里插入图片描述

2.核心实现步骤

2.1 模拟数据构造

"""
Part-1:模拟数据构造

本部分模拟真实场景,人工构造用户数据、商品数据、用户-商品交互数据(点击、购买),并进行必要的预处
"""
# 计算机生成的随机数本质是伪随机数,由算法基于初始种子值(seed)生成固定序列。设置相同的种子会得到相同的随机数序列
# 在机器学习中,随机性会影响模型训练、数据划分(如训练集/测试集分割)、参数初始化等环节。设置种子后,多次运行代码会得到相同结果,便于调试和验证
np.random.seed(42)
tf.random.set_seed(42)
# 用户特征:用户ID、年龄、性别、职业
# 商品特征:商品ID、类别、品牌、价格
num_users = 100
num_items = 200
num_interactions = 1000

# 生成用户特征
user_data = {
   
    'user_id': np.arange(1, num_users + 1),
    'user_age': np.random.randint(18, 65, size=num_users),
    'user_gender': np.random.choice(['male', 'female'], size=num_users),
    'user_occupation': np.random.choice(['student', 'worker', 'teacher'], size=num_users),
    'city_code': np.random.randint(1, 2856, size=num_users),  # 城市编码,中国有 2856 个城市
    'device_type': np.random.randint(0, 5, size=num_users)  # 设备类型(0=Android,1=iOS等)
}

# 生成商品特征
item_data = {
   
    'item_id': np.arange(1, num_items + 1),
    'item_category': np.random.choice(['electronics', 'books', 'clothing'], size=num_items),
    'item_brand': np.random.choice(['brandA', 'brandB', 'brandC'], size=num_items),
    'item_price': np.random.randint(1, 199, size=num_items)
}

# 生成用户-商品交互数据
# 包括:点击和转化(购买)数据
interactions = []
for _ in range(num_interactions):
    user_id = np.random.randint(1, num_users + 1)
    item_id = np.random.randint(1, num_items + 1)
    # 点击标签。0: 未点击, 1: 点击。在真实场景中可通过客户端埋点上报获得用户的点击行为数据
    click_label = np.random.randint(0, 2)
    # 转化标签。由于转化的前提是点击,因此点击和转化之间是一个漏斗关系——转化显著低于点击
    conversion_label = 0
    if click_label == 1:
        conversion_label = np.random.binomial(1, 0.3)  # 假设点击后30%转化率
    interactions.append([user_id, item_id, click_label, conversion_label])

interaction_df = pd.DataFrame(
    interactions,
    columns=['user_id', 'item_id', 'click_label', 'conversion_label'])

# 合并用户特征、商品特征和交互数据
user_df = pd.DataFrame(user_data)
item_df = pd.DataFrame(item_data)
df = interaction_df.merge(user_df, on='user_id').merge(item_df, on='item_id')
df['ctcvr_label'] = df['click_label'] * df['conversion_label']  # 新增CTCVR标签

# 划分训练集和测试集
labels = df[['click_label', 'conversion_label', 'ctcvr_label']]
features = df.drop(['click_label', 'conversion_label', 'ctcvr_label'], axis=1)
train_features, test_features, train_labels, test_labels = train_test_split(
    features, labels,
    test_size=0.2,
    random_state=42
)

2.2 特征工程

对不同类型的特征进行差异化处理:
在这里插入图片描述

"""
Part-2:特征工程

本部分对原始用户数据、商品数据、用户-商品交互数据进行分类处理,加工为模型训练需要的特征
    1.数值型特征:如用户年龄、价格,少数场景下可直接使用,但最好进行标准化,从而消除量纲差异
    2.类别型特征:需要进行 Embedding 处理
    3.ID类特征:如用户ID、商品ID,属于高维稀疏特征,需要embedding处理为低维稠密特征

关于 Embedding 处理:
    1.无论是通过tf.keras.layers.Embedding还是feature_column.embedding_column,Embedding层的初始值通常是随机生成的(例如均匀分布或截断正态分布)
    2.在模型训练过程中,Embedding向量会通过反向传播不断更新,使得模型的预测结果(如用户-物品相似度)与目标(如点击标签)更接近
    3.训练后的Embedding向量会收敛到某种有意义的表示,与初始化的随机值完全不同

关于标准化处理:
    1.如果使用 feature_column 的 normalizer_fn:模型自动处理,无需手动干预
    2.如果手动标准化:必须保存训练阶段的参数(均值和标准差),并在预测时加载这些参数进行标准化
"""
""" 
用户特征预处理 
"""
# 高维稀疏特征处理
# 过程:先将用户ID定义为类别型特征,num_buckets=num_users 表示用户ID的取值范围是 [0, num_users-1] 的整数;然后,embedding处理
# 注意:在模拟数据中用户和商品数量较少(100用户/200商品),直接使用 ID embedding 容易导致尾部 ID 无法充分训练
# 双塔模型通常需要权衡记忆(ID特征)与泛化(属性特征)能力
user_id = feature_column.categorical_column_with_identity('user_id', num_buckets=num_users)
user_id_emb = feature_column.embedding_column(user_id, dimension=8)
# 数值特征处理
# StandardScaler 是 Scikit-Learn 提供的标准化工具,它会将数据转换为均值为 0、标准差为 1 的分布。
# 标准化(或采用归一化)可以消除不同特征间的量纲差异(例如年龄范围是 0-100,价格范围是 0-10000),使模型训练更稳定
scaler_age = StandardScaler()
df['user_age'] = scaler_age.fit_transform(df[['user_age']])
user_age = feature_column.numeric_column('user_age')

# 类别特征处理
# 先映射,后嵌入,生成低维稠密向量
# 将性别字符串(如“male”“female”)映射为整数ID,输入数据中的性别字符串会被转换为 0(male)或 1(female),然后进行嵌入转换,生成低维稠密向量
user_gender = feature_column.categorical_column_with_vocabulary_list(
    'user_gender', ['male', 'female'])
user_gender_emb = feature_column.embedding_column(user_gender, dimension=2)
# 将职业字符串映射为整数ID(如“student”→0,“worker”→1,依此类推),然后进行嵌入转换,生成低维稠密向量
user_occupation = feature_column.categorical_column_with_vocabulary_list(
    'user_occupation', ['student', 'worker', 'teacher'])
user_occupation_emb = feature_column.embedding_column(user_occupation, dimension=2)
# 用户所在城市编码embedding
# 城市ID的可能取值范围(1到2855,共2855个值,需设置为max_id + 1)
city_code_column = feature_column.categorical_column_with_identity(key='city_code', num_buckets=2856)
city_code_emb = feature_column.embedding_column(city_code_column, dimension=8)
# 用户设备编码embedding
device_types_column = feature_column.categorical_column_with_identity(key='device_type', num_buckets=5)
device_types_emb = feature_column.embedding_column(device_types_column, dimension=8)

""" 
商品特征预处理 
"""
# 高维稀疏特征处理
# 与 user_id 类似,商品ID被定义为 [0, num_items-1] 的整数类别
item_id = feature_column.categorical_column_with_identity('item_id', num_buckets=num_items)
item_id_emb = feature_column.embedding_column(item_id, dimension=8)

# 数值特征处理
# StandardScaler 是 Scikit-Learn 提供的标准化工具,它会将数据转换为均值为 0、标准差为 1 的分布。
# 标准化(或采用归一化)可以消除不同特征间的量纲差异(例如年龄范围是 0-100,价格范围是 0-10000),使模型训练更稳定
scaler_price = StandardScaler()
df['item_price'] = scaler_price.fit_transform(df[['item_price']])
item_price = feature_column.numeric_column('item_price')

# 类别特征处理
# 先映射,后嵌入,生成低维稠密向量
# 分别将商品类别和品牌字符串映射为整数ID,(如“electronics”→0,“books”→1,依此类推),然后进行嵌入转换,生成低维稠密向量
item_category = feature_column.categorical_column_with_vocabulary_list(
    'item_category', ['electronics', 'books', 'clothing'])
item_category_emb = feature_column.embedding_column(item_category, dimension=2)
item_brand = feature_column.categorical_column_with_vocabulary_list(
    'item_brand', ['brandA', 'brandB', 'brandC'])
item_brand_emb = feature_column.embedding_column(item_brand, dimension=2)

# 打印前几行数据以查看结构(设置display.max_columns为None,强制显示全部列)
pd.set_option('display.max_columns', None)
print(df.head())

2.3 ESMM 模型架构设计

ESMM(Entire Space Multi-Task Model)是阿里巴巴提出的多任务学习框架,通过联合建模点击率(CTR)和点击转化率(CVR)解决传统CVR模型样本稀疏问题。其核心思想是将 CVR 定义为 CTR 与 CTCVR 的条件概率,模型包含两个子任务:

  • CTR 预测任务:使用全量曝光样本训练
  • CTCVR 预测任务:通过CTR*CVR乘积形式建模
"""
Part-3:模型架构设计
"""
# 用户特征列
user_tower_columns = [
    user_id_emb,
    user_age,
    user_gender_emb,
    user_occupation_emb,
    city_code_emb,
    device_types_emb
]

# 商品特征列
item_tower_columns = [
    item_id_emb,
    item_category_emb,
    item_brand_emb,
    item_price
]


# 自定义多任务模型(基于ESMM模型)
def model_fn(features, labels, mode, params):
    """
    自定义多任务模型:基于TensorFlow Estimator的多任务学习模型,主要用于同时预测点击率(CTR)和点击转化率(CTCVR)
    ESMM 通过引入全样本空间建模解决CVR样本稀疏问题,核心包含两个子任务:
        1.CTR任务:预测点击率(全量样本参与训练)
        2.CTCVR任务:预测点击后转化率(CTR * CVR,全量样本参与训练)
    通过CTCVR任务间接训练CVR模型,使得CVR模型能利用全量曝光样本而非仅点击样本
    """
    # 通过 DenseFeatures 层,将不同的特征列(用户塔和商品塔)转换为模型可用的输入
    

http://www.kler.cn/a/612826.html

相关文章:

  • SpringBoot学习Day1
  • Appium 入门操作指南
  • 地理信息可视化技术大全【WebGIS 技术文档大全】
  • Nginx多域名HTTPS配置全攻略:从证书生成到客户端安装
  • 【矩阵快速幂】P2100 凌乱的地下室|省选-
  • UE4学习笔记 FPS游戏制作31 显示计分板
  • 31天Python入门——第16天:模块与库详解
  • 正则表达式-笔记
  • ArayTS:一个功能强大的 TypeScript 工具库
  • Docker 快速入门指南
  • 路由器、交换机、防火墙、服务器、负载均衡在网络中作用
  • 第三课:Stable Diffusion图生图入门及应用
  • d2025328
  • OSPF邻居状态机
  • Java-servlet(十)使用过滤器,请求调度程序和Servlet线程(附带图谱表格更好对比理解)
  • AIGC-评论金句引流回复创作智能体完整指令(DeepSeek,豆包,千问,Kimi,GPT)
  • vueRouter的hash模式跟history的区别
  • 鸿蒙篇:vp、fp、px
  • Java 大视界 -- Java 大数据在智慧港口集装箱调度与物流效率提升中的应用创新(159)
  • Typora使用Gitee作为图床