Kaggle新冠疫苗研发竞赛
https://www.kaggle.com/c/stanford-covid-vaccine/overview
mRNA疫苗已经成为2019冠状病毒最快的候选疫苗,但目前它们面临着关键的潜在限制。目前最大的挑战之一是如何设计超稳定的RNA分子(mRNA)。传统疫苗是装在注射器里通过冷藏运输到世界各地,但mRNA疫苗目前还不可能做到这一点。
研究人员已经观察到RNA分子有降解的倾向。这是一个严重的限制,降解会使mRNA疫苗失效。目前,对于特定RNA的主干中哪个部位最容易受影响的细节知之甚少。在不了解这些情况的情况下,目前针对COVID-19的mRNA疫苗必须在高度冷藏条件下准备和运输,它们必须能够得到稳定,否则不太可能送达地球上的每个人。
由斯坦福大学医学院(Stanford’s School of Medicine)计算生物学家瑞朱·达斯(Rhiju Das)教授领导的永恒星系(Eterna)社区将科学家和竞赛玩家聚集在一起,解决谜题并发明药物。Eterna是一款在线竞赛平台,通过谜题挑战玩家解决诸如mRNA设计等科学问题。由斯坦福大学的研究人员合成并进行实验测试,以获得关于RNA分子的新见解。Eterna社区之前已经开启了新的科学原理,对致命疾病做出了新的诊断,并利用世界上最强大的智力资源改善公众生活。Eterna社区通过其在20多份出版物上的贡献推动了生物技术,包括RNA生物技术进展。
在这次竞赛中,我们希望利用Kaggle社区的数据科学专业知识来开发模型和设计RNA降解规则。模型将预测RNA分子每个碱基的可能降解率,训练的对象是由超过3000个RNA分子组成的Eterna数据集子集(它们跨越了一整套序列和结构),以及它们在每个位置的降解率。然后,我们将根据Eterna玩家刚刚为COVID-19 mRNA疫苗设计的第二代RNA序列为模型评分。这些最终的测试序列目前正在合成和实验表征在斯坦福大学与建模工作并行——自然将评分模型!
提高mRNA疫苗的稳定性已经在探索,我们必须解决这一深刻的科学挑战,以加速mRNA疫苗研究,并提供一种针对COVID-19背后病毒SARS-CoV-2的冰箱稳定疫苗。我们正在试图解决的问题希望得到学术实验室、工业研发团队和超级计算机的帮助,你可以加入电子竞赛玩家、科学家和开发者的团队,在Eterna永恒星球上对抗这一毁灭性病毒。
一:案例简介
将编码的DNA送到细胞中,细胞使用mRNA(Messenger RNA)组装蛋白,免疫系统检测到组装蛋白质以后,利用构建病毒蛋白的编码基因激活免疫系统产生抗体,增强针对冠状病毒的抵御能力。
不同的mRNA生成同一个蛋白质
mRNA随着时间的流逝及温度的变化发生了降解
如何找到结构更加稳定的mRNA?利用图神经网络找到更稳定的mRNA,颜色越深越稳定.
二:新冠疫苗项目拔高实战
代码文件结构:
数据分布特征
查看当前挂载的数据集目录
# 查看当前挂载的数据集目录, 该目录下的变更重启环境后会自动还原
# 这里可以看到我们数据集的名称为: data60987
!ls /home/aistudio/data
运行结果
data60987
查看数据集train.json,数据格式:
{"index":401,"id":"id_2a983d026","sequence":"GGAAAAAGGCUCAAAAACUGUACGAAGGUACAGAAAAACCAUAGCGAAAGCUAUGGAAAAAGAGCCAACUACUGGUUCGCCAGUAGAAAAGAAACAACAACAACAAC","structure":".......(((((.....((((((....)))))).....(((((((....))))))).....)))))..(((((((....))))))).....................","predicted_loop_type":"EEEEEEESSSSSMMMMMSSSSSSHHHHSSSSSSMMMMMSSSSSSSHHHHSSSSSSSMMMMMSSSSSXXSSSSSSSHHHHSSSSSSSEEEEEEEEEEEEEEEEEEEEE","signal_to_noise":8.157,"SN_filter":1.0,"seq_length":107,"seq_scored":68,"reactivity_error":[0.1423,0.2177,0.139,0.0994,0.1153,0.0995,0.0582,0.0237,0.0226,0.0263,0.0235,0.0692,0.1025,0.0635,0.0713,0.0749,0.0542,0.0218,0.0075,0.0208,0.0213,0.018,0.024,0.0736,0.0713,0.0391,0.0696,0.0423,0.0273,0.0198,0.0203,0.0093,0.0508,0.0871,0.0622,0.0625,0.0623,0.0473,0.0159,0.0217,0.0155,0.0119,0.0145,0.0128,0.0133,0.0319,0.0558,0.0359,0.0346,0.0085,0.0096,0.0161,0.0129,0.0113,0.0137,0.0434,0.0588,0.0595,0.0624,0.0525,0.0378,0.0177,0.0141,0.016,0.0094,0.0228,0.0578,0.0383],"deg_error_Mg_pH10":[0.1878,0.3274,0.1631,0.0812,0.1629,0.1502,0.1275,0.0633,0.0685,0.0775,0.0695,0.1889,0.1619,0.1326,0.0742,0.0613,0.0547,0.0485,0.0289,0.0515,0.0497,0.0304,0.0233,0.0528,0.0466,0.03,0.0513,0.0561,0.0261,0.0387,0.0316,0.0289,0.112,0.1137,0.0846,0.0631,0.0433,0.0464,0.0265,0.0315,0.0346,0.0218,0.0254,0.0223,0.0176,0.0327,0.0335,0.0297,0.0262,0.03,0.0331,0.0201,0.0329,0.0186,0.0232,0.073,0.0625,0.0585,0.0593,0.0471,0.0453,0.0317,0.0195,0.0337,0.0311,0.0333,0.036,0.0562],"deg_error_pH10":[0.232,0.3104,0.1631,0.0778,0.1532,0.1399,0.1284,0.0564,0.0634,0.099,0.0594,0.1322,0.1365,0.136,0.1118,0.1049,0.0986,0.0473,0.0267,0.0433,0.0478,0.0266,0.0375,0.0597,0.0657,0.0551,0.0952,0.0624,0.0561,0.0417,0.0404,0.0317,0.1204,0.1383,0.1066,0.1015,0.0807,0.0884,0.0359,0.0497,0.0424,0.033,0.0313,0.0364,0.021,0.0476,0.0495,0.037,0.047,0.0428,0.0448,0.0425,0.0335,0.0269,0.0401,0.1032,0.0864,0.0977,0.0974,0.0821,0.0959,0.0556,0.033,0.0517,0.0453,0.0626,0.0841,0.1283],"deg_error_Mg_50C":[0.1342,0.2586,0.1547,0.0724,0.1516,0.1302,0.0857,0.0411,0.0349,0.0471,0.0448,0.13,0.1312,0.1216,0.0827,0.0713,0.0502,0.0332,0.0184,0.0269,0.0275,0.0183,0.0237,0.045,0.0522,0.0391,0.0611,0.0413,0.0269,0.021,0.0308,0.0218,0.1118,0.1188,0.0898,0.0648,0.0521,0.0458,0.0247,0.0272,0.0238,0.0166,0.0178,0.019,0.0136,0.0278,0.0366,0.0291,0.0282,0.0167,0.0221,0.0135,0.0189,0.0067,0.0156,0.0818,0.0718,0.0752,0.0815,0.0573,0.0617,0.0326,0.024,0.0299,0.0305,0.0389,0.0441,0.054],"deg_error_50C":[0.1858,0.2902,0.1741,0.0976,0.1655,0.1298,0.1092,0.0595,0.0464,0.0776,0.0601,0.1411,0.1319,0.1292,0.1263,0.1279,0.0998,0.0498,0.0386,0.0481,0.0635,0.0383,0.0499,0.0737,0.0802,0.0752,0.1019,0.0777,0.0529,0.0381,0.055,0.0631,0.1288,0.138,0.0833,0.1019,0.0992,0.081,0.0284,0.045,0.0326,0.0341,0.0316,0.0371,0.0257,0.0677,0.0606,0.0618,0.0519,0.0423,0.033,0.0504,0.0463,0.021,0.0474,0.107,0.0997,0.099,0.0964,0.0838,0.0769,0.0439,0.0315,0.0475,0.0379,0.0719,0.0805,0.099],"reactivity":[1.123,3.8721,1.713,0.8734,1.3266,0.9945,0.2319,0.0312,0.0196,0.0122,0.0234,0.3576,1.1503,0.31,0.5168,0.6628,0.3396,-0.0029,0.0,0.0221,0.0184,0.0193,0.0381,0.6968,0.676,0.1654,0.6669,0.2018,0.0571,0.0247,0.0079,-0.0037,0.228,1.1223,0.5402,0.6254,0.6763,0.3724,0.0151,0.0073,0.0068,0.0094,0.0187,-0.0026,0.014,0.1307,0.5515,0.219,0.1912,-0.0036,0.0044,0.0185,0.0052,0.0088,0.0177,0.2602,0.5248,0.7127,0.7374,0.548,0.2271,0.0377,0.0152,0.0374,-0.0072,0.0419,0.7803,0.2772],"deg_Mg_pH10":[0.712,4.2396,0.9996,0.1747,1.1575,1.0471,0.7494,0.1471,0.1808,0.242,0.1974,2.4576,2.1187,1.5409,0.4156,0.2857,0.2373,0.1608,0.0473,0.2133,0.1952,0.0582,0.0223,0.2264,0.167,0.051,0.2164,0.2843,0.0308,0.1201,0.0582,0.0544,1.5511,1.8827,1.0921,0.617,0.2783,0.3367,0.0867,0.1096,0.1697,0.0581,0.0866,0.0523,0.0335,0.1483,0.1606,0.1344,0.0941,0.1342,0.1802,0.0468,0.1748,0.0442,0.0792,1.112,0.7844,0.8065,0.8327,0.5404,0.5097,0.2545,0.0775,0.3106,0.2497,0.278,0.362,1.045],"deg_pH10":[2.3831,5.385,1.4281,0.1975,1.3957,1.1176,0.8533,0.1538,0.1879,0.5337,0.1605,0.8779,1.0647,1.0386,0.7012,0.714,0.7048,0.0258,0.0211,0.0761,0.0889,0.021,0.0444,0.1558,0.217,0.1494,0.6119,0.1911,0.1477,0.0726,0.0185,0.0222,0.9114,1.3969,0.7886,0.8342,0.5404,0.6823,0.054,0.0596,0.081,0.0615,0.0526,0.0412,0.0131,0.1074,0.1244,0.0869,0.1393,0.1044,0.1463,0.1038,0.0436,0.0343,0.1111,0.8634,0.3376,0.8998,0.7099,0.5414,0.8286,0.267,0.0532,0.2501,0.1356,0.276,0.7102,1.9362],"deg_Mg_50C":[0.6751,4.3933,1.6426,0.2694,1.8023,1.3557,0.4751,0.0981,0.0543,0.1044,0.1166,1.4446,1.6431,1.4634,0.6241,0.4909,0.2331,0.0405,0.0148,0.0399,0.0355,0.0147,0.0251,0.1578,0.2427,0.1206,0.3671,0.1302,0.0291,0.0167,0.0456,0.0201,1.5332,1.951,1.1253,0.5821,0.3826,0.2864,0.0548,0.0236,0.0396,0.0214,0.0267,0.0128,0.0107,0.0616,0.1508,0.1051,0.0847,0.0114,0.0522,0.0009,0.0218,0.0,0.0182,1.0762,0.7187,1.035,1.2014,0.5732,0.7027,0.1796,0.0792,0.1617,0.1442,0.235,0.3647,0.5791],"deg_50C":[1.0915,3.7795,1.3767,0.3335,1.3792,0.7563,0.4283,0.1329,0.052,0.2067,0.1244,0.8698,0.7834,0.7217,0.7968,0.9452,0.5773,0.0347,0.0506,0.0815,0.1747,0.0502,0.0933,0.2459,0.3174,0.2917,0.6025,0.3061,0.0982,0.0406,0.1026,0.2047,0.9596,1.2014,0.3152,0.7039,0.7273,0.4594,0.0139,0.025,0.0175,0.0516,0.0412,0.0379,0.0206,0.2756,0.2029,0.2619,0.1538,0.0833,0.05,0.1426,0.1124,0.0099,0.1384,0.8286,0.561,0.8028,0.6205,0.5066,0.3916,0.1169,0.0374,0.1668,0.0611,0.3675,0.5446,0.8819]}
........
安装图学习框架PGL
!pip install pgl -q # 安装PGL
# 主要代码文件在./src目录
%cd ./src
导入包
import json
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import networkx as nx
from utils.config import prepare_config, make_dir
from utils.logger import prepare_logger, log_to_file
from data_parser import GraphParser
seed = 123
np.random.seed(seed)
random.seed(seed)
加载训练数据
# https://www.kaggle.com/c/stanford-covid-vaccine/data
df = pd.read_json('../data/data60987/train.json', lines=True)
sample = df.loc[0]
print(sample)
运行结果:
SN_filter 0
deg_50C [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
deg_Mg_50C [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
deg_Mg_pH10 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
deg_error_50C [191738.0886, 191738.0886, 191738.0886, 191738...
deg_error_Mg_50C [171525.3217, 171525.3217, 171525.3217, 171525...
deg_error_Mg_pH10 [104235.1742, 104235.1742, 104235.1742, 104235...
deg_error_pH10 [222620.9531, 222620.9531, 222620.9531, 222620...
deg_pH10 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
id id_2a7a4496f
index 400
predicted_loop_type EEEEESSSHHHSSSSSSSSSSSSSSSSSSSSSSSISSSSHHHHSSS...
reactivity [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
reactivity_error [146151.225, 146151.225, 146151.225, 146151.22...
seq_length 107
seq_scored 68
sequence GGAAAGCCCGCGGCGCCGGGCGCCGCGGCCGCCCAGGCCGCCCGGC...
signal_to_noise 0
structure .....(((...)))((((((((((((((((((((.((((....)))...
Name: 0, dtype: object
本案例要预测RNA序列不同位置的降解速率,训练数据中提供了多个ground值,标签包括以下几项:reactivity, deg_Mg_pH10, and deg_Mg_50
reactivity - (1x68 vector 训练集,1x91测试集) 一个浮点数数组,与seq_scores有相同的长度,是前68个碱基的反应活性值,按顺序表示,用于确定RNA样本可能的二级结构。
deg_Mg_pH10 - (训练集 1x68向量,1x91测试集)一个浮点数数组,与seq_scores有相同的长度,是前68个碱基的反应活性值,按顺序表示,用于确定在高pH (pH 10)下的降解可能性。
deg_Mg_50 - (训练集 1x68向量,1x91测试集)一个浮点数数组,与seq_scores有相同的长度,是前68个碱基的反应活性值,按顺序表示,用于确定在高温(50摄氏度)下的降解可能性。
解析数据
args = prepare_config("./config.yaml", isCreate=False, isSave=False)
parser = GraphParser(args)
gdata = parser.parse(sample)
运行结果
{'nfeat': array([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 1., 0., ..., 0., 0., 0.],
...,
[1., 0., 0., ..., 0., 0., 0.],
[1., 0., 0., ..., 0., 0., 0.],
[1., 0., 0., ..., 0., 0., 0.]], dtype=float32),
'edges': array([[ 0, 1],
[ 1, 0],
[ 1, 2],
...,
[142, 105],
[106, 142],
[142, 106]]),
'efeat': array([[ 0., 0., 0., 1., 1.],
[ 0., 0., 0., -1., 1.],
[ 0., 0., 0., 1., 1.],
...,
[ 0., 1., 0., 0., 0.],
[ 0., 1., 0., 0., 0.],
[ 0., 1., 0., 0., 0.]], dtype=float32),
'labels': array([[ 0. , 0. , 0. ],
[ 0. , 0. , 0. ],
...,
[ 0. , 0.9213, 0. ],
[ 6.8894, 3.5097, 5.7754],
[ 0. , 1.8426, 6.0642],
...,
[ 0. , 0. , 0. ],
[ 0. , 0. , 0. ]], dtype=float32),
'mask': array([[ True],
[ True],
......
[False]])}
查看各数据的维度
print(gdata['nfeat'].shape)
print(gdata['edges'].shape)
print(gdata['efeat'].shape)
print(gdata['labels'].shape)
print(gdata['mask'].shape)
运行结果
(143, 14)
(564, 2)
(564, 5)
(143, 3)
(143, 1)
训练数据解析的源代码:
class GraphParser(object):
def __init__(self, config, mode="train"):
self.config = config
self.mode = mode
def parse(self, sample):
labels = []
nfeat = []
efeat = []
edges = []
train_mask = []
test_mask = []
sequence = sample['sequence']
predicted_loop_type = sample['predicted_loop_type']
seq_length = sample['seq_length']
seq_scored = sample['seq_scored']
pair_info = match_pair(sample['structure'])
paired_nodes = {}
for j in range(seq_length):
add_base_node(nfeat, sequence[j], predicted_loop_type[j])
if j + 1 < seq_length: # edge between current node and next node
add_edges_between_base_nodes(edges, efeat, j, j + 1)
if pair_info[j] != -1:
if pair_info[j] not in paired_nodes:
paired_nodes[pair_info[j]] = [j]
else:
paired_nodes[pair_info[j]].append(j)
train_mask.append(j < seq_scored)
test_mask.append(True)
if self.config.add_edge_for_paired_nodes:
for pair in paired_nodes.values():
add_edges_between_paired_nodes(edges, efeat, pair[0], pair[1])
if self.config.add_codon_nodes:
codon_node_idx = seq_length - 1
for j in range(seq_length):
if j % 3 == 0:
# add codon node
add_codon_node(nfeat)
codon_node_idx += 1
train_mask.append(False)
test_mask.append(False)
if self.mode != "test":
labels.append([0, 0, 0])
if codon_node_idx > seq_length:
# add edges between adjacent codon nodes
add_edges_between_codon_nodes(
edges, efeat, codon_node_idx - 1, codon_node_idx)
# add edges between codon node and base node
add_edges_between_codon_and_base_node(
edges, efeat, j, codon_node_idx)
if self.mode != 'test':
react = sample['reactivity']
deg_Mg_pH10 = sample['deg_Mg_pH10']
deg_Mg_50C = sample['deg_Mg_50C']
for j in range(seq_length):
if j < seq_scored:
labels.append([react[j], deg_Mg_pH10[j], deg_Mg_50C[j] ])
else:
labels.append([0, 0, 0])
gdata = {}
gdata['nfeat'] = np.array(nfeat, dtype="float32")
gdata['edges'] = np.array(edges, dtype="int64")
gdata['efeat'] = np.array(efeat, dtype="float32")
if self.mode != "test":
gdata['labels'] = np.array(labels, dtype="float32")
gdata['mask'] = np.array(train_mask, dtype=bool).reshape(-1, 1)
else:
# fake labels
gdata['labels'] = np.zeros((self.config.batch_size, self.config.num_class))
gdata['mask'] = np.array(test_mask, dtype=bool).reshape(-1, 1)
return gdata
数据的可视化
fig = plt.figure(figsize=(24, 12))
nx_G = nx.Graph()
nx_G.add_nodes_from([i for i in range(len(gdata['nfeat']))])
nx_G.add_edges_from(gdata['edges'])
node_color = ['g' for _ in range(sample['seq_length'])] + \
['y' for _ in range(len(gdata['nfeat']) - sample['seq_length'])]
options = {
"node_color": node_color,
}
pos = nx.spring_layout(nx_G, iterations=400, k=0.2)
nx.draw(nx_G, pos, **options)
plt.show()
运行结果
三 GNN简化模型 实现消息发送及接收
import paddle
import paddle.fluid as fluid
def copy_send(src_feat, dst_feat, edge_feat):
"""doc"""
return src_feat["h"]
def mean_recv(feat):
"""doc"""
return fluid.layers.sequence_pool(feat, pool_type="average")
def sum_recv(feat):
"""doc"""
return fluid.layers.sequence_pool(feat, pool_type="sum")
def max_recv(feat):
"""doc"""
return fluid.layers.sequence_pool(feat, pool_type="max")
def simple_gnn(gw, feature, hidden_size, act, name):
"""doc"""
msg = gw.send(copy_send, nfeat_list=[("h", feature)])
neigh_feature = gw.recv(msg, sum_recv)
self_feature = feature
output = self_feature + neigh_feature
output = fluid.layers.fc(output,
hidden_size,
act=act,
param_attr=fluid.ParamAttr(name=name + '_w'),
bias_attr=fluid.ParamAttr(name=name + '_b'))
return output
四 GNN 模型类
class GNNModel(propeller.train.Model):
def __init__(self, hparam, mode, run_config):
self.hparam = hparam
self.mode = mode
self.is_test = True if self.mode != propeller.RunMode.TRAIN else False
self.run_config = run_config
def forward(self, input_dict):
gw = BatchGraphWrapper(input_dict['num_nodes'],
input_dict['num_edges'],
input_dict['edges'],
edge_feats={'efeat': input_dict['edge_feat']})
feature = L.fc(input_dict['node_feat'],
size=self.hparam.hidden_size,
act=None,
bias_attr=F.ParamAttr(name='embed_b'),
param_attr=F.ParamAttr(name="embed_w")
)
for layer in range(self.hparam.num_layers):
if layer == self.hparam.num_layers - 1:
act = None
else:
act = 'leaky_relu'
feature = GNNlayers.simple_gnn(
gw,
feature,
self.hparam.hidden_size,
act,
name="%s_%s" % (self.hparam.layer_type, layer))
feature = L.dropout(
feature,
self.hparam.dropout_prob,
dropout_implementation="upscale_in_train")
logits = L.fc(feature,
size=self.hparam.num_class,
act=None,
bias_attr=F.ParamAttr(name='final_b'),
param_attr=F.ParamAttr(name="final_w"))
mask = input_dict['mask']
logits = paddle_helper.masked_select(logits, mask)
return [logits, mask]
def loss(self, predictions, label):
logits = predictions[0]
mask = predictions[1]
label = paddle_helper.masked_select(label, mask)
loss = L.mse_loss(input=logits, label=label)
loss = L.reduce_mean(loss)
return loss
def backward(self, loss):
optimizer = F.optimizer.Adam(learning_rate=self.hparam.lr)
optimizer.minimize(loss)
def metrics(self, predictions, label):
result = {}
logits = predictions[0]
mask = predictions[1]
label = paddle_helper.masked_select(label, mask)
result["MCRMSE"] = propeller.metrics.MCRMSE(label, logits)
return result
其中采用的评估指标是MCRMSE:
五 GNN模型训练
......
def train(args):
train_ds = CovidDataset(data_file=args.train_file, config=args, mode="train")
valid_ds = CovidDataset(data_file=args.valid_file, config=args, mode="valid")
log.info("train examples: %s" % len(train_ds))
log.info("valid examples: %s" % len(train_ds))
train_loader = Dataloader(train_ds,
batch_size=args.batch_size,
shuffle=args.shuffle,
collate_fn=CollateFn())
train_loader = multi_epoch_dataloader(train_loader, args.epochs)
train_loader = PDataset.from_generator_func(train_loader)
valid_loader = Dataloader(valid_ds,
batch_size=1,
shuffle=False,
collate_fn=CollateFn())
valid_loader = PDataset.from_generator_func(valid_loader)
# warmup start setting
ws = None
propeller.train.train_and_eval(
model_class_or_model_fn=GNNModel,
params=args,
run_config=args,
train_dataset=train_loader,
eval_dataset={"eval": valid_loader},
warm_start_setting=ws,
)
def infer(args):
# predict for test data
log.info('Reading %s' % args.test_file)
test_ds = CovidDataset(args.test_file, args, 'test')
test_loader = Dataloader(test_ds,
batch_size=args.batch_size,
shuffle=False,
collate_fn=CollateFn(mode='test'))
test_loader = PDataset.from_generator_func(test_loader)
est = propeller.Learner(GNNModel, args, args)
output_path = args.model_path_for_infer.replace("checkpoints/", "outputs/")
make_dir(output_path)
filename = os.path.join(output_path, "submission.csv")
id_seqpos = build_id_seqpos(args.test_file)
preds = []
for predictions in est.predict(test_loader,
ckpt_path=args.model_path_for_infer,
split_batch=False):
preds.append(predictions[0])
preds = np.concatenate(preds)
df_sub = pd.DataFrame({'id_seqpos': id_seqpos,
'reactivity': preds[:,0],
'deg_Mg_pH10': preds[:,1],
'deg_pH10': 0,
'deg_Mg_50C': preds[:,2],
'deg_50C': 0})
log.info("saving result to %s" % filename)
df_sub.to_csv(filename, index=False)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='gnn')
parser.add_argument("--config", type=str, default="./config.yaml")
parser.add_argument("--mode", type=str, default="train")
args = parser.parse_args()
if args.mode == "infer":
config = prepare_config(args.config, isCreate=False, isSave=False)
infer(config)
else:
config = prepare_config(args.config, isCreate=True, isSave=True)
log_to_file(log, config.log_dir)
train(config)
GNN执行脚本:
python main.py
运行结果:
........
********** Start Loop ************
train loop has hook
train loop has hook
train loop has hook
train loop has hook
[training] step: 20 steps/sec: -1.00000 loss: 305.70929
[training] step: 40 steps/sec: 141.66859 loss: 103.75951
[training] step: 60 steps/sec: 139.14825 loss: 44.02296
END: epoch 0 ...
BEGIN: epoch 1 ...
[training] step: 80 steps/sec: 145.89391 loss: 35.04534
[training] step: 100 steps/sec: 145.03201 loss: 24.50893
[training] step: 120 steps/sec: 143.29997 loss: 22.46627
[training] step: 140 steps/sec: 150.44456 loss: 17.93622
[training] step: 160 steps/sec: 148.98990 loss: 15.78190
[training] step: 180 steps/sec: 144.50660 loss: 14.56967
END: epoch 1 ...
(240,)
{'MCRMSE': 0.54948246, 'loss': 0.30230046494398266}
write to tensorboard ../checkpoints/covid19/eval_history/eval
write to tensorboard ../checkpoints/covid19/eval_history/eval
[Eval:eval]:MCRMSE:0.5494824647903442 loss:0.30230046494398266
[training] step: 24980 steps/sec: 5.72037 loss: 0.37832
********** Stop Loop ************
saving step 25000 to ../checkpoints/covid19/model_25000
六 GNN模型预测
执行脚本
python main.py --mode infer
运行结果
aistudio@jupyter-112853-1529805:~/src$ python main.py --mode infer
[WARNING] 2021-02-15 20:46:38,028 [ __init__.py: 41]: enabling old_styled_ckpt
[DEBUG] 2021-02-15 20:46:38,029 [distribution.py: 130]: no PROPELLER_DISCONFIG found, try paddlestype setting
[DEBUG] 2021-02-15 20:46:38,030 [distribution.py: 133]: no paddle stype setting found
[WARNING] 2021-02-15 20:46:38,031 [monitored_executor.py: 392]: textone not found in ['/home/aistudio/src', '/opt/conda/envs/python35-paddle120-env/lib/python37.zip', '/opt/conda/envs/python35-paddle120-env/lib/python3.7', '/opt/conda/envs/python35-paddle120-env/lib/python3.7/lib-dynload', '/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages', '../']! will not load encrepted model
[INFO] 2021-02-15 20:46:38,041 [ main.py: 85]: Reading ../data/data60987/valid.json
[INFO] 2021-02-15 20:46:38,488 [ trainer.py: 220]: Building Predict Graph
[INFO] 2021-02-15 20:46:38,488 [functional.py: 417]: Try to infer data shapes & types from generator
[INFO] 2021-02-15 20:46:38,489 [functional.py: 434]: Dataset `predict` has data_shapes: [[-1, 5], [-1, 2], [-1], [-1, 14], [-1], [-1]] data_types: ['float32', 'int64', 'bool', 'float32', 'int64', 'int64']
[INFO] 2021-02-15 20:46:38,541 [ trainer.py: 225]: Done
[INFO] 2021-02-15 20:46:38,542 [ trainer.py: 240]: Predict with:
> Run_config: {'task_name': 'covid19', 'use_cuda': False, 'warm_start_from': '', 'model_path_for_infer': '../checkpoints/covid19/model_810', 'train_file': '../data/data60987/train.json', 'valid_file': '../data/data60987/valid.json', 'test_file': '../data/data60987/valid.json', 'percentage': 0.9, 'add_edge_for_paired_nodes': True, 'add_codon_nodes': True, 'num_layers': 5, 'layer_type': 'simple_gnn', 'emb_size': 64, 'hidden_size': 64, 'num_class': 3, 'dropout_prob': 0.1, 'epochs': 200, 'batch_size': 16, 'lr': 0.001, 'shuffle': True, 'save_steps': 200000000, 'log_steps': 20, 'max_ckpt': 8, 'skip_steps': 0, 'eval_steps': 320, 'eval_max_steps': 10000, 'stdout': True, 'log_dir': '../logs', 'log_filename': 'log.txt', 'save_dir': '../checkpoints', 'output_dir': '../outputs', 'files2saved': ['layers.py', 'data_parser.py', 'config.yaml', 'main.py', 'dataset.py', 'model.py'], 'model_dir': '../checkpoints'}
> Params: {'task_name': 'covid19', 'use_cuda': False, 'warm_start_from': '', 'model_path_for_infer': '../checkpoints/covid19/model_810', 'train_file': '../data/data60987/train.json', 'valid_file': '../data/data60987/valid.json', 'test_file': '../data/data60987/valid.json', 'percentage': 0.9, 'add_edge_for_paired_nodes': True, 'add_codon_nodes': True, 'num_layers': 5, 'layer_type': 'simple_gnn', 'emb_size': 64, 'hidden_size': 64, 'num_class': 3, 'dropout_prob': 0.1, 'epochs': 200, 'batch_size': 16, 'lr': 0.001, 'shuffle': True, 'save_steps': 200000000, 'log_steps': 20, 'max_ckpt': 8, 'skip_steps': 0, 'eval_steps': 320, 'eval_max_steps': 10000, 'stdout': True, 'log_dir': '../logs', 'log_filename': 'log.txt', 'save_dir': '../checkpoints', 'output_dir': '../outputs', 'files2saved': ['layers.py', 'data_parser.py', 'config.yaml', 'main.py', 'dataset.py', 'model.py'], 'model_dir': '../checkpoints'}
> Train_model_spec: ModelSpec(loss=None, predictions=[name: "gather_7.tmp_0"
......
[INFO] 2021-02-15 20:46:41,787 [monitored_executor.py: 540]: propeller runs in CUDA mode
[INFO] 2021-02-15 20:46:41,787 [monitored_executor.py: 547]: ********** Start Loop ************
[INFO] 2021-02-15 20:46:41,787 [ trainer.py: 390]: Runining predict from dir: {'gstep': 0, 'step': 0, 'time': 1613393198.543097}
[INFO] 2021-02-15 20:46:42,180 [monitored_executor.py: 606]: ********** Stop Loop ************
[INFO] 2021-02-15 20:46:42,195 [ main.py: 114]: saving result to ../outputs/covid19/model_810/submission.csv
注:本文图文资料来源于 AIStudio-人工智能学习与实训社区