-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsave_graphs.py
executable file
·112 lines (93 loc) · 3.68 KB
/
save_graphs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
#!/usr/bin/env python
import click as ck
import numpy as np
import pandas as pd
import gzip
import logging
import torch as th
#DGL imports
import dgl
@ck.command()
def main():
for ont in ['mf', 'bp', 'cc']:
train_df = pd.read_pickle(f'data/{ont}/train_data_int.pkl')
proteins = train_df['proteins']
prot_idx = {v: k for k, v in enumerate(proteins)}
src = []
dst = []
edge_types = []
rels = {}
for i, row in enumerate(train_df.itertuples()):
p_id = prot_idx[row.proteins]
for rel, p2_id in row.interactions:
if rel not in rels:
rels[rel] = len(rels)
if p2_id in prot_idx:
p2_id = prot_idx[p2_id]
src.append(p_id)
dst.append(p2_id)
edge_types.append(rels[rel])
print(len(src), len(proteins))
train_n = len(proteins)
valid_df = pd.read_pickle(f'data/{ont}/valid_data_int.pkl')
valid_proteins = valid_df['proteins']
for i, p_id in enumerate(valid_proteins):
prot_idx[p_id] = train_n + i
valid_proteins = set(valid_proteins)
valid_n = len(valid_proteins)
for i, row in enumerate(train_df.itertuples()):
p_id = prot_idx[row.proteins]
for rel, p2_id in row.interactions:
if p2_id in valid_proteins:
p2_id = prot_idx[p2_id]
src.append(p_id)
dst.append(p2_id)
edge_types.append(rels[rel])
for i, row in enumerate(valid_df.itertuples()):
p_id = prot_idx[row.proteins]
for rel, p2_id in row.interactions:
if rel not in rels:
rels[rel] = len(rels)
if p2_id in prot_idx:
p2_id = prot_idx[p2_id]
src.append(p_id)
dst.append(p2_id)
edge_types.append(rels[rel])
train_df = pd.concat([train_df, valid_df])
test_df = pd.read_pickle(f'data/{ont}/test_data.pkl')
test_proteins = test_df['proteins']
for i, p_id in enumerate(test_proteins):
prot_idx[p_id] = train_n + valid_n + i
test_proteins = set(test_proteins)
test_n = len(test_proteins)
for i, row in enumerate(train_df.itertuples()):
p_id = prot_idx[row.proteins]
for rel, p2_id in row.interactions:
if p2_id in test_proteins:
p2_id = prot_idx[p2_id]
src.append(p_id)
dst.append(p2_id)
edge_types.append(rels[rel])
for i, row in enumerate(test_df.itertuples()):
p_id = prot_idx[row.proteins]
for rel, p2_id in row.interactions:
if rel not in rels:
rels[rel] = len(rels)
if p2_id in prot_idx:
p2_id = prot_idx[p2_id]
src.append(p_id)
dst.append(p2_id)
edge_types.append(rels[rel])
print(len(prot_idx))
graph = dgl.graph((src, dst), num_nodes=len(prot_idx))
graph.edata['etypes'] = th.LongTensor(edge_types)
graph = dgl.add_self_loop(graph)
dgl.save_graphs(
f'data/{ont}/ppi.bin', graph,
{
'train_nids': th.LongTensor(np.arange(train_n)),
'valid_nids': th.LongTensor(np.arange(train_n, train_n + valid_n)),
'test_nids': th.LongTensor(np.arange(train_n + valid_n, train_n + valid_n + test_n))
})
if __name__ == '__main__':
main()