forked from tbwxmu/SAMPN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
features.py
63 lines (49 loc) · 2.2 KB
/
features.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import csv
import os
import pickle
from typing import Callable, List, Union
import numpy as np
from rdkit import Chem
Molecule = Union[str, Chem.Mol]
FeaturesGenerator = Callable[[Molecule], np.ndarray]
FEATURES_GENERATOR_REGISTRY = {}
def save_features(path: str, features: List[np.ndarray]):
"""
Saves features to a compressed .npz file with array name "features".
:param path: Path to a .npz file where the features will be saved.
:param features: A list of 1D numpy arrays containing the features for molecules.
"""
np.savez_compressed(path, features=features)
def load_features(path: str) -> np.ndarray:
"""
Loads features saved in a variety of formats.
Supported formats:
- .npz compressed (assumes features are saved with name "features")
- .npz (assumes features are saved with name "features")
- .npy
- .csv/.txt (assumes comma-separated features with a header and with one line per molecule)
- .pkl/.pckl/.pickle containing a sparse numpy array (TODO: remove this option once we are no longer dependent on it)
All formats assume that the SMILES strings loaded elsewhere in the code are in the same
order as the features loaded here.
:param path: Path to a file containing features.
:return: A 2D numpy array of size (num_molecules, features_size) containing the features.
"""
extension = os.path.splitext(path)[1]
if extension == '.npz':
features = np.load(path)['features']
elif extension == '.npy':
features = np.load(path)
elif extension in ['.csv', '.txt']:
with open(path) as f:
reader = csv.reader(f)
next(reader) # skip header
features = np.array([[float(value) for value in row] for row in reader])
elif extension in ['.pkl', '.pckl', '.pickle']:
with open(path, 'rb') as f:
features = np.array([np.squeeze(np.array(feat.todense())) for feat in pickle.load(f)])
else:
raise ValueError(f'Features path extension {extension} not supported.')
return features
def get_available_features_generators() -> List[str]:
"""Returns the names of available features generators."""
return list(FEATURES_GENERATOR_REGISTRY.keys())