您当前的位置: 首页 > 

静静喜欢大白

暂无认证

  • 0浏览

    0关注

    521博文

    0收益

  • 0浏览

    0点赞

    0打赏

    0留言

私信
关注
热门博文

GraphSage-TF代码解读

静静喜欢大白 发布时间:2021-01-04 21:15:51 ,浏览量:0

转载

参考

本次代码是在本地win10下3.6python下运行的,只需要修改运行函数的参数设置,也就是修改一两行即可

目录

代码结构分析

结构

 

数据集

ppi数据集信息(example_data)

toy-ppi-G.json 图的信息

toy-ppi-class_map.json

toy-ppi-id_map.json

toy-ppi-walks.txt

toy-ppi-feats.npy

实验环境要求

配置环境

运行

运行unsupervised_train.py

运行supervised_train.py

代码分析

__init__.py

utils.py

neigh_samplers.py

models.py

layers.py

minibatch.py

aggregators.py

prediction.py

supervised_train.py

unsupervised_train.py

inits.py

citation_eval.py

ppi_eval.py

reddit_eval.py

 

参考

代码结构分析 结构

 

1.文件目录 ├──eval_scripts //验证集 ├──example_data //ppi数据集 └──graphsage//模型结构定义、GCN层定义、……

2.eval_scripts //验证集目录内容 ├──citation_eval.py ├──ppi_eval.py └──reddit_eval.py

3.example_data //ppi数据集 ├──toy-ppi-class_map.json //图节点id映射到类。 ├──toy-ppi-feats.npy //预训练好得到的features ├──toy-ppi-G.json //图的信息 ├──ttoy-ppi-walks//从一点出发随机游走到邻居节点的情况,对于每个点取198次 └──toy-ppi-id_map.json //节点编号与序号的一一对应

4.graphsage//模型结构定义 ├── init //导入模块 ├──aggregators // 聚合函数定义 ├──inits.py // 初始化的一些公用函数 ├── layers // GCN层的定义 ├── metrics // 评测指标的计算 ├── minibatch//minibatch iterator函数定义 ├── models // 各种模型结构定义 ├── neigh_samplers //定义从节点的邻居中采样的采样器 ├── prediction// ├── supervised_models ├── supervised_train ├── unsupervised_train └── utils // 工具函数的定义

数据集 数据集#图#节点#边#特征#标签(y)Cora12708542914337Citeseer13327473237036Pubmed119717443385003PPI245694481871650121Reddit12329651160691960241Nell16575526614461278210

ppi数据集信息(example_data) toy-ppi-G.json 图的信息

数据中只有一个图,用来做节点分类任务。 图为无向图,由nodes集和links集合构成,每个集合都是一个list,里面包含的每一个node或link都是词典形式存储的在这里插入图片描述 数据格式:

{ 
  directed: false
  graph : {
              {name: disjoint_union(,) }
           nodes:  [
                        {  
                                test: false
                         id: 0
                         features: [ ... ]
                         val: false
                          lable: [ ... ]
                       }
                       {...}
                         ...
                  ]

            links: [
                       {  
                                test_removed: false
                        train_removed: false
                        target: 800 # 指向的节点id(默认从小节点指向大节点)
                        source: 0   # 从0节点按顺序展示
                         }
                         {...}
                           ...
                    ]
      }
}
  • name: disjoint_union(,)表示图的名字
  • toy-ppi-G.json里只有一个图 (可能是因为用于节点分类只需要一张图即可,做图分类任务需要多张图)
  • 可以看出,这是个无向图,并且由nodes集和links集合构成,每个集合都是一个list,里面包含的每一个node或link都是词典形式存储的
  • 从github下载的源码中,没有links部分的数据?其实是由于文件过大显示不完整,其实是存在的,比如节点只显示到1883,总共14754个
toy-ppi-class_map.json

图节点id映射到类。格式为:{“0”: [1, 0, 0,…],…,“14754”: [1, 1, 0, 0,…]}

