当前位置: 首页 > article >正文

GNN多任务预测模型实现(二):将EXCEL数据转换为图数据

 

一. 引言

在图神经网络(Graph Neural Networks, GNNs)的多任务学习场景中,数据预处理是至关重要的一步。尤其是当我们的数据存储在表格格式(如Excel文件)中时,如何有效地将其转换为图数据格式,是搭建GNN模型的基础。


二. 加载和检查数据

第一步是加载数据并检查其格式。我们通常使用 pandas 库来读取和操作Excel文件。以下代码实现了从Excel文件中加载数据,并打印前几行以检查格式是否正确。

import pandas as pd

# 从Excel文件中读取数据
df = pd.read_excel('Participation_prediction_data.xlsx')

# 打印前几行数据以检查数据格式
print(df.head())

 知识点

  • pandas.read_excel():用于从Excel文件中读取数据并加载为DataFrame对象。
  • DataFrame.head():返回前5行数据,用于快速检查数据格式和内容。

三. 提取特征和标签

在机器学习任务中,我们通常需要将数据分为特征(features)和标签(labels)。在本例中,我们从表格中提取三列作为特征,并将“讨论参与”列作为标签。 

import numpy as np

# 提取特征列并转换为numpy数组
features = df[['讨论参与', '作业提交', '在线课堂出席时长']].values

# 提取标签列并确保其为整数类型
labels = df['讨论参与'].values.astype(int)

知识点

  • DataFrame的列选择:通过列名选择所需的列。
  • .values 属性:将DataFrame转换为NumPy数组。
  • 数据类型转换:通过 .astype() 将数据类型转换为所需的类型。

 四. 标准化特征

在模型训练之前,对特征进行标准化处理(即归一化到均值为0,标准差为1的范围)通常能够提高模型的收敛速度和性能。 

from sklearn.preprocessing import StandardScaler

# 初始化StandardScaler对象
scaler = StandardScaler()

# 对特征进行标准化处理
features = scaler.fit_transform(features)

知识点

  • StandardScaler:用于对数据进行标准化处理。
  • fit_transform():同时计算均值和标准差,并对数据进行标准化。

 五. 构建节点索引

在图数据中,每个节点通常需要一个唯一的索引。在本例中,我们使用DataFrame的索引作为节点的索引。 

# 获取所有节点的索引
node_indices = df.index.values

知识点

  • DataFrame.index.values:返回DataFrame的索引,通常是一个NumPy数组。

六. 构建边及其特征 

边的构建是图数据生成的关键步骤。在本例中,我们根据以下条件构建边:

  • 两个节点的时间差为1。
  • 两个节点属于同一学生。

同时,我们为每条边定义了特征,包括源节点的“讨论参与”、“作业提交”和“在线课堂出席时长”。

import torch

# 初始化边列表和边特征列表
edges = []
edge_features = []

for i in range(len(node_indices)):
    for j in range(len(node_indices)):
        # 如果两个节点的时间差为1且属于同一学生,则添加一条边
        if abs(df.loc[i, '时间'] - df.loc[j, '时间']) == 1 and df.loc[i, '学生id'] == df.loc[j, '学生id']:
            edges.append([i, j])
            edge_features.append([
                df.loc[i, '讨论参与'],
                df.loc[i, '作业提交'],
                df.loc[i, '在线课堂出席时长']
            ])

# 确保边列表和边特征列表长度一致
if len(edges) != len(edge_features):
    raise ValueError("Edge list and edge feature list should have the same length.")

# 将边列表转换为张量,并转置为 [2, num_edges] 形状
edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()

# 将边特征列表转换为张量
edge_attr = torch.tensor(edge_features, dtype=torch.float)

知识点

  • 嵌套循环:用于遍历所有可能的节点对。
  • DataFrame.loc[]:用于按索引访问DataFrame中的行。
  • 条件判断:用于确定是否添加一条边。
  • torch.tensor():将Python列表转换为PyTorch张量。
  • .t() 和 .contiguous():用于将边列表张量转置为 [2, num_edges] 形状。

七. 总结 

 通过上述步骤,我们成功地将Excel表格数据转换为了图数据格式,包括节点特征、节点索引、边列表和边特征。这些数据可以直接输入到GNN模型中进行训练和预测。

import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
import torch

# 从Excel文件中读取数据
df = pd.read_excel('Participation_prediction_data.xlsx')

# 打印前几行数据以检查数据格式
print(df.head())

# 提取特征列并转换为numpy数组
features = df[['讨论参与', '作业提交', '在线课堂出席时长']].values

# 提取标签列并确保其为整数类型
labels = df['讨论参与'].values.astype(int)

# 初始化StandardScaler对象
scaler = StandardScaler()

# 对特征进行标准化处理
features = scaler.fit_transform(features)

# 获取所有节点的索引
node_indices = df.index.values

# 初始化边列表和边特征列表
edges = []
edge_features = []

for i in range(len(node_indices)):
    for j in range(len(node_indices)):
        if abs(df.loc[i, '时间'] - df.loc[j, '时间']) == 1 and df.loc[i, '学生id'] == df.loc[j, '学生id']:
            edges.append([i, j])
            edge_features.append([
                df.loc[i, '讨论参与'],
                df.loc[i, '作业提交'],
                df.loc[i, '在线课堂出席时长']
            ])

if len(edges) != len(edge_features):
    raise ValueError("Edge list and edge feature list should have the same length.")

edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
edge_attr = torch.tensor(edge_features, dtype=torch.float)

八. 结语

通过将Excel数据转换为图数据格式,我们为后续的GNN模型搭建和训练奠定了基础!


http://www.kler.cn/a/535242.html

相关文章:

  • 深入理解 Java 接口的回调机制 【学术会议-2025年人工智能与计算智能(AICI 2025)】
  • 哈希(Hashing)在 C++ STL 中的应用
  • vscode 如何通过Continue引入AI 助手deepseek
  • 在 Mac M2 上安装 PyTorch 并启用 MPS 加速的详细教程与性能对比
  • vs code 使用教程
  • IOPS与吞吐量、读写块大小及延迟之间的关系
  • 数据实时推送至前端的主流方法总结
  • 为何实现大语言模型的高效推理以及充分释放 AI 芯片的计算能力对于企业级落地应用来说,被认为具备显著的研究价值与重要意义?
  • 面向对象程序设计-实验1
  • 快速单机部署ollama v0.5.7 +openwebui(免去网络环境干扰)
  • 【后端开发】系统设计101——Devops,Git与CICD,云服务与云原生,Linux,安全性,案例研究(30张图详解)
  • 下标为3的倍数
  • 解锁C#数据校验:从基础到实战的进阶之路
  • 日志模块自定义@SkipLogAspect注解跳过切面
  • 三格电子-单串口服务器说明
  • [paddle] 矩阵乘法
  • 高性能音频分析仪,音频分析器、国产音频分析仪
  • QUIC协议详解
  • ES6- 代码编程风格(let、字符串、解构赋值)
  • 所遇皆温柔,佛系过生活
  • pycharm集成通义灵码应用
  • 【PyTorch】解决Boolean value of Tensor with more than one value is ambiguous报错
  • leetcode——组合总和(回溯算法详细讲解)
  • DNN(深度神经网络)近似 Lyapunov 函数
  • 解锁反序列化漏洞:从原理到防护的安全指南
  • 【OpenCV插值算法比较】