Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfixes #95

Draft
wants to merge 18 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added .DS_Store
Binary file not shown.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ var/
*.lprof
*.prof
*.out
.coverage*

# pytest
.pytest*
Expand Down
2 changes: 1 addition & 1 deletion scripts/predict_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
def predict_fingerprints(data, path_parameters: Path, path_scales: Path):
"""Predict data using a previously trained fingerprint model."""
# FullyConnected NN
net = FingerprintFullyConnected(hidden_cells=100, num_labels=NUMLABELS)
net = FingerprintFullyConnected(hidden_units=(100, 100), output_units=NUMLABELS)
return call_modeller(net, data, data.fingerprints, path_parameters, path_scales)


Expand Down
4 changes: 2 additions & 2 deletions scripts/predict_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from swan.dataset import (DGLGraphData, FingerprintsData,
TorchGeometricGraphData)
from swan.dataset.dgl_graph_data import dgl_data_loader
from swan.modeller import Modeller
from swan.modeller import TorchModeller as Modeller
from swan.modeller.models import MPNN, FingerprintFullyConnected
from swan.modeller.models.se3_transformer import SE3Transformer
from swan.utils.plot import create_scatter_plot
Expand All @@ -26,7 +26,7 @@ def predict_fingerprints():
"""Predict data using a previously trained fingerprint model."""
data = FingerprintsData(PATH_DATA, sanitize=True)
# FullyConnected NN
net = FingerprintFullyConnected(hidden_cells=100, num_labels=NUMLABELS)
net = FingerprintFullyConnected(hidden_units=(100, 100), output_units=NUMLABELS)
return call_modeller(net, data, data.fingerprints)


Expand Down
2 changes: 1 addition & 1 deletion scripts/run_torch_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
# path_data, properties=properties, file_geometries=path_geometries, sanitize=False)
# data = TorchGeometricGraphData(path_data, properties=properties, file_geometries=path_geometries, sanitize=False)
# FullyConnected NN
net = FingerprintFullyConnected(hidden_cells=100, num_labels=num_labels)
net = FingerprintFullyConnected(hidden_units=(100, 100), output_units=num_labels)

# # Graph NN configuration
# net = MPNN(batch_size=batch_size, output_channels=40, num_labels=num_labels)
Expand Down
6 changes: 3 additions & 3 deletions swan/dataset/data_graph_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from typing import List, Optional, Union


from .geometry import guess_positions
from .swan_data_base import SwanDataBase
from ..type_hints import PathLike
from swan.dataset.geometry import guess_positions
from swan.dataset.swan_data_base import SwanDataBase
from swan.type_hints import PathLike


__all__ = ["SwanGraphData"]
Expand Down
6 changes: 3 additions & 3 deletions swan/dataset/dgl_graph_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import torch
from torch.utils.data import DataLoader, Dataset

from .data_graph_base import SwanGraphData
from .graph.molecular_graph import create_molecular_dgl_graph
from ..type_hints import PathLike
from swan.dataset.data_graph_base import SwanGraphData
from swan.dataset.graph.molecular_graph import create_molecular_dgl_graph
from swan.type_hints import PathLike

try:
import dgl
Expand Down
9 changes: 6 additions & 3 deletions swan/dataset/fingerprints_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@
import torch
from torch.utils.data import Dataset

from .features.featurizer import generate_fingerprints
from .swan_data_base import SwanDataBase
from ..type_hints import PathLike
from rdkit import RDLogger

from swan.dataset.features.featurizer import generate_fingerprints
from swan.dataset.swan_data_base import SwanDataBase
from swan.type_hints import PathLike

__all__ = ["FingerprintsData"]
RDLogger.DisableLog('rdApp.*') # disable rdkit messages, preventing spam of 'Molecule does not have explicit Hs. Consider calling AddHs()'


class FingerprintsData(SwanDataBase):
Expand Down
40 changes: 40 additions & 0 deletions swan/dataset/fingerprints_data_export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from pathlib import Path

import pandas as pd
import torch

from swan.dataset import FingerprintsData

path_files = Path("data")

# Compute full dataset of fingerprints and all properties, for both ligands
paths = {
'carboxylics': path_files / 'Carboxylic_acids/CDFT/all_carboxylics.csv',
'amines': path_files / 'Amines/CDFT/all_amines.csv'
}
frames = {ligand_type: pd.read_csv(path, index_col='Unnamed: 0') for ligand_type, path in paths.items()}
properties = list(frames['carboxylics'].columns[2:])
fp_data = {ligand_type: FingerprintsData(path, sanitize=False, properties=properties)
for ligand_type, path in paths.items()}
Xs = {ligand_type: fp.fingerprints for ligand_type, fp in fp_data.items()}
ys = {ligand_type: fp.labels for ligand_type, fp in fp_data.items()}

# shuffle data for both ligands
torch.manual_seed(42) # to make it deterministic
indices = {ligand_type: torch.randperm(X.shape[0]) for ligand_type, X in Xs.items()}
Xs_shuffled = {ligand_type: X[indices] for (ligand_type, X), indices in zip(Xs.items(), indices.values())}
ys_shuffled = {ligand_type: y[indices] for (ligand_type, y), indices in zip(ys.items(), indices.values())}

