您当前的位置: 首页 >  网络

段智华

暂无认证

  • 0浏览

    0关注

    1232博文

    0收益

  • 0浏览

    0点赞

    0打赏

    0留言

私信
关注
热门博文

图神经网络实战案例-新冠疫苗项目实战,助力疫情!

段智华 发布时间:2021-02-16 00:00:00 ,浏览量:0

Kaggle新冠疫苗研发竞赛

https://www.kaggle.com/c/stanford-covid-vaccine/overview

mRNA疫苗已经成为2019冠状病毒最快的候选疫苗,但目前它们面临着关键的潜在限制。目前最大的挑战之一是如何设计超稳定的RNA分子(mRNA)。传统疫苗是装在注射器里通过冷藏运输到世界各地,但mRNA疫苗目前还不可能做到这一点。

研究人员已经观察到RNA分子有降解的倾向。这是一个严重的限制,降解会使mRNA疫苗失效。目前,对于特定RNA的主干中哪个部位最容易受影响的细节知之甚少。在不了解这些情况的情况下,目前针对COVID-19的mRNA疫苗必须在高度冷藏条件下准备和运输,它们必须能够得到稳定,否则不太可能送达地球上的每个人。

由斯坦福大学医学院(Stanford’s School of Medicine)计算生物学家瑞朱·达斯(Rhiju Das)教授领导的永恒星系(Eterna)社区将科学家和竞赛玩家聚集在一起,解决谜题并发明药物。Eterna是一款在线竞赛平台,通过谜题挑战玩家解决诸如mRNA设计等科学问题。由斯坦福大学的研究人员合成并进行实验测试,以获得关于RNA分子的新见解。Eterna社区之前已经开启了新的科学原理,对致命疾病做出了新的诊断,并利用世界上最强大的智力资源改善公众生活。Eterna社区通过其在20多份出版物上的贡献推动了生物技术,包括RNA生物技术进展。

在这次竞赛中,我们希望利用Kaggle社区的数据科学专业知识来开发模型和设计RNA降解规则。模型将预测RNA分子每个碱基的可能降解率,训练的对象是由超过3000个RNA分子组成的Eterna数据集子集(它们跨越了一整套序列和结构),以及它们在每个位置的降解率。然后,我们将根据Eterna玩家刚刚为COVID-19 mRNA疫苗设计的第二代RNA序列为模型评分。这些最终的测试序列目前正在合成和实验表征在斯坦福大学与建模工作并行——自然将评分模型!

提高mRNA疫苗的稳定性已经在探索,我们必须解决这一深刻的科学挑战,以加速mRNA疫苗研究,并提供一种针对COVID-19背后病毒SARS-CoV-2的冰箱稳定疫苗。我们正在试图解决的问题希望得到学术实验室、工业研发团队和超级计算机的帮助,你可以加入电子竞赛玩家、科学家和开发者的团队,在Eterna永恒星球上对抗这一毁灭性病毒。

 

一:案例简介

将编码的DNA送到细胞中,细胞使用mRNA(Messenger RNA)组装蛋白,免疫系统检测到组装蛋白质以后,利用构建病毒蛋白的编码基因激活免疫系统产生抗体,增强针对冠状病毒的抵御能力。

不同的mRNA生成同一个蛋白质

mRNA随着时间的流逝及温度的变化发生了降解

如何找到结构更加稳定的mRNA?利用图神经网络找到更稳定的mRNA,颜色越深越稳定.

二:新冠疫苗项目拔高实战

代码文件结构:

数据分布特征

查看当前挂载的数据集目录 

# 查看当前挂载的数据集目录, 该目录下的变更重启环境后会自动还原
# 这里可以看到我们数据集的名称为: data60987
!ls /home/aistudio/data

运行结果

data60987

查看数据集train.json,数据格式:

{"index":401,"id":"id_2a983d026","sequence":"GGAAAAAGGCUCAAAAACUGUACGAAGGUACAGAAAAACCAUAGCGAAAGCUAUGGAAAAAGAGCCAACUACUGGUUCGCCAGUAGAAAAGAAACAACAACAACAAC","structure":".......(((((.....((((((....)))))).....(((((((....))))))).....)))))..(((((((....))))))).....................","predicted_loop_type":"EEEEEEESSSSSMMMMMSSSSSSHHHHSSSSSSMMMMMSSSSSSSHHHHSSSSSSSMMMMMSSSSSXXSSSSSSSHHHHSSSSSSSEEEEEEEEEEEEEEEEEEEEE","signal_to_noise":8.157,"SN_filter":1.0,"seq_length":107,"seq_scored":68,"reactivity_error":[0.1423,0.2177,0.139,0.0994,0.1153,0.0995,0.0582,0.0237,0.0226,0.0263,0.0235,0.0692,0.1025,0.0635,0.0713,0.0749,0.0542,0.0218,0.0075,0.0208,0.0213,0.018,0.024,0.0736,0.0713,0.0391,0.0696,0.0423,0.0273,0.0198,0.0203,0.0093,0.0508,0.0871,0.0622,0.0625,0.0623,0.0473,0.0159,0.0217,0.0155,0.0119,0.0145,0.0128,0.0133,0.0319,0.0558,0.0359,0.0346,0.0085,0.0096,0.0161,0.0129,0.0113,0.0137,0.0434,0.0588,0.0595,0.0624,0.0525,0.0378,0.0177,0.0141,0.016,0.0094,0.0228,0.0578,0.0383],"deg_error_Mg_pH10":[0.1878,0.3274,0.1631,0.0812,0.1629,0.1502,0.1275,0.0633,0.0685,0.0775,0.0695,0.1889,0.1619,0.1326,0.0742,0.0613,0.0547,0.0485,0.0289,0.0515,0.0497,0.0304,0.0233,0.0528,0.0466,0.03,0.0513,0.0561,0.0261,0.0387,0.0316,0.0289,0.112,0.1137,0.0846,0.0631,0.0433,0.0464,0.0265,0.0315,0.0346,0.0218,0.0254,0.0223,0.0176,0.0327,0.0335,0.0297,0.0262,0.03,0.0331,0.0201,0.0329,0.0186,0.0232,0.073,0.0625,0.0585,0.0593,0.0471,0.0453,0.0317,0.0195,0.0337,0.0311,0.0333,0.036,0.0562],"deg_error_pH10":[0.232,0.3104,0.1631,0.0778,0.1532,0.1399,0.1284,0.0564,0.0634,0.099,0.0594,0.1322,0.1365,0.136,0.1118,0.1049,0.0986,0.0473,0.0267,0.0433,0.0478,0.0266,0.0375,0.0597,0.0657,0.0551,0.0952,0.0624,0.0561,0.0417,0.0404,0.0317,0.1204,0.1383,0.1066,0.1015,0.0807,0.0884,0.0359,0.0497,0.0424,0.033,0.0313,0.0364,0.021,0.0476,0.0495,0.037,0.047,0.0428,0.0448,0.0425,0.0335,0.0269,0.0401,0.1032,0.0864,0.0977,0.0974,0.0821,0.0959,0.0556,0.033,0.0517,0.0453,0.0626,0.0841,0.1283],"deg_error_Mg_50C":[0.1342,0.2586,0.1547,0.0724,0.1516,0.1302,0.0857,0.0411,0.0349,0.0471,0.0448,0.13,0.1312,0.1216,0.0827,0.0713,0.0502,0.0332,0.0184,0.0269,0.0275,0.0183,0.0237,0.045,0.0522,0.0391,0.0611,0.0413,0.0269,0.021,0.0308,0.0218,0.1118,0.1188,0.0898,0.0648,0.0521,0.0458,0.0247,0.0272,0.0238,0.0166,0.0178,0.019,0.0136,0.0278,0.0366,0.0291,0.0282,0.0167,0.0221,0.0135,0.0189,0.0067,0.0156,0.0818,0.0718,0.0752,0.0815,0.0573,0.0617,0.0326,0.024,0.0299,0.0305,0.0389,0.0441,0.054],"deg_error_50C":[0.1858,0.2902,0.1741,0.0976,0.1655,0.1298,0.1092,0.0595,0.0464,0.0776,0.0601,0.1411,0.1319,0.1292,0.1263,0.1279,0.0998,0.0498,0.0386,0.0481,0.0635,0.0383,0.0499,0.0737,0.0802,0.0752,0.1019,0.0777,0.0529,0.0381,0.055,0.0631,0.1288,0.138,0.0833,0.1019,0.0992,0.081,0.0284,0.045,0.0326,0.0341,0.0316,0.0371,0.0257,0.0677,0.0606,0.0618,0.0519,0.0423,0.033,0.0504,0.0463,0.021,0.0474,0.107,0.0997,0.099,0.0964,0.0838,0.0769,0.0439,0.0315,0.0475,0.0379,0.0719,0.0805,0.099],"reactivity":[1.123,3.8721,1.713,0.8734,1.3266,0.9945,0.2319,0.0312,0.0196,0.0122,0.0234,0.3576,1.1503,0.31,0.5168,0.6628,0.3396,-0.0029,0.0,0.0221,0.0184,0.0193,0.0381,0.6968,0.676,0.1654,0.6669,0.2018,0.0571,0.0247,0.0079,-0.0037,0.228,1.1223,0.5402,0.6254,0.6763,0.3724,0.0151,0.0073,0.0068,0.0094,0.0187,-0.0026,0.014,0.1307,0.5515,0.219,0.1912,-0.0036,0.0044,0.0185,0.0052,0.0088,0.0177,0.2602,0.5248,0.7127,0.7374,0.548,0.2271,0.0377,0.0152,0.0374,-0.0072,0.0419,0.7803,0.2772],"deg_Mg_pH10":[0.712,4.2396,0.9996,0.1747,1.1575,1.0471,0.7494,0.1471,0.1808,0.242,0.1974,2.4576,2.1187,1.5409,0.4156,0.2857,0.2373,0.1608,0.0473,0.2133,0.1952,0.0582,0.0223,0.2264,0.167,0.051,0.2164,0.2843,0.0308,0.1201,0.0582,0.0544,1.5511,1.8827,1.0921,0.617,0.2783,0.3367,0.0867,0.1096,0.1697,0.0581,0.0866,0.0523,0.0335,0.1483,0.1606,0.1344,0.0941,0.1342,0.1802,0.0468,0.1748,0.0442,0.0792,1.112,0.7844,0.8065,0.8327,0.5404,0.5097,0.2545,0.0775,0.3106,0.2497,0.278,0.362,1.045],"deg_pH10":[2.3831,5.385,1.4281,0.1975,1.3957,1.1176,0.8533,0.1538,0.1879,0.5337,0.1605,0.8779,1.0647,1.0386,0.7012,0.714,0.7048,0.0258,0.0211,0.0761,0.0889,0.021,0.0444,0.1558,0.217,0.1494,0.6119,0.1911,0.1477,0.0726,0.0185,0.0222,0.9114,1.3969,0.7886,0.8342,0.5404,0.6823,0.054,0.0596,0.081,0.0615,0.0526,0.0412,0.0131,0.1074,0.1244,0.0869,0.1393,0.1044,0.1463,0.1038,0.0436,0.0343,0.1111,0.8634,0.3376,0.8998,0.7099,0.5414,0.8286,0.267,0.0532,0.2501,0.1356,0.276,0.7102,1.9362],"deg_Mg_50C":[0.6751,4.3933,1.6426,0.2694,1.8023,1.3557,0.4751,0.0981,0.0543,0.1044,0.1166,1.4446,1.6431,1.4634,0.6241,0.4909,0.2331,0.0405,0.0148,0.0399,0.0355,0.0147,0.0251,0.1578,0.2427,0.1206,0.3671,0.1302,0.0291,0.0167,0.0456,0.0201,1.5332,1.951,1.1253,0.5821,0.3826,0.2864,0.0548,0.0236,0.0396,0.0214,0.0267,0.0128,0.0107,0.0616,0.1508,0.1051,0.0847,0.0114,0.0522,0.0009,0.0218,0.0,0.0182,1.0762,0.7187,1.035,1.2014,0.5732,0.7027,0.1796,0.0792,0.1617,0.1442,0.235,0.3647,0.5791],"deg_50C":[1.0915,3.7795,1.3767,0.3335,1.3792,0.7563,0.4283,0.1329,0.052,0.2067,0.1244,0.8698,0.7834,0.7217,0.7968,0.9452,0.5773,0.0347,0.0506,0.0815,0.1747,0.0502,0.0933,0.2459,0.3174,0.2917,0.6025,0.3061,0.0982,0.0406,0.1026,0.2047,0.9596,1.2014,0.3152,0.7039,0.7273,0.4594,0.0139,0.025,0.0175,0.0516,0.0412,0.0379,0.0206,0.2756,0.2029,0.2619,0.1538,0.0833,0.05,0.1426,0.1124,0.0099,0.1384,0.8286,0.561,0.8028,0.6205,0.5066,0.3916,0.1169,0.0374,0.1668,0.0611,0.3675,0.5446,0.8819]}
........

安装图学习框架PGL

!pip install pgl -q  # 安装PGL
# 主要代码文件在./src目录
%cd ./src

导入包

import json
import random
import numpy as np
import pandas as pd


import matplotlib.pyplot as plt
import networkx as nx


from utils.config import prepare_config, make_dir
from utils.logger import prepare_logger, log_to_file
from data_parser import GraphParser


seed = 123
np.random.seed(seed)
random.seed(seed)

加载训练数据

# https://www.kaggle.com/c/stanford-covid-vaccine/data


df = pd.read_json('../data/data60987/train.json', lines=True)
sample = df.loc[0]
print(sample)

运行结果:

SN_filter                                                              0
deg_50C                [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
deg_Mg_50C             [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
deg_Mg_pH10            [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
deg_error_50C          [191738.0886, 191738.0886, 191738.0886, 191738...
deg_error_Mg_50C       [171525.3217, 171525.3217, 171525.3217, 171525...
deg_error_Mg_pH10      [104235.1742, 104235.1742, 104235.1742, 104235...
deg_error_pH10         [222620.9531, 222620.9531, 222620.9531, 222620...
deg_pH10               [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
id                                                          id_2a7a4496f
index                                                                400
predicted_loop_type    EEEEESSSHHHSSSSSSSSSSSSSSSSSSSSSSSISSSSHHHHSSS...
reactivity             [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
reactivity_error       [146151.225, 146151.225, 146151.225, 146151.22...
seq_length                                                           107
seq_scored                                                            68
sequence               GGAAAGCCCGCGGCGCCGGGCGCCGCGGCCGCCCAGGCCGCCCGGC...
signal_to_noise                                                        0
structure              .....(((...)))((((((((((((((((((((.((((....)))...
Name: 0, dtype: object

本案例要预测RNA序列不同位置的降解速率,训练数据中提供了多个ground值,标签包括以下几项:reactivity, deg_Mg_pH10, and deg_Mg_50

  • reactivity - (1x68 vector 训练集,1x91测试集) 一个浮点数数组,与seq_scores有相同的长度,是前68个碱基的反应活性值,按顺序表示,用于确定RNA样本可能的二级结构。

  • deg_Mg_pH10 - (训练集 1x68向量,1x91测试集)一个浮点数数组,与seq_scores有相同的长度,是前68个碱基的反应活性值,按顺序表示,用于确定在高pH (pH 10)下的降解可能性。

  • deg_Mg_50 - (训练集 1x68向量,1x91测试集)一个浮点数数组,与seq_scores有相同的长度,是前68个碱基的反应活性值,按顺序表示,用于确定在高温(50摄氏度)下的降解可能性。

解析数据

args = prepare_config("./config.yaml", isCreate=False, isSave=False)
parser = GraphParser(args)
gdata = parser.parse(sample)

运行结果

{'nfeat': array([[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 1., 0., ..., 0., 0., 0.],
        ...,
        [1., 0., 0., ..., 0., 0., 0.],
        [1., 0., 0., ..., 0., 0., 0.],
        [1., 0., 0., ..., 0., 0., 0.]], dtype=float32),
 'edges': array([[  0,   1],
        [  1,   0],
        [  1,   2],
        ...,
        [142, 105],
        [106, 142],
        [142, 106]]),
 'efeat': array([[ 0.,  0.,  0.,  1.,  1.],
        [ 0.,  0.,  0., -1.,  1.],
        [ 0.,  0.,  0.,  1.,  1.],
        ...,
        [ 0.,  1.,  0.,  0.,  0.],
        [ 0.,  1.,  0.,  0.,  0.],
        [ 0.,  1.,  0.,  0.,  0.]], dtype=float32),
 'labels': array([[ 0.    ,  0.    ,  0.    ],
        [ 0.    ,  0.    ,  0.    ],
        ...,
        [ 0.    ,  0.9213,  0.    ],
        [ 6.8894,  3.5097,  5.7754],
        [ 0.    ,  1.8426,  6.0642],
          ...,        
        [ 0.    ,  0.    ,  0.    ],
        [ 0.    ,  0.    ,  0.    ]], dtype=float32),
 'mask': array([[ True],
        [ True],
     ......
       [False]])}

查看各数据的维度

print(gdata['nfeat'].shape)
print(gdata['edges'].shape)
print(gdata['efeat'].shape)
print(gdata['labels'].shape)
print(gdata['mask'].shape)

运行结果

(143, 14)
(564, 2)
(564, 5)
(143, 3)
(143, 1)

训练数据解析的源代码:

class GraphParser(object):
    def __init__(self, config, mode="train"):
        self.config = config
        self.mode = mode


    def parse(self, sample):
        labels = []
        nfeat = []
        efeat = []
        edges = []
        train_mask = []
        test_mask = []


        sequence = sample['sequence']
        predicted_loop_type = sample['predicted_loop_type']
        seq_length = sample['seq_length']
        seq_scored = sample['seq_scored']


        pair_info = match_pair(sample['structure'])


        paired_nodes = {}
        for j in range(seq_length):
            add_base_node(nfeat, sequence[j], predicted_loop_type[j])


            if j + 1 < seq_length: # edge between current node and next node
                add_edges_between_base_nodes(edges, efeat, j, j + 1)


            if pair_info[j] != -1:
                if pair_info[j] not in paired_nodes:
                    paired_nodes[pair_info[j]] = [j]
                else:
                    paired_nodes[pair_info[j]].append(j)


            train_mask.append(j < seq_scored)
            test_mask.append(True)


        if self.config.add_edge_for_paired_nodes:
            for pair in paired_nodes.values():
                add_edges_between_paired_nodes(edges, efeat, pair[0], pair[1])


        if self.config.add_codon_nodes:
            codon_node_idx = seq_length - 1
            for j in range(seq_length):
                if j % 3 == 0:
                    # add codon node
                    add_codon_node(nfeat)
                    codon_node_idx += 1
                    train_mask.append(False)
                    test_mask.append(False)
                    if self.mode != "test":
                        labels.append([0, 0, 0])


                    if codon_node_idx > seq_length:
                        # add edges between adjacent codon nodes
                        add_edges_between_codon_nodes(
                                edges, efeat, codon_node_idx - 1, codon_node_idx)


                # add edges between codon node and base node
                add_edges_between_codon_and_base_node(
                        edges, efeat, j, codon_node_idx)


        if self.mode != 'test':
            react = sample['reactivity']
            deg_Mg_pH10 = sample['deg_Mg_pH10']
            deg_Mg_50C = sample['deg_Mg_50C']


            for j in range(seq_length):
                if j < seq_scored:
                    labels.append([react[j], deg_Mg_pH10[j], deg_Mg_50C[j] ])
                else:
                    labels.append([0, 0, 0])


        gdata = {}
        gdata['nfeat'] = np.array(nfeat, dtype="float32")
        gdata['edges'] = np.array(edges, dtype="int64")
        gdata['efeat'] = np.array(efeat, dtype="float32")
        if self.mode != "test":
            gdata['labels'] = np.array(labels, dtype="float32")
            gdata['mask'] = np.array(train_mask, dtype=bool).reshape(-1, 1)
        else:
            # fake labels
            gdata['labels'] = np.zeros((self.config.batch_size, self.config.num_class))
            gdata['mask'] = np.array(test_mask, dtype=bool).reshape(-1, 1)


        return gdata


数据的可视化



fig = plt.figure(figsize=(24, 12))
nx_G = nx.Graph()
nx_G.add_nodes_from([i for i in range(len(gdata['nfeat']))])


nx_G.add_edges_from(gdata['edges'])
node_color = ['g' for _ in range(sample['seq_length'])] + \
['y' for _ in range(len(gdata['nfeat']) - sample['seq_length'])]
options = {
    "node_color": node_color,
}
pos = nx.spring_layout(nx_G, iterations=400, k=0.2)
nx.draw(nx_G, pos, **options)


plt.show()

运行结果

三 GNN简化模型 实现消息发送及接收

import paddle
import paddle.fluid as fluid




def copy_send(src_feat, dst_feat, edge_feat):
    """doc"""
    return src_feat["h"]




def mean_recv(feat):
    """doc"""
    return fluid.layers.sequence_pool(feat, pool_type="average")




def sum_recv(feat):
    """doc"""
    return fluid.layers.sequence_pool(feat, pool_type="sum")




def max_recv(feat):
    """doc"""
    return fluid.layers.sequence_pool(feat, pool_type="max")




def simple_gnn(gw, feature, hidden_size, act, name):
    """doc"""
    msg = gw.send(copy_send, nfeat_list=[("h", feature)])
    neigh_feature = gw.recv(msg, sum_recv)
    self_feature = feature
    output = self_feature + neigh_feature


    output = fluid.layers.fc(output,
                            hidden_size,
                            act=act,
                            param_attr=fluid.ParamAttr(name=name + '_w'),
                            bias_attr=fluid.ParamAttr(name=name + '_b'))


    return output


四 GNN 模型类 



class GNNModel(propeller.train.Model):
    def __init__(self, hparam, mode, run_config):
        self.hparam = hparam
        self.mode = mode
        self.is_test = True if self.mode != propeller.RunMode.TRAIN else False
        self.run_config = run_config


    def forward(self, input_dict):
        gw = BatchGraphWrapper(input_dict['num_nodes'],
                               input_dict['num_edges'],
                               input_dict['edges'],
                               edge_feats={'efeat': input_dict['edge_feat']})


        feature = L.fc(input_dict['node_feat'], 
                    size=self.hparam.hidden_size,
                    act=None,
                    bias_attr=F.ParamAttr(name='embed_b'),
                    param_attr=F.ParamAttr(name="embed_w")
                    )


        for layer in range(self.hparam.num_layers):
            if layer == self.hparam.num_layers - 1:
                act = None
            else:
                act = 'leaky_relu'


            feature = GNNlayers.simple_gnn(
                    gw,
                    feature,
                    self.hparam.hidden_size,
                    act,
                    name="%s_%s" % (self.hparam.layer_type, layer))


        feature = L.dropout(
            feature,
            self.hparam.dropout_prob,
            dropout_implementation="upscale_in_train")


        logits = L.fc(feature, 
                size=self.hparam.num_class, 
                act=None,
                bias_attr=F.ParamAttr(name='final_b'),
                param_attr=F.ParamAttr(name="final_w"))


        mask = input_dict['mask']
        logits = paddle_helper.masked_select(logits, mask)


        return [logits, mask]


    def loss(self, predictions, label):
        logits = predictions[0]
        mask = predictions[1]
        label = paddle_helper.masked_select(label, mask)


        loss = L.mse_loss(input=logits, label=label)
        loss = L.reduce_mean(loss)


        return loss


    def backward(self, loss):
        optimizer = F.optimizer.Adam(learning_rate=self.hparam.lr)
        optimizer.minimize(loss)


    def metrics(self, predictions, label):
        result = {}
        logits = predictions[0]
        mask = predictions[1]
        label = paddle_helper.masked_select(label, mask)


        result["MCRMSE"] = propeller.metrics.MCRMSE(label, logits)


        return result




其中采用的评估指标是MCRMSE:

五 GNN模型训练

......
def train(args):
    train_ds = CovidDataset(data_file=args.train_file, config=args, mode="train")
    valid_ds = CovidDataset(data_file=args.valid_file, config=args, mode="valid")


    log.info("train examples: %s" % len(train_ds))
    log.info("valid examples: %s" % len(train_ds))


    train_loader = Dataloader(train_ds, 
                              batch_size=args.batch_size,
                              shuffle=args.shuffle,
                              collate_fn=CollateFn())
    train_loader = multi_epoch_dataloader(train_loader, args.epochs)
    train_loader = PDataset.from_generator_func(train_loader)


    valid_loader = Dataloader(valid_ds, 
                            batch_size=1,
                            shuffle=False,
                            collate_fn=CollateFn())
    valid_loader = PDataset.from_generator_func(valid_loader)


    # warmup start setting
    ws = None
    propeller.train.train_and_eval(
            model_class_or_model_fn=GNNModel,
            params=args,
            run_config=args,
            train_dataset=train_loader,
            eval_dataset={"eval": valid_loader},
            warm_start_setting=ws,
            )




def infer(args):
    # predict for test data
    log.info('Reading %s' % args.test_file)
    test_ds = CovidDataset(args.test_file, args, 'test')
    test_loader = Dataloader(test_ds,
                             batch_size=args.batch_size,
                             shuffle=False,
                             collate_fn=CollateFn(mode='test'))
    test_loader = PDataset.from_generator_func(test_loader)


    est = propeller.Learner(GNNModel, args, args)


    output_path = args.model_path_for_infer.replace("checkpoints/", "outputs/")
    make_dir(output_path)
    filename = os.path.join(output_path, "submission.csv")


    id_seqpos = build_id_seqpos(args.test_file)


    preds = []
    for predictions in est.predict(test_loader,
                                   ckpt_path=args.model_path_for_infer, 
                                   split_batch=False):
        preds.append(predictions[0])


    preds = np.concatenate(preds)
    df_sub = pd.DataFrame({'id_seqpos': id_seqpos,
                           'reactivity': preds[:,0],
                           'deg_Mg_pH10': preds[:,1],
                           'deg_pH10': 0,
                           'deg_Mg_50C': preds[:,2],
                           'deg_50C': 0})
    log.info("saving result to %s" % filename)
    df_sub.to_csv(filename, index=False)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='gnn')
    parser.add_argument("--config", type=str, default="./config.yaml")
    parser.add_argument("--mode", type=str, default="train")
    args = parser.parse_args()


    if args.mode == "infer":
        config = prepare_config(args.config, isCreate=False, isSave=False)
        infer(config)
    else:
        config = prepare_config(args.config, isCreate=True, isSave=True)
        log_to_file(log, config.log_dir)
        train(config)

GNN执行脚本:

python main.py

运行结果:

........
********** Start Loop ************
train loop has hook 
train loop has hook 
train loop has hook 
train loop has hook 
[training]  step: 20  steps/sec: -1.00000  loss: 305.70929  
[training]  step: 40  steps/sec: 141.66859  loss: 103.75951  
[training]  step: 60  steps/sec: 139.14825  loss: 44.02296  
END: epoch 0 ...
BEGIN: epoch 1 ...
[training]  step: 80  steps/sec: 145.89391  loss: 35.04534  
[training]  step: 100  steps/sec: 145.03201  loss: 24.50893  
[training]  step: 120  steps/sec: 143.29997  loss: 22.46627  
[training]  step: 140  steps/sec: 150.44456  loss: 17.93622  
[training]  step: 160  steps/sec: 148.98990  loss: 15.78190  
[training]  step: 180  steps/sec: 144.50660  loss: 14.56967  
END: epoch 1 ...
(240,)
{'MCRMSE': 0.54948246, 'loss': 0.30230046494398266}
write to tensorboard ../checkpoints/covid19/eval_history/eval
write to tensorboard ../checkpoints/covid19/eval_history/eval
[Eval:eval]:MCRMSE:0.5494824647903442  loss:0.30230046494398266
[training]  step: 24980  steps/sec: 5.72037  loss: 0.37832  
********** Stop Loop ************
saving step 25000 to ../checkpoints/covid19/model_25000

六 GNN模型预测

执行脚本

python main.py --mode infer

运行结果

aistudio@jupyter-112853-1529805:~/src$ python main.py --mode infer
[WARNING] 2021-02-15 20:46:38,028 [ __init__.py:   41]: enabling old_styled_ckpt
[DEBUG] 2021-02-15 20:46:38,029 [distribution.py:  130]:        no PROPELLER_DISCONFIG found, try paddlestype setting
[DEBUG] 2021-02-15 20:46:38,030 [distribution.py:  133]:        no paddle stype setting found
[WARNING] 2021-02-15 20:46:38,031 [monitored_executor.py:  392]:        textone not found in ['/home/aistudio/src', '/opt/conda/envs/python35-paddle120-env/lib/python37.zip', '/opt/conda/envs/python35-paddle120-env/lib/python3.7', '/opt/conda/envs/python35-paddle120-env/lib/python3.7/lib-dynload', '/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages', '../']! will not load encrepted model
[INFO] 2021-02-15 20:46:38,041 [     main.py:   85]:    Reading ../data/data60987/valid.json
[INFO] 2021-02-15 20:46:38,488 [  trainer.py:  220]:    Building Predict Graph
[INFO] 2021-02-15 20:46:38,488 [functional.py:  417]:   Try to infer data shapes & types from generator
[INFO] 2021-02-15 20:46:38,489 [functional.py:  434]:   Dataset `predict` has data_shapes: [[-1, 5], [-1, 2], [-1], [-1, 14], [-1], [-1]] data_types: ['float32', 'int64', 'bool', 'float32', 'int64', 'int64']
[INFO] 2021-02-15 20:46:38,541 [  trainer.py:  225]:    Done
[INFO] 2021-02-15 20:46:38,542 [  trainer.py:  240]:    Predict with: 
> Run_config: {'task_name': 'covid19', 'use_cuda': False, 'warm_start_from': '', 'model_path_for_infer': '../checkpoints/covid19/model_810', 'train_file': '../data/data60987/train.json', 'valid_file': '../data/data60987/valid.json', 'test_file': '../data/data60987/valid.json', 'percentage': 0.9, 'add_edge_for_paired_nodes': True, 'add_codon_nodes': True, 'num_layers': 5, 'layer_type': 'simple_gnn', 'emb_size': 64, 'hidden_size': 64, 'num_class': 3, 'dropout_prob': 0.1, 'epochs': 200, 'batch_size': 16, 'lr': 0.001, 'shuffle': True, 'save_steps': 200000000, 'log_steps': 20, 'max_ckpt': 8, 'skip_steps': 0, 'eval_steps': 320, 'eval_max_steps': 10000, 'stdout': True, 'log_dir': '../logs', 'log_filename': 'log.txt', 'save_dir': '../checkpoints', 'output_dir': '../outputs', 'files2saved': ['layers.py', 'data_parser.py', 'config.yaml', 'main.py', 'dataset.py', 'model.py'], 'model_dir': '../checkpoints'}
> Params: {'task_name': 'covid19', 'use_cuda': False, 'warm_start_from': '', 'model_path_for_infer': '../checkpoints/covid19/model_810', 'train_file': '../data/data60987/train.json', 'valid_file': '../data/data60987/valid.json', 'test_file': '../data/data60987/valid.json', 'percentage': 0.9, 'add_edge_for_paired_nodes': True, 'add_codon_nodes': True, 'num_layers': 5, 'layer_type': 'simple_gnn', 'emb_size': 64, 'hidden_size': 64, 'num_class': 3, 'dropout_prob': 0.1, 'epochs': 200, 'batch_size': 16, 'lr': 0.001, 'shuffle': True, 'save_steps': 200000000, 'log_steps': 20, 'max_ckpt': 8, 'skip_steps': 0, 'eval_steps': 320, 'eval_max_steps': 10000, 'stdout': True, 'log_dir': '../logs', 'log_filename': 'log.txt', 'save_dir': '../checkpoints', 'output_dir': '../outputs', 'files2saved': ['layers.py', 'data_parser.py', 'config.yaml', 'main.py', 'dataset.py', 'model.py'], 'model_dir': '../checkpoints'}
> Train_model_spec: ModelSpec(loss=None, predictions=[name: "gather_7.tmp_0"
......
[INFO] 2021-02-15 20:46:41,787 [monitored_executor.py:  540]:   propeller runs in CUDA mode
[INFO] 2021-02-15 20:46:41,787 [monitored_executor.py:  547]:   ********** Start Loop ************
[INFO] 2021-02-15 20:46:41,787 [  trainer.py:  390]:    Runining predict from dir: {'gstep': 0, 'step': 0, 'time': 1613393198.543097}
[INFO] 2021-02-15 20:46:42,180 [monitored_executor.py:  606]:   ********** Stop Loop ************
[INFO] 2021-02-15 20:46:42,195 [     main.py:  114]:    saving result to ../outputs/covid19/model_810/submission.csv

注:本文图文资料来源于 AIStudio-人工智能学习与实训社区

关注
打赏
1659361485
查看更多评论
立即登录/注册

微信扫码登录

0.0483s