toy-ppi-id_map.json

节点编号与序号的一一对应;数据格式为:{“0”: 0, “1”: 1,…, “14754”: 14754}

toy-ppi-walks.txt
0	708
0	3163
0	276
0	1789
...
1	15
1	1455
1	1327
1	317
1	63
1	1420
...
9715	7369
9715	8983
9715	6983
  • 从一点出发随机游走到邻居节点的情况,对于每个点取198次(即可能有重复情况)
  • 例如:0 708 表示从0点走到708点。
toy-ppi-feats.npy

预训练好得到的features。

数据处理的时候主要通过两个函数 (1):np.save(“test.npy”,数据结构) ----存数据 (2):data =np.load('test.npy") ----取数据 例如,存列表

z = [[[1, 2, 3], ['w']], [[1, 2, 3], ['w']]]
np.save('test.npy', z)
x = np.load('test.npy')

x:
->array([[list([1, 2, 3]), list(['w'])],
       [list([1, 2, 3]), list(['w'])]], dtype=object)

例如,存字典

x
-> {0: 'wpy', 1: 'scg'}
np.save('test.npy',x)
x = np.load('test.npy')
x
->array({0: 'wpy', 1: 'scg'}, dtype=object)

在存为字典格式读取后,需要先调用如下语句 data.item() 将数据numpy.ndarray对象转换为dict

实验环境要求
  • networkx版本必须小于等于1.11,pip install networkx==1.11
  • 其他的也要严格按照实验室环境要求的做,要不然引起不必要的麻烦
  • python版本3.6
absl-py==0.2.2
astor==0.6.2
backports.weakref==1.0.post1
bleach==1.5.0
decorator==4.3.0
enum34==1.1.6
funcsigs==1.0.2
futures==3.2.0
gast==0.2.0
grpcio==1.12.1
html5lib==0.9999999
Markdown==2.6.11
mock==2.0.0
networkx==1.11
numpy==1.14.5
pbr==4.0.4
protobuf==3.6.0
scikit-learn==0.19.1
scipy==1.1.0
six==1.11.0
sklearn==0.0
tensorboard==1.8.0
tensorflow==1.8.0
termcolor==1.1.0
Werkzeug==0.14.1
配置环境
python最开始选的3.6
conda activate GraphSAGE-master
conda install tensorflow==1.8.0
pip list
conda install networkx==1.11
pip list
conda install scikit-learn==0.19.1
pip list
#发现上面的py3.6版本安装不了enum34==1.1.6和futures==3.2.0,但是发现运行代码的时候不影响


#发现上面的py3.6版本安装不了enum34==1.1.6和futures==3.2.0,于是重新创建环境py2.7(未成功)
conda create -n py27 python=2.7
conda activate py27 
conda install enum34==1.1.6
conda install futures==3.2.0
conda install tensorflow==1.8.0#貌似不行,需要离线安装,因为win不支持py2.7安装tf了,但是没安好
conda install networkx==1.11
conda install scikit-learn==0.19.1

【第一次的环境】

【第二次的环境,未成功,就差tf1.8的离线安装】

运行

pycharm选择解析器,会发现有些安装包不适配,pyachrm提醒的时候直接安装即可

【发现下面还是不行,暂时放弃,而且发现后面运行代码的时候并不受影响】

运行unsupervised_train.py
#cmd中运行
python -m graphsage.unsupervised_train --train_prefix ./example_data/toy-ppi --model graphsage_mean --max_total_steps 1000 --validate_iter 10
等价于
#pycharm中运行
python ./graphsage/unsupervised_train.py  --train_prefix ./example_data/toy-ppi --model graphsage_mean --max_total_steps 1000 --validate_iter 10

注意,上述数据集路径和官方给的不一样。如果是在Pycharm中运行,需要更改train_prefix,model等参数的值,需要注意在ide和命令行中参数的格式,在idea中修改成:

####./是同级目录,../是上一级####

flags.DEFINE_string('model', 'graphsage_mean', 'model names. See README for possible values.')  
flags.DEFINE_string('train_prefix', '../example_data/toy-ppi', 'prefix identifying training data. must be specified.')

也就是在下面的代码中进行修改,然后直接右击该py文件直接运行即可

【运行结果】

D:\Anaconda\envs\GraphSAGE-master\python.exe F:/code/GraphSAGE-master/graphsage/unsupervised_train.py
Loading training data..
Removed 0 nodes that lacked proper annotations due to networkx versioning issues
Loaded data.. now preprocessing..
Done loading training data..
Unexpected missing: 0
9716 train nodes
5039 test nodes
2021-01-05 12:40:09.431313: I tensorflow/core/platform/cpu_feature_guard.cc:140] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX AVX2
Epoch: 0001
Iter: 0000 train_loss= 18.78066 train_mrr= 0.23649 train_mrr_ema= 0.23649 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 1.80019
Iter: 0050 train_loss= 18.67712 train_mrr= 0.16173 train_mrr_ema= 0.21749 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.30518
Iter: 0100 train_loss= 18.41344 train_mrr= 0.17981 train_mrr_ema= 0.20753 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.29077
Iter: 0150 train_loss= 18.10207 train_mrr= 0.21065 train_mrr_ema= 0.19910 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28748
Iter: 0200 train_loss= 17.45003 train_mrr= 0.18005 train_mrr_ema= 0.19214 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.29262
Iter: 0250 train_loss= 16.71679 train_mrr= 0.21261 train_mrr_ema= 0.18919 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28971
Iter: 0300 train_loss= 16.64080 train_mrr= 0.20941 train_mrr_ema= 0.18904 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28670
Iter: 0350 train_loss= 16.33514 train_mrr= 0.18145 train_mrr_ema= 0.18745 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28372
Iter: 0400 train_loss= 15.88267 train_mrr= 0.18800 train_mrr_ema= 0.18749 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28125
Iter: 0450 train_loss= 15.74382 train_mrr= 0.18654 train_mrr_ema= 0.18716 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.27997
Iter: 0500 train_loss= 15.58050 train_mrr= 0.17311 train_mrr_ema= 0.18805 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28234
Iter: 0550 train_loss= 15.37372 train_mrr= 0.19895 train_mrr_ema= 0.18720 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28504
Iter: 0600 train_loss= 15.11785 train_mrr= 0.18306 train_mrr_ema= 0.18627 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28553
Iter: 0650 train_loss= 15.04833 train_mrr= 0.17784 train_mrr_ema= 0.18642 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28460
Iter: 0700 train_loss= 14.93566 train_mrr= 0.17898 train_mrr_ema= 0.18615 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28595
Iter: 0750 train_loss= 14.94030 train_mrr= 0.16468 train_mrr_ema= 0.18470 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28714
Iter: 0800 train_loss= 14.82021 train_mrr= 0.17996 train_mrr_ema= 0.18407 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28590
Iter: 0850 train_loss= 14.75895 train_mrr= 0.20370 train_mrr_ema= 0.18402 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28603
Iter: 0900 train_loss= 14.79193 train_mrr= 0.17865 train_mrr_ema= 0.18508 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28522
Iter: 0950 train_loss= 14.68051 train_mrr= 0.18984 train_mrr_ema= 0.18638 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28426
Iter: 1000 train_loss= 14.66581 train_mrr= 0.18520 train_mrr_ema= 0.18604 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28330
Iter: 1050 train_loss= 14.64359 train_mrr= 0.18334 train_mrr_ema= 0.18624 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28250
Iter: 1100 train_loss= 14.66787 train_mrr= 0.16166 train_mrr_ema= 0.18589 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28193
Iter: 1150 train_loss= 14.65202 train_mrr= 0.19368 train_mrr_ema= 0.18547 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28159
Iter: 1200 train_loss= 14.65571 train_mrr= 0.17497 train_mrr_ema= 0.18529 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28211
Iter: 1250 train_loss= 14.63282 train_mrr= 0.18568 train_mrr_ema= 0.18499 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28408
Iter: 1300 train_loss= 14.63904 train_mrr= 0.17733 train_mrr_ema= 0.18549 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28504
Iter: 1350 train_loss= 14.62205 train_mrr= 0.17858 train_mrr_ema= 0.18492 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28581
Iter: 1400 train_loss= 14.59377 train_mrr= 0.18335 train_mrr_ema= 0.18578 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28678
Iter: 1450 train_loss= 14.61559 train_mrr= 0.19628 train_mrr_ema= 0.18639 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28726
Iter: 1500 train_loss= 14.58464 train_mrr= 0.18871 train_mrr_ema= 0.18576 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28664
Iter: 1550 train_loss= 14.61813 train_mrr= 0.17187 train_mrr_ema= 0.18629 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28593
Iter: 1600 train_loss= 14.61389 train_mrr= 0.19341 train_mrr_ema= 0.18711 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28527
Iter: 1650 train_loss= 14.61634 train_mrr= 0.19737 train_mrr_ema= 0.18766 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28483
Iter: 1700 train_loss= 14.57671 train_mrr= 0.19137 train_mrr_ema= 0.18717 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28439
Iter: 1750 train_loss= 14.55233 train_mrr= 0.20713 train_mrr_ema= 0.18679 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28405
Iter: 1800 train_loss= 14.58431 train_mrr= 0.20119 train_mrr_ema= 0.18758 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28359
Iter: 1850 train_loss= 14.59033 train_mrr= 0.18874 train_mrr_ema= 0.18673 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28320
Iter: 1900 train_loss= 14.61115 train_mrr= 0.18718 train_mrr_ema= 0.18686 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28276
Iter: 1950 train_loss= 14.59950 train_mrr= 0.17403 train_mrr_ema= 0.18849 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28236
Iter: 2000 train_loss= 14.59908 train_mrr= 0.18091 train_mrr_ema= 0.18714 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28214
Iter: 2050 train_loss= 14.58339 train_mrr= 0.19607 train_mrr_ema= 0.18727 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28182
Iter: 2100 train_loss= 14.56937 train_mrr= 0.19161 train_mrr_ema= 0.18767 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28153
Iter: 2150 train_loss= 14.61444 train_mrr= 0.19147 train_mrr_ema= 0.18828 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28113
Iter: 2200 train_loss= 14.61586 train_mrr= 0.19568 train_mrr_ema= 0.18844 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28100
Iter: 2250 train_loss= 14.58835 train_mrr= 0.17902 train_mrr_ema= 0.18864 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28069
Iter: 2300 train_loss= 14.59437 train_mrr= 0.18586 train_mrr_ema= 0.18851 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28042
Iter: 2350 train_loss= 14.58622 train_mrr= 0.19174 train_mrr_ema= 0.18775 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28034
Iter: 2400 train_loss= 14.59255 train_mrr= 0.19002 train_mrr_ema= 0.18706 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28019
Iter: 2450 train_loss= 14.61124 train_mrr= 0.19741 train_mrr_ema= 0.18872 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28009
Iter: 2500 train_loss= 14.61114 train_mrr= 0.18167 train_mrr_ema= 0.18885 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.27986
Iter: 2550 train_loss= 14.58769 train_mrr= 0.21136 train_mrr_ema= 0.18786 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.27974
Iter: 2600 train_loss= 14.59534 train_mrr= 0.19214 train_mrr_ema= 0.18814 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.27957
Iter: 2650 train_loss= 14.57917 train_mrr= 0.17698 train_mrr_ema= 0.18707 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.27946
Iter: 2700 train_loss= 14.59813 train_mrr= 0.18994 train_mrr_ema= 0.18711 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.27930
Iter: 2750 train_loss= 14.58086 train_mrr= 0.18377 train_mrr_ema= 0.18592 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.27909
Iter: 2800 train_loss= 14.60047 train_mrr= 0.18927 train_mrr_ema= 0.18653 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.27886
Iter: 2850 train_loss= 14.60010 train_mrr= 0.18386 train_mrr_ema= 0.18697 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.27862
Iter: 2900 train_loss= 14.60427 train_mrr= 0.18447 train_mrr_ema= 0.18848 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.27853
Iter: 2950 train_loss= 14.56377 train_mrr= 0.20184 train_mrr_ema= 0.18818 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.27839
Iter: 3000 train_loss= 14.59801 train_mrr= 0.16331 train_mrr_ema= 0.18764 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.27824
Iter: 3050 train_loss= 14.60474 train_mrr= 0.18347 train_mrr_ema= 0.18711 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.27809
Iter: 3100 train_loss= 14.59281 train_mrr= 0.18725 train_mrr_ema= 0.18756 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.27805
Iter: 3150 train_loss= 14.62110 train_mrr= 0.19122 train_mrr_ema= 0.18804 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.27790
Iter: 3200 train_loss= 14.57584 train_mrr= 0.17738 train_mrr_ema= 0.18747 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.27779
Iter: 3250 train_loss= 14.60866 train_mrr= 0.17803 train_mrr_ema= 0.18844 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.27764
Iter: 3300 train_loss= 14.58529 train_mrr= 0.20240 train_mrr_ema= 0.18789 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.27751
Iter: 3350 train_loss= 14.62195 train_mrr= 0.18435 train_mrr_ema= 0.18832 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.27737
Iter: 3400 train_loss= 14.56922 train_mrr= 0.19166 train_mrr_ema= 0.18750 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.27717
Iter: 3450 train_loss= 14.58548 train_mrr= 0.19197 train_mrr_ema= 0.18727 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.27697
Iter: 3500 train_loss= 14.58611 train_mrr= 0.18371 train_mrr_ema= 0.18716 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.27681
Iter: 3550 train_loss= 14.58547 train_mrr= 0.18298 train_mrr_ema= 0.18861 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.27674
Iter: 3600 train_loss= 14.57893 train_mrr= 0.18505 train_mrr_ema= 0.18868 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.27701
Iter: 3650 train_loss= 14.57411 train_mrr= 0.17655 train_mrr_ema= 0.18987 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.27730
Iter: 3700 train_loss= 14.59241 train_mrr= 0.19698 train_mrr_ema= 0.18930 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.27722
Optimization Finished!