# set aside 1000 data points of carboxylics as the test set
n_test = 1_000
test_data = Xs_shuffled['carboxylics'][:n_test], ys_shuffled['carboxylics'][:n_test]

# the remaining carboxylics, in addition to potentially all amines, are the training set
# these are to be split into training and validation sets during usage
train_data_carboxylics = Xs_shuffled['carboxylics'][n_test:], ys_shuffled['carboxylics'][n_test:]
train_data_amines = Xs_shuffled['amines'], ys_shuffled['amines']

# save
torch.save(test_data, path_files / 'Carboxylic_acids/CDFT/fingerprint_test')
torch.save(train_data_carboxylics, path_files / 'Carboxylic_acids/CDFT/fingerprint_train')
torch.save(train_data_amines, path_files / 'Amines/CDFT/fingerprint_train')
2 changes: 1 addition & 1 deletion swan/dataset/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from rdkit import Chem
from rdkit.Chem import AllChem

from ..type_hints import PathLike
from swan.type_hints import PathLike


def read_geometries_from_files(file_geometries: PathLike) -> Tuple[List[Chem.rdchem.Mol], List[np.ndarray]]:
Expand Down
4 changes: 2 additions & 2 deletions swan/dataset/splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import numpy as np
import torch

from ..state import StateH5
from ..type_hints import PathLike
from swan.state import StateH5
from swan.type_hints import PathLike


class SplitDataset(NamedTuple):
Expand Down
6 changes: 3 additions & 3 deletions swan/dataset/swan_data_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from sklearn.preprocessing import RobustScaler
from torch.utils.data import DataLoader, Dataset, Subset

from ..type_hints import PathLike
from .geometry import read_geometries_from_files
from .sanitize_data import sanitize_data
from swan.type_hints import PathLike
from swan.dataset.geometry import read_geometries_from_files
from swan.dataset.sanitize_data import sanitize_data

__all__ = ["SwanDataBase"]

Expand Down
6 changes: 3 additions & 3 deletions swan/dataset/torch_geometric_graph_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
import torch_geometric as tg
from torch_geometric.data import Data

from ..type_hints import PathLike
from .data_graph_base import SwanGraphData
from .graph.molecular_graph import create_molecular_torch_geometric_graph
from swan.type_hints import PathLike
from swan.dataset.data_graph_base import SwanGraphData
from swan.dataset.graph.molecular_graph import create_molecular_torch_geometric_graph


class TorchGeometricGraphData(SwanGraphData):
Expand Down
6 changes: 3 additions & 3 deletions swan/modeller/base_modeller.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import numpy as np
import torch

from ..dataset.swan_data_base import SwanDataBase
from ..state import StateH5
from ..type_hints import PathLike
from swan.dataset.swan_data_base import SwanDataBase
from swan.state import StateH5
from swan.type_hints import PathLike

# `bound` preserves all sub-type information, which might be useful
T_co = TypeVar('T_co', bound=Union[np.ndarray, torch.Tensor], covariant=True)
Expand Down
6 changes: 3 additions & 3 deletions swan/modeller/gp_modeller.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
import torch
from torch import Tensor

from ..dataset.fingerprints_data import FingerprintsData
from ..dataset.splitter import SplitDataset
from .torch_modeller import TorchModeller
from swan.dataset.fingerprints_data import FingerprintsData
from swan.dataset.splitter import SplitDataset
from swan.modeller.torch_modeller import TorchModeller

# Starting logger
LOGGER = logging.getLogger(__name__)
Expand Down
41 changes: 32 additions & 9 deletions swan/modeller/models/fingerprint_models.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,47 @@
"""Statistical models."""

from torch import Tensor, nn
from itertools import chain
from typing import Tuple

__all__ = ["FingerprintFullyConnected"]


class FingerprintFullyConnected(nn.Module):
"""Fully connected network for non-linear regression."""

def __init__(self, input_features: int = 2048, hidden_cells: int = 100, num_labels: int = 1):
def __init__(self, input_units: int = 2048, hidden_units: Tuple[int, ...] = (100, 100), output_units: int = 1,
activation: nn.Module = nn.ReLU):
"""Create a deep feed foward network."""
super().__init__()
self.seq = nn.Sequential(
nn.Linear(input_features, hidden_cells),
nn.ReLU(),
nn.Linear(hidden_cells, hidden_cells),
nn.ReLU(),
nn.Linear(hidden_cells, num_labels),
)
self.activation = activation
self.input_units = input_units
self.hidden_units = hidden_units
self.output_units = output_units

self.layers = self._construct_layers()
self.net = nn.Sequential(*self.layers)

def forward(self, tensor: Tensor) -> Tensor:
"""Run the model."""
return self.seq(tensor)
return self.net(tensor)

def _construct_layers(self):
in_sizes = [self.input_units] + list(self.hidden_units)
out_sizes = list(self.hidden_units) + [self.output_units]

