简介
1.1 Overview
Open Graph Benchmark(以下简称 OGB)是斯坦福大学的同学开源的 Python 库,其包含了图机器学习(以下简称图 ML)的基准数据集、数据加载器和评估器,目的在于促进可扩展的、健壮的、可复现的图 ML 的研究。
OGB 包含了多种图机器学习的多种任务,并且涵盖从社会和信息网络到生物网络,分子图和知识图的各种领域。没有数据集都有特定的数据拆分和评估指标,从而提供统一的评估协议。
OGB 提供了一个自动的端到端图 ML 的 pipeline,该 pipeline 简化并标准化了图数据加载,实验设置和模型评估的过程。如下图所示:
下图展示了 OGB 的三个维度,包括任务类型(Tasks)、可扩展性(Scale)、领域(Rich domains)。
来看一下 OGB 现在包含的数据集:
和数据集的统计明细:
OGB 也提供了标准化的评估人员和排行榜,以跟踪最新的结果,我们来看下不同任务下的部分 Leaderboard。
节点分类:
链接预测:
图分类:
官方给出的例子都是基于 PyG 实现的,我们这里实现一个基于 DGL 例子。
2.1 环境准备导入数据包
import dgl
import ogb
import math
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator
查看版本
print(dgl.__version__)
print(torch.__version__)
print(ogb.__version__)
0.4.3post2 1.5.0+cu101 1.1.1
cuda 相关信息
print(torch.version.cuda) print(torch.cuda.is_available()) print(torch.cuda.device_count()) print(torch.cuda.get_device_name(0))
10.1 True 1 Tesla P100-PCIE-16GB 0
2.2 数据准备设置参数
device_id=0 # GPU 的使用 id n_layers=3 # 输入层 + 隐藏层 + 输出层的数量 n_hiddens=256 # 隐藏层节点的数量 dropout=0.5 lr=0.01 epochs=300 runs=10 # 跑 10 次,取平均 log_steps=50
定义训练函数、测试函数和日志记录
def train(model, g, feats, y_true, train_idx, optimizer): """ 训练函数 """ model.train() optimizer.zero_grad() out = model(g, feats)[train_idx] loss = F.nll_loss(out, y_true.squeeze(1)[train_idx]) loss.backward() optimizer.step() return loss.item() @torch.no_grad() def test(model, g, feats, y_true, split_idx, evaluator): """ 测试函数 """ model.eval() out = model(g, feats) y_pred = out.argmax(dim=-1, keepdim=True) train_acc = evaluator.eval({ 'y_true': y_true[split_idx['train']], 'y_pred': y_pred[split_idx['train']], })['acc'] valid_acc = evaluator.eval({ 'y_true': y_true[split_idx['valid']], 'y_pred': y_pred[split_idx['valid']], })['acc'] test_acc = evaluator.eval({ 'y_true': y_true[split_idx['test']], 'y_pred': y_pred[split_idx['test']], })['acc'] return train_acc, valid_acc, test_acc class Logger(object): """ 用于日志记录 """ def __init__(self, runs, info=None): self.info = info self.results = [[] for _ in range(runs)] def add_result(self, run, result): assert len(result) == 3 assert run >= 0 and run
关注打赏
最近更新
- 深拷贝和浅拷贝的区别(重点)
- 【Vue】走进Vue框架世界
- 【云服务器】项目部署—搭建网站—vue电商后台管理系统
- 【React介绍】 一文带你深入React
- 【React】React组件实例的三大属性之state,props,refs(你学废了吗)
- 【脚手架VueCLI】从零开始,创建一个VUE项目
- 【React】深入理解React组件生命周期----图文详解(含代码)
- 【React】DOM的Diffing算法是什么?以及DOM中key的作用----经典面试题
- 【React】1_使用React脚手架创建项目步骤--------详解(含项目结构说明)
- 【React】2_如何使用react脚手架写一个简单的页面?