Process finished with exit code 0
运行supervised_train.py

注意train_prefix参数的值也需要改: ../example_data/toy-ppi

python -m graphsage.supervised_train --train_prefix ./example_data/toy-ppi --model graphsage_mean --sigmoid
等价于
python ./graphsage/supervised_train.py --train_prefix ./example_data/toy-ppi --model graphsage_mean --sigmoid

也就是代码修改同上

运行结果

D:\Anaconda\envs\GraphSAGE-master\python.exe F:/code/GraphSAGE-master/graphsage/supervised_train.py
Loading training data..
Removed 0 nodes that lacked proper annotations due to networkx versioning issues
Loaded data.. now preprocessing..
Done loading training data..
WARNING:tensorflow:From F:\code\GraphSAGE-master\graphsage\supervised_models.py:118: softmax_cross_entropy_with_logits (from tensorflow.python.ops.nn_ops) is deprecated and will be removed in a future version.
Instructions for updating:

Future major versions of TensorFlow will allow gradients to flow
into the labels input on backprop by default.

See @{tf.nn.softmax_cross_entropy_with_logits_v2}.

2021-01-05 14:24:52.211993: I tensorflow/core/platform/cpu_feature_guard.cc:140] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX AVX2
Epoch: 0001
D:\Anaconda\envs\GraphSAGE-master\lib\site-packages\sklearn\metrics\classification.py:1135: UndefinedMetricWarning: F-score is ill-defined and being set to 0.0 in labels with no predicted samples.
  'precision', 'predicted', average, warn_for)