linear_layers = [nn.Linear(n_in, n_out) for n_in, n_out in zip(in_sizes, out_sizes)]

activations = [self.activation() for _ in linear_layers]

all_layers = list(chain(*zip(linear_layers, activations)))[:-1] # no activation after last layer

return all_layers

def get_config(self):
return {
'input_units': self.input_units,
'hidden_units': self.hidden_units,
'output_units': self.output_units,
'activation': self.activation
}
8 changes: 4 additions & 4 deletions swan/modeller/scikit_modeller.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
import numpy as np
from sklearn import gaussian_process, svm, tree

from ..dataset.fingerprints_data import FingerprintsData
from ..dataset.splitter import split_dataset
from ..type_hints import PathLike
from .base_modeller import BaseModeller
from swan.dataset.fingerprints_data import FingerprintsData
from swan.dataset.splitter import split_dataset
from swan.type_hints import PathLike
from swan.modeller.base_modeller import BaseModeller

LOGGER = logging.getLogger(__name__)

Expand Down
22 changes: 11 additions & 11 deletions swan/modeller/torch_modeller.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
import torch
from torch import Tensor, nn

from ..dataset.swan_data_base import SwanDataBase
from ..type_hints import PathLike
from ..utils.early_stopping import EarlyStopping
from .base_modeller import BaseModeller
from swan.dataset.swan_data_base import SwanDataBase
from swan.type_hints import PathLike
from swan.utils.early_stopping import EarlyStopping
from swan.modeller.base_modeller import BaseModeller
import numpy as np
import sklearn

Expand Down Expand Up @@ -97,7 +97,7 @@ def set_loss(self, name: str, *args, **kwargs) -> None:
self.loss_func = getattr(nn, name)(*args, **kwargs)

def set_scheduler(self, name, *args, **kwargs) -> None:
"""Set the sceduler used for decreasing the LR
"""Set the scheduler used for decreasing the learning rate

Parameters
----------
Expand Down Expand Up @@ -134,11 +134,11 @@ def train_model(self,
Parameters
----------
nepoch : int
number of ecpoch to run
number of epochs to run
frac : List[int], optional
divide the dataset in train/valid, by default [0.8, 0.2]
batch_size : int, optional
batchsize, by default 64
batch size, by default 64
"""
LOGGER.info("TRAINING STEP")
self.split_data(frac, batch_size)
Expand Down Expand Up @@ -198,7 +198,7 @@ def train_batch(self, inp_data: Tensor, ground_truth: Tensor) -> Tuple[float, Te
inp_data : Tensor
input data of the network
ground_truth : Tensor
ground trurth of the data points in input
ground truth of the data points in input

Returns
-------
Expand All @@ -214,7 +214,7 @@ def train_batch(self, inp_data: Tensor, ground_truth: Tensor) -> Tuple[float, Te
return loss.item(), prediction

def validate_model(self) -> Tuple[Tensor, Tensor]:
"""compute the output of the model on the validation set
"""Compute the output of the model on the validation set

Returns
-------
Expand Down Expand Up @@ -243,7 +243,7 @@ def validate_model(self) -> Tuple[Tensor, Tensor]:
return tuple(self.inverse_transform(torch.cat(x)) for x in (results, expected))

def predict(self, inp_data: Tensor) -> Tensor:
"""compute output of the model for a given input
"""Compute output of the model for a given input

Parameters
----------
Expand All @@ -264,7 +264,7 @@ def save_model(self,
epoch: int,
loss: float,
filename: str = 'swan_chk.pt') -> None:
"""Save the modle current status."""
"""Save the model's current status."""
path = self.workdir / filename
torch.save(
{
Expand Down
16 changes: 11 additions & 5 deletions swan/state/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import h5py
import numpy as np

from ..type_hints import PathLike, ArrayLike
from swan.type_hints import PathLike, ArrayLike


class StateH5:
Expand Down Expand Up @@ -41,14 +41,18 @@ def has_data(self, data: Union[str, List[str]]) -> bool:
return data in f5

def store_array(self, node: str, data: ArrayLike, dtype: str = "float") -> None:
"""Store a tensor in the HDF5.
"""
Store a tensor as a dataset in the HDF5 file.
If the dataset already exists, it overwrites it.

Parameters
----------
paths
list of nodes where the data is going to be stored
node
node where the data is going to be stored
data
Numpy array or list of array to store
dtype
dtype of data
"""
supported_types = {'int': int, 'float': float, 'str': h5py.string_dtype(encoding='utf-8')}
if dtype in supported_types:
Expand All @@ -58,7 +62,9 @@ def store_array(self, node: str, data: ArrayLike, dtype: str = "float") -> None:
raise RuntimeError(msg)

with h5py.File(self.path, 'r+') as f5:
f5.require_dataset(node, shape=np.shape(data), data=data, dtype=dtype)
if node in f5:
del f5[node]
f5.create_dataset(node, shape=np.shape(data), data=data, dtype=dtype)

def retrieve_data(self, paths_to_prop: str) -> Any:
"""Read Numerical properties from ``paths_hdf5``.
Expand Down
Loading