当前位置: 首页 > article >正文

基于图卷积网络的轻量化推荐模型(论文复现)

基于图卷积网络的轻量化推荐模型(论文复现)

本文所涉及所有资源均在传知代码平台可获取

概述

图卷积网络(Graph Convolution Network,GCN)已经广泛的应用于推荐系统,基于GCN的协同过滤算法(例如NGCF)缺少消融研究,此模型对NGCF进行了消融实验并提出了轻量化卷积网络。传统的GCN推荐模型(以NGCF为例)

在这里插入图片描述

其中的线性变换和非线性激活函数导致模型庞大,速度很慢,难于理解。
通过消融实验,去掉线性变换W和非线性激活函数σ,得到以下结果:

在这里插入图片描述

可以看到,去掉fn的recall和ndcg在两个常用数据集上的效果更好。
本模型的优势在于,轻量化了NGCF模型,在参数更小,速度更快的基础上,还提升了性能

模型讲解

在这里插入图片描述

模型集合了Item和User的邻居信息,切只保留这部分信息,通过多层的GCN,最后求均值,得到了最终的u、i向量,最后进Prediction

模型公式:

在这里插入图片描述

目标函数:

在这里插入图片描述

演示效果

在这里插入图片描述

其中precision、recall、ndcg为模型评判标准,epoch为迭代次数(可改参数)、loss为损失,Sample为节点覆盖率

核心逻辑

核心代码逻辑:

class LightGCN(BasicModel):
    def __init__(self,
                 config:dict,
                 dataset:BasicDataset):
        super(LightGCN, self).__init__()
        self.config = config
        self.dataset : dataloader.BasicDataset = dataset
        self.__init_weight()
        self.attention_layer = AttentionLayer(input_dim=64)
        self.mlp = MLP(input_dim=64)
#        self.contrast = Contrast(64, 0.5, 0.5)

    def __init_weight(self):
        self.num_users  = self.dataset.n_users
        self.num_items  = self.dataset.m_items

        self.latent_dim = self.config['latent_dim_rec']
        self.n_layers = self.config['lightGCN_n_layers']
        self.keep_prob = self.config['keep_prob']
        self.A_split = self.config['A_split']

        self.embedding_user = torch.nn.Embedding(
            num_embeddings=self.num_users, embedding_dim=self.latent_dim)
        self.embedding_item = torch.nn.Embedding(
            num_embeddings=self.num_items, embedding_dim=self.latent_dim)


        if self.config['pretrain'] == 0:
            #nn.init.xavier_uniform_(self.embedding_user.weight, gain=1)
            #nn.init.xavier_uniform_(self.embedding_item.weight, gain=1)
            #print('use xavier initilizer')
            # random normal init seems to be a better choice when lightGCN actually don't use any non-linear activation function
            nn.init.normal_(self.embedding_user.weight, std=0.1)
            nn.init.normal_(self.embedding_item.weight, std=0.1)

            world.cprint('use NORMAL distribution initilizer')
        else:
            self.embedding_user.weight.data.copy_(torch.from_numpy(self.config['user_emb']))
            self.embedding_item.weight.data.copy_(torch.from_numpy(self.config['item_emb']))

            print('use pretarined data')
        self.f = nn.Sigmoid()
        self.Graph = self.dataset.getSparseGraph()
        print(f"lgn is already to go(dropout:{self.config['dropout']})")

        # print("save_txt")

核心逻辑就是去掉传统图卷积中的非线性激活函数和线性变换,轻量化了模型,只保留了图的语义信息,目标函数选择了BPRLOSS

使用方式

在这里插入图片描述

首先在/data文件中导入items和user数据,运行data_init.py文件进行数据初始化

在这里插入图片描述

在parse.py中修改模型参数;运行main.py

部署方式

python3.8即可,拥有pytorch环境;搭建环境

 pip install -r requirements.txt

文章代码资源点击附件获取


http://www.kler.cn/news/309469.html

相关文章:

  • 【Docker】docker的一些常用命令
  • 看Threejs好玩示例,学习创新与技术(二)
  • 创建索引遇到这个Bug,19c中还没有修复
  • echarts 自定义标注样式自定义tooltip弹窗样式
  • Redisson实现分布式锁(看门狗机制)
  • 【MySQL-初级】mysql基础操作(账户、数据库、表的增删查改)
  • 软考中级软件设计师——知识产权学习记录
  • Android Activity分屏设置
  • vue3前端开发-小兔鲜超市-本地购物车列表页面的统计计算
  • 新增的标准流程
  • Codeforces practice C++ 2024/9/11 - 2024/9/18
  • 常见数据湖的优劣对比
  • Rust表达一下中秋祝福,群发问候!
  • Spring Boot-依赖冲突问题
  • Verdin AM62 引脚复用配置
  • 检查和测绘室内防撞无人机技术详解
  • 机器学习的网络们
  • mysql 8.0 搭建主从集群注意事项
  • 从登录到免登录:JSP与Servlet结合Cookie的基本实现
  • react 组件通讯
  • 面试题篇: 跨域问题如何处理(Java和Nginx处理方式)
  • Linux 使用 tar 命令
  • C++掉血迷宫
  • 在vmvare安装飞牛私有云 fnOS体验教程
  • 自动化测试框架pytest命令参数
  • 如何在@GenericGenerator中显式指定schema
  • 友思特方案 | 搭建红外桥梁:嵌入式视觉接口助力红外热像仪传输
  • SpringBoot入门(黑马)
  • 【数据可视化】Arcgis api 4.x 专题图制作之分级色彩,采用自然间断法(使用simple-statistics JS数学统计库生成自然间断点)
  • 0072__ActiveX插件的使用