-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathdataset.py
118 lines (71 loc) · 2.54 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
import time
import json
import loading
import gc
import csv
import random
import math
import numpy as np
import scipy as sp
from tqdm import tqdm
import logging
from data import BondType
logging.getLogger("pysmiles").setLevel(logging.CRITICAL)
config = json.loads(open("config.json").read())
print("Reading the dataset...")
all_smiles = list(csv.reader(open(config["scores"]), delimiter="\t"))
if config["shuffle-seed"] != None:
random.Random(config["shuffle-seed"]).shuffle(all_smiles)
print("Initialising the dataset...")
def get_data (smiles):
# Parse the graphs
returnable = list()
for smile in tqdm(smiles):
if smile[0] == "":
continue
molecule_data = smile[1].split(",")
fingerprint_data = molecule_data[0]
regression_data = [0 if x == "" else float(x) for x in molecule_data[1:]]
fingerprint = [int(x) for x in fingerprint_data]
# CIS TRANS chemistry
graph = loading.get_data(smile[0], apply_paths=False, parse_cis_trans=False)
if len(graph[0]) > config["node-cutoff"]:
continue
x, a, e = loading.convert(*graph, bonds=[
BondType.SINGLE, BondType.DOUBLE, BondType.TRIPLE, BondType.AROMATIC, BondType.NOT_CONNECTED])
fingerprint = list(fingerprint) + regression_data
returnable.append([(x, a, e), fingerprint])
return returnable
data = get_data(all_smiles[:config["truncation"]])
gc.collect()
print("Dataset initialised! {} graphs have been loaded.".format(len(data)))
from spektral.data import Dataset, Graph
class SMILESDataset (Dataset):
def __init__ (self, training=False, all_data=False, **kwargs):
if all_data:
assert not training
self.training = training
self.all_data = all_data
super().__init__(**kwargs)
def download (self):
# No download function
pass
def get_dataset (self):
cutoff = math.floor((1 - config["validation-ratio"]) * len(data))
if self.all_data:
adj_data = data
else:
if self.training:
adj_data = data[:cutoff]
else:
adj_data = data[cutoff:]
return adj_data
def read (self):
adj_data = self.get_dataset()
returnable = list()
for adj_datapoint in adj_data:
x, a, e = adj_datapoint[0]
fingerprint = adj_datapoint[1]
returnable.append(Graph(x=np.array(x), a=np.array(a), e=np.array(e), y=np.array(fingerprint)))
return returnable