D:\Anaconda\envs\GraphSAGE-master\lib\site-packages\sklearn\metrics\classification.py:1137: UndefinedMetricWarning: F-score is ill-defined and being set to 0.0 in labels with no true samples.
  'recall', 'true', average, warn_for)
Iter: 0000 train_loss= 160.39902 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 191.88171 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 1.31548
Iter: 0005 train_loss= 177.72525 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 191.88171 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.35738
Iter: 0010 train_loss= 168.67435 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 191.88171 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.26747
Iter: 0015 train_loss= 174.82602 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 191.88171 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.23475
Epoch: 0002
Iter: 0001 train_loss= 169.43646 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 196.30547 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.22217
Iter: 0006 train_loss= 171.03656 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 196.30547 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.21048
Iter: 0011 train_loss= 168.95322 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 196.30547 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.20330
Iter: 0016 train_loss= 164.48836 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 196.30547 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.19861
Epoch: 0003
Iter: 0002 train_loss= 170.99802 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 181.05011 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.19645
Iter: 0007 train_loss= 170.51253 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 181.05011 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.19307
Iter: 0012 train_loss= 174.38806 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 181.05011 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.19067
Iter: 0017 train_loss= 162.57272 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 181.05011 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.18782
Epoch: 0004
Iter: 0003 train_loss= 170.45332 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 190.83347 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.18709
Iter: 0008 train_loss= 169.09729 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 190.83347 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.18537
Iter: 0013 train_loss= 166.32990 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 190.83347 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.18381
Iter: 0018 train_loss= 173.10933 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 190.83347 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.18272
Epoch: 0005
Iter: 0004 train_loss= 165.87482 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 191.74899 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.18264
Iter: 0009 train_loss= 168.55566 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 191.74899 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.18191
Iter: 0014 train_loss= 173.05153 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 191.74899 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.18122
Epoch: 0006
Iter: 0000 train_loss= 168.48744 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 177.71576 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.18131
Iter: 0005 train_loss= 164.95117 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 177.71576 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.18065
Iter: 0010 train_loss= 166.21835 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 177.71576 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.17936
Iter: 0015 train_loss= 177.44318 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 177.71576 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.17864
Epoch: 0007
Iter: 0001 train_loss= 167.81136 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 186.21609 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.17919
Iter: 0006 train_loss= 174.58884 train_f1_mic= 0.00195 train_f1_mac= 0.00022 val_loss= 186.21609 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.17845
Iter: 0011 train_loss= 165.81683 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 186.21609 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.17768
Iter: 0016 train_loss= 171.65659 train_f1_mic= 0.00195 train_f1_mac= 0.00020 val_loss= 186.21609 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.17679
Epoch: 0008
Iter: 0002 train_loss= 174.44943 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 192.19485 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.17647
Iter: 0007 train_loss= 176.99825 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 192.19485 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.17634
Iter: 0012 train_loss= 167.94687 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 192.19485 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.17634
Iter: 0017 train_loss= 172.83304 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 192.19485 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.17586
Epoch: 0009
Iter: 0003 train_loss= 176.01657 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 197.73112 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.17600
Iter: 0008 train_loss= 169.39464 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 197.73112 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.17550
Iter: 0013 train_loss= 168.62959 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 197.73112 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.17529
Iter: 0018 train_loss= 168.67769 train_f1_mic= 0.00200 train_f1_mac= 0.00023 val_loss= 197.73112 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.17495
Epoch: 0010
Iter: 0004 train_loss= 165.34845 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 185.15099 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.17522
Iter: 0009 train_loss= 169.52484 train_f1_mic= 0.00195 train_f1_mac= 0.00019 val_loss= 185.15099 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.17464
Iter: 0014 train_loss= 168.72287 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 185.15099 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.17434
Optimization Finished!
Full validation stats: loss= 184.90506 f1_micro= 0.00055 f1_macro= 0.00005 time= 0.54853
Writing test set stats to file (don't peak!)

Process finished with exit code 0
  • graphsage_mean – GraphSage with mean-based aggregator
  • graphsage_seq – GraphSage with LSTM-based aggregator
  • graphsage_maxpool – GraphSage with max-pooling aggregator (as described in the NIPS 2017 paper)
  • graphsage_meanpool – GraphSage with mean-pooling aggregator (a variant of the pooling aggregator, where the element-wie mean replaces the element-wise max).
  • gcn – GraphSage with GCN-based aggregator
  • n2v – an implementation of DeepWalk (called n2v for short in the code.)
  • 可以看出,unsupervised_train.py只运行了1个epoch,共3700次迭代,每50个迭代运行一次validation,batch_size:512
  • 可以看出,supervised_train.py只运行了10个epoch,共40次迭代,每5个迭代运行一次validation,batch_size:512
  • python -m graphsage.unsupervised_train 表示以模块运行,不用具体路径
  • python ./graphsage/unsupervised_train.py 表示以脚本文件直接运行

注意

D:\Anaconda\envs\GraphSAGE-master\lib\site-packages\sklearn\metrics\classification.py:1135: UndefinedMetricWarning: F-score is ill-defined and being set to 0.0 in labels with no predicted samples.
  'precision', 'predicted', average, warn_for)
D:\Anaconda\envs\GraphSAGE-master\lib\site-packages\sklearn\metrics\classification.py:1137: UndefinedMetricWarning: F-score is ill-defined and being set to 0.0 in labels with no true samples.
  'recall', 'true', average, warn_for)
  • 原因:存在一些样本 label 为 y_true,但是y_pred 并没有预测到,即在预测数据中存在实际类别没有的标签时报此warning,此时F1当作0。 比如 y_true = (0, 1, 2, 3, 4) y_pred = (0, 1, 1, 3, 4) label‘2’ 从来没有被预测到,所以F-score没有计算这项 label, 因此这种情况下 F-score 就被当作为 0.0 了。 但是又因为,要计算所有分类结果的平均得分就必须将这项得分为 0 的情况考虑进去,所以,scikit-learn出来提醒你,warning警告一下,但不是错误。
代码分析 __init__.py
from __future__ import print_function
'''即使在python2.X,使用print就得像python3.X那样加括号使用'''

from __future__ import division
'''导入python未来支持的语言特征division(精确除法),
6 # 当我们没有在程序中导入该特征时,"/"操作符执行的是截断除法(Truncating Division);
7 # 当我们导入精确除法之后,"/"执行的是精确除法, "//"执行截断除除法'''
utils.py
from __future__ import print_function

import numpy as np'''导入numpy模块'''
import random'''导入randomm模块'''
import json'''导入json模块'''
import sys'''导入系统模块'''
import os'''导入操作系统模块'''

import networkx as nx'''networkx(图论)的基本操作,用于创建图等操作'''
from networkx.readwrite import json_graph'''用于将networks图保存为json图'''
version_info = list(map(int, nx.__version__.split('.')))#获取netwoeks版本信息然后转换为列表
major = version_info[0]#获取版本号点号前面的数字
minor = version_info[1]#获取版本号点号后面的数字
assert (major  FLAGS.max_total_steps:
                break

        if total_steps > FLAGS.max_total_steps:
                break
    
    print("Optimization Finished!")
    if FLAGS.save_embeddings:# 训练以后是否存储节点的embeddings
        sess.run(val_adj_info.op)

        save_val_embeddings(sess, model, minibatch, FLAGS.validate_batch_size, log_dir())

        if FLAGS.model == "n2v":
            # stopping the gradient for the already trained nodes
            train_ids = tf.constant([[id_map[n]] for n in G.nodes_iter() if not G.node[n]['val'] and not G.node[n]['test']],
                    dtype=tf.int32)
            test_ids = tf.constant([[id_map[n]] for n in G.nodes_iter() if G.node[n]['val'] or G.node[n]['test']], 
                    dtype=tf.int32)
            update_nodes = tf.nn.embedding_lookup(model.context_embeds, tf.squeeze(test_ids))
            no_update_nodes = tf.nn.embedding_lookup(model.context_embeds,tf.squeeze(train_ids))
            update_nodes = tf.scatter_nd(test_ids, update_nodes, tf.shape(model.context_embeds))
            no_update_nodes = tf.stop_gradient(tf.scatter_nd(train_ids, no_update_nodes, tf.shape(model.context_embeds)))
            model.context_embeds = update_nodes + no_update_nodes
            sess.run(model.context_embeds)

            # run random walks
            from graphsage.utils import run_random_walks
            nodes = [n for n in G.nodes_iter() if G.node[n]["val"] or G.node[n]["test"]]
            start_time = time.time()
            pairs = run_random_walks(G, nodes, num_walks=50)
            walk_time = time.time() - start_time

            test_minibatch = EdgeMinibatchIterator(G, 
                id_map,
                placeholders, batch_size=FLAGS.batch_size,
                max_degree=FLAGS.max_degree, 
                num_neg_samples=FLAGS.neg_sample_size,
                context_pairs = pairs,
                n2v_retrain=True,
                fixed_n2v=True)
            
            start_time = time.time()
            print("Doing test training for n2v.")
            test_steps = 0
            for epoch in range(FLAGS.n2v_test_epochs):
                test_minibatch.shuffle()
                while not test_minibatch.end():
                    feed_dict = test_minibatch.next_minibatch_feed_dict()
                    feed_dict.update({placeholders['dropout']: FLAGS.dropout})
                    outs = sess.run([model.opt_op, model.loss, model.ranks, model.aff_all, 
                        model.mrr, model.outputs1], feed_dict=feed_dict)
                    if test_steps % FLAGS.print_every == 0:
                        print("Iter:", '%04d' % test_steps, 
                              "train_loss=", "{:.5f}".format(outs[1]),
                              "train_mrr=", "{:.5f}".format(outs[-2]))
                    test_steps += 1
            train_time = time.time() - start_time
            save_val_embeddings(sess, model, minibatch, FLAGS.validate_batch_size, log_dir(), mod="-test")
            print("Total time: ", train_time+walk_time)
            print("Walk time: ", walk_time)
            print("Train time: ", train_time)

    
# main函数,加载数据并训练
def main(argv=None):
    print("Loading training data..")
    train_data = load_data(FLAGS.train_prefix, load_walks=True)'''load_data函数在graphsage.utils中定义,加载标签数据集'''
    print("Done loading training data..")
    train(train_data)'''# train函数在该文件中定义def train(train_data, test_data=None)'''

if __name__ == '__main__':
    tf.app.run()  # 解析命令行参数,调用main 函数 main(sys.argv) 
''' 
tf.app.run()的作用:通过处理flag解析,然后执行main函数
如果你的代码中的入口函数不叫main(),而是一个其他名字的函数,如test(),则你应该这样写入口tf.app.run(test())
 如果你的代码中的入口函数叫main(),则你就可以把入口写成tf.app.run()
 
 使用tf.app.run() ,上面已经有FLAGS = tf.app.flags.FLAGS了,则已经解析了输入。

则tf.app.run() 中argv=None,通过args = argv[1:] if argv else None则args=None(即不指定,后面会自动解析command)

f = flags.FLAGS构造了解析器f用以解析args, f._parse_flags(参数args)解析args列表或者command输入,args列表为空,则解析command输入,返回的flags_passthrough内为无法解析的数据列表(不包括文件名) 。
'''
inits.py citation_eval.py ppi_eval.py reddit_eval.py   参考

【源码】https://github.com/williamleif/GraphSAGE

【分析】https://www.cnblogs.com/shiyublog/tag/graphsage/

关注
打赏
1510642601
查看更多评论
立即登录/注册

微信扫码登录

0.0518s