From 221b29b0f37481ef6bfa4e6e2e64c3cd45d5287e Mon Sep 17 00:00:00 2001 From: Steffen Hirte Date: Wed, 25 Sep 2024 08:39:26 +0200 Subject: [PATCH 01/11] Cache calls to Configuration.get_dict * implement key-based access: config['description'] * add __repr__ method --- nerdd_module/config/configuration.py | 49 ++++++++++++++++------------ 1 file changed, 29 insertions(+), 20 deletions(-) diff --git a/nerdd_module/config/configuration.py b/nerdd_module/config/configuration.py index d57a190..db59d89 100644 --- a/nerdd_module/config/configuration.py +++ b/nerdd_module/config/configuration.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod -from typing import Dict, List +from functools import lru_cache +from typing import List __all__ = ["Configuration"] @@ -8,45 +9,53 @@ class Configuration(ABC): def __init__(self): pass - def get_dict(self) -> Dict: - config_dict = self._get_dict() + @lru_cache + def get_dict(self) -> dict: + config = self._get_dict() - return config_dict + # TODO: validate + + return config @abstractmethod - def _get_dict(self) -> Dict: + def _get_dict(self) -> dict: pass def molecular_property_columns(self) -> List[str]: return [ c["name"] - for c in self.get_dict().get("result_properties", []) - if "level" not in c or c["level"] == "molecule" + for c in self["result_properties"] + if c.get("level", "molecule") == "molecule" ] def atom_property_columns(self) -> List[str]: return [ - c["name"] - for c in self.get_dict().get("result_properties", []) - if "level" in c and c["level"] == "atom" + c["name"] for c in self["result_properties"] if c.get("level") == "atom" ] def derivative_property_columns(self) -> List[str]: return [ c["name"] - for c in self.get_dict().get("result_properties", []) - if "level" in c and c["level"] == "derivative" + for c in self["result_properties"] + if c.get("level") == "derivative" ] - def get_module_type(self) -> str: + def get_task(self) -> str: + num_atom_properties = len(self.atom_property_columns()) + num_derivative_properties = len(self.derivative_property_columns()) assert ( - len(self.atom_property_columns()) == 0 - or len(self.derivative_property_columns()) == 0 + num_atom_properties == 0 or num_derivative_properties == 0 ), "A module can only predict atom or derivative properties, not both." - if len(self.atom_property_columns()) > 0: - return "atom_property_predictor" - elif len(self.derivative_property_columns()) > 0: - return "derivative_property_predictor" + if num_atom_properties > 0: + return "atom_property_prediction" + elif num_derivative_properties > 0: + return "derivative_property_prediction" else: - return "molecule_property_predictor" + return "molecular_property_prediction" + + def __getitem__(self, key): + return self.get_dict()[key] + + def __repr__(self): + return f"{self.__class__.__name__}({self._get_dict()})" From c45a747668636135f121c4212b78c6e46d37973f Mon Sep 17 00:00:00 2001 From: Steffen Hirte Date: Wed, 25 Sep 2024 08:46:03 +0200 Subject: [PATCH 02/11] Decouple AutoConfiguration from Model * avoid that AutoConfiguration calls the _get_config method of the Model class (instead let _get_config returns an AutoConfiguration) * make package and file-based configuration two exclusive options in AutoConfiguration --- nerdd_module/config/auto_configuration.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/nerdd_module/config/auto_configuration.py b/nerdd_module/config/auto_configuration.py index ff21283..a11ac3c 100644 --- a/nerdd_module/config/auto_configuration.py +++ b/nerdd_module/config/auto_configuration.py @@ -47,14 +47,10 @@ def __init__(self, nerdd_module): if default_config_file is not None: configs.append(YamlConfiguration(default_config_file)) - - # 2.b search for nerdd.yml in the package (submodule package_name.data) - data_module = f"{root_module}.data" - configs.append(PackageConfiguration(data_module)) - - # 3. module can be configured via the method _get_config in the module - if hasattr(nerdd_module, "_get_config"): - configs.append(DictConfiguration(nerdd_module._get_config())) + else: + # 2.b search for nerdd.yml in the package (submodule package_name.data) + data_module = f"{root_module}.data" + configs.append(PackageConfiguration(data_module)) self.delegate = MergedConfiguration(*configs) From 7cdb39ba6ce50cbef01cc4d24b3a39fdfdfae9a2 Mon Sep 17 00:00:00 2001 From: Steffen Hirte Date: Wed, 25 Sep 2024 14:25:09 +0200 Subject: [PATCH 03/11] Refactor configuration classes * add types * derive DefaultConfiguration from DictConfiguration * move yaml search code from AutoConfiguration to SearchYamlConfiguration * implement proper merge mechanism in MergedConfiguration * move code from AutoConfiguration to model classes --- nerdd_module/config/__init__.py | 2 +- nerdd_module/config/auto_configuration.py | 58 ------------------- nerdd_module/config/configuration.py | 54 ++++++++++------- nerdd_module/config/default_configuration.py | 18 +++--- nerdd_module/config/dict_configuration.py | 9 ++- nerdd_module/config/merged_configuration.py | 42 +++++++++++--- .../config/search_yaml_configuration.py | 32 ++++++++++ nerdd_module/config/yaml_configuration.py | 4 +- 8 files changed, 113 insertions(+), 106 deletions(-) delete mode 100644 nerdd_module/config/auto_configuration.py create mode 100644 nerdd_module/config/search_yaml_configuration.py diff --git a/nerdd_module/config/__init__.py b/nerdd_module/config/__init__.py index cb67abd..fcf5da9 100644 --- a/nerdd_module/config/__init__.py +++ b/nerdd_module/config/__init__.py @@ -1,7 +1,7 @@ -from .auto_configuration import * from .configuration import * from .default_configuration import * from .dict_configuration import * from .merged_configuration import * from .package_configuration import * +from .search_yaml_configuration import * from .yaml_configuration import * diff --git a/nerdd_module/config/auto_configuration.py b/nerdd_module/config/auto_configuration.py deleted file mode 100644 index a11ac3c..0000000 --- a/nerdd_module/config/auto_configuration.py +++ /dev/null @@ -1,58 +0,0 @@ -import os -import sys - -from .configuration import Configuration -from .default_configuration import DefaultConfiguration -from .dict_configuration import DictConfiguration -from .merged_configuration import MergedConfiguration -from .package_configuration import PackageConfiguration -from .yaml_configuration import YamlConfiguration - -__all__ = ["AutoConfiguration"] - - -class AutoConfiguration(Configuration): - def __init__(self, nerdd_module): - super().__init__() - - # get the class of the nerdd module, e.g. - nerdd_module_class = nerdd_module.__class__ - - # get the module name of the nerdd module class - # e.g. "cypstrate.cypstrate_model" - python_module = nerdd_module_class.__module__ - - # get the root module name, e.g. "cypstrate" - root_module = python_module.split(".")[0] - - # collect configurations that are used - configs = [] - - # 1. module has a default configuration (containing default values) - configs.append(DefaultConfiguration(nerdd_module)) - - # 2. module can be configured via a yaml file - # 2.a search for nerdd.yml in the file tree - # start at the directory containing the file where nerdd_module_class is - # defined and go up the directory tree until nerdd.yml is found - leaf = sys.modules[nerdd_module_class.__module__].__file__ or "" - while True: - if os.path.isfile(os.path.join(leaf, "nerdd.yml")): - default_config_file = os.path.join(leaf, "nerdd.yml") - break - elif leaf == os.path.dirname(leaf): - default_config_file = None - break - leaf = os.path.dirname(leaf) - - if default_config_file is not None: - configs.append(YamlConfiguration(default_config_file)) - else: - # 2.b search for nerdd.yml in the package (submodule package_name.data) - data_module = f"{root_module}.data" - configs.append(PackageConfiguration(data_module)) - - self.delegate = MergedConfiguration(*configs) - - def _get_dict(self): - return self.delegate._get_dict() diff --git a/nerdd_module/config/configuration.py b/nerdd_module/config/configuration.py index db59d89..ff0358c 100644 --- a/nerdd_module/config/configuration.py +++ b/nerdd_module/config/configuration.py @@ -5,6 +5,10 @@ __all__ = ["Configuration"] +def get_property_columns_of_type(config, t) -> List[dict]: + return [c for c in config["result_properties"] if c.get("level", "molecule") == t] + + class Configuration(ABC): def __init__(self): pass @@ -13,7 +17,17 @@ def __init__(self): def get_dict(self) -> dict: config = self._get_dict() - # TODO: validate + if "result_properties" not in config: + config["result_properties"] = [] + + # check that a module can only predict atom or derivative properties, not both + num_atom_properties = len(get_property_columns_of_type(config, "atom")) + num_derivative_properties = len( + get_property_columns_of_type(config, "derivative") + ) + assert ( + num_atom_properties == 0 or num_derivative_properties == 0 + ), "A module can only predict atom or derivative properties, not both." return config @@ -21,31 +35,27 @@ def get_dict(self) -> dict: def _get_dict(self) -> dict: pass - def molecular_property_columns(self) -> List[str]: - return [ - c["name"] - for c in self["result_properties"] - if c.get("level", "molecule") == "molecule" - ] - - def atom_property_columns(self) -> List[str]: - return [ - c["name"] for c in self["result_properties"] if c.get("level") == "atom" - ] - - def derivative_property_columns(self) -> List[str]: - return [ - c["name"] - for c in self["result_properties"] - if c.get("level") == "derivative" - ] + def is_empty(self) -> bool: + return self.get_dict() == {} + + def molecular_property_columns(self) -> List[dict]: + return get_property_columns_of_type(self, "molecule") + + def atom_property_columns(self) -> List[dict]: + return get_property_columns_of_type(self, "atom") + + def derivative_property_columns(self) -> List[dict]: + return get_property_columns_of_type(self, "derivative") def get_task(self) -> str: + # if task is specified in the config, use that + config = self.get_dict() + if "task" in config: + return config["task"] + + # try to derive the task from the result_properties num_atom_properties = len(self.atom_property_columns()) num_derivative_properties = len(self.derivative_property_columns()) - assert ( - num_atom_properties == 0 or num_derivative_properties == 0 - ), "A module can only predict atom or derivative properties, not both." if num_atom_properties > 0: return "atom_property_prediction" diff --git a/nerdd_module/config/default_configuration.py b/nerdd_module/config/default_configuration.py index 0786182..ad168d2 100644 --- a/nerdd_module/config/default_configuration.py +++ b/nerdd_module/config/default_configuration.py @@ -1,15 +1,13 @@ from stringcase import snakecase from ..polyfills import version -from .configuration import Configuration +from .dict_configuration import DictConfiguration __all__ = ["DefaultConfiguration"] -class DefaultConfiguration(Configuration): +class DefaultConfiguration(DictConfiguration): def __init__(self, nerdd_module): - super().__init__() - # generate a name from the module name class_name = nerdd_module.__class__.__name__ if class_name.endswith("Model"): @@ -25,17 +23,15 @@ def __init__(self, nerdd_module): try: module = nerdd_module.__module__ root_module = module.split(".", 1)[0] - version_ = version(root_module) + package_version = version(root_module) except ModuleNotFoundError: - version_ = "0.0.1" + package_version = "0.0.1" - self.config = dict( + config = dict( name=name, - version=version_, - task="molecular_property_prediction", + version=package_version, job_parameters=[], result_properties=[], ) - def _get_dict(self): - return self.config + super().__init__(config) diff --git a/nerdd_module/config/dict_configuration.py b/nerdd_module/config/dict_configuration.py index f4c7222..50e9265 100644 --- a/nerdd_module/config/dict_configuration.py +++ b/nerdd_module/config/dict_configuration.py @@ -4,10 +4,9 @@ class DictConfiguration(Configuration): - def __init__(self, config): + def __init__(self, config: dict) -> None: super().__init__() + self._config = config - self.config = config - - def _get_dict(self): - return self.config + def _get_dict(self) -> dict: + return self._config diff --git a/nerdd_module/config/merged_configuration.py b/nerdd_module/config/merged_configuration.py index 0225c96..ed9728a 100644 --- a/nerdd_module/config/merged_configuration.py +++ b/nerdd_module/config/merged_configuration.py @@ -1,18 +1,44 @@ +from collections import Counter + from .configuration import Configuration +from .dict_configuration import DictConfiguration __all__ = ["MergedConfiguration"] -class MergedConfiguration(Configuration): - def __init__(self, *configs): - super().__init__() +def merge(*args): + assert len(args) > 0 + + first_entry = args[0] + assert all(isinstance(d, type(first_entry)) for d in args) + + if isinstance(first_entry, list): + return [e for d in args for e in d] + if isinstance(first_entry, dict): + count_fields = Counter([k for d in args for k in d.keys()]) - self.config = dict() + # merge fields that occur in multiple dicts + overlapping_fields = [k for k, v in count_fields.items() if v > 1] + merged_overlapping_fields = { + k: merge(*[d[k] for d in args if k in d]) for k in overlapping_fields + } + # collect fields that occur in only one dict + non_overlapping_fields = [k for k, v in count_fields.items() if v == 1] + merged_non_overlapping_fields = { + k: v for d in args for k, v in d.items() if k in non_overlapping_fields + } + + return { + **merged_non_overlapping_fields, + **merged_overlapping_fields, + } + else: # merge all configurations starting from the first one # --> last configuration has the highest priority - for c in configs: - self.config.update(c._get_dict()) + return args[-1] + - def _get_dict(self): - return self.config +class MergedConfiguration(DictConfiguration): + def __init__(self, *configs: Configuration): + super().__init__(merge(*[c.get_dict() for c in configs])) diff --git a/nerdd_module/config/search_yaml_configuration.py b/nerdd_module/config/search_yaml_configuration.py new file mode 100644 index 0000000..de56373 --- /dev/null +++ b/nerdd_module/config/search_yaml_configuration.py @@ -0,0 +1,32 @@ +import os +import sys +from typing import Any, Optional + +from .configuration import Configuration +from .dict_configuration import DictConfiguration +from .yaml_configuration import YamlConfiguration + + +class SearchYamlConfiguration(DictConfiguration): + def __init__(self, start: str, base_path: Optional[str] = None) -> None: + # provide a default configuration if no configuration file is found + config: Configuration = DictConfiguration({}) + + if start is not None: + # start at the directory containing the file where nerdd_module_class is + # defined and go up the directory tree until nerdd.yml is found (or root is + # reached) + leaf = start + while True: + if os.path.isfile(os.path.join(leaf, "nerdd.yml")): + default_config_file = os.path.join(leaf, "nerdd.yml") + break + elif leaf == os.path.dirname(leaf): # reached root + default_config_file = None + break + leaf = os.path.dirname(leaf) + + if default_config_file is not None: + config = YamlConfiguration(default_config_file, base_path) + + super().__init__(config.get_dict()) diff --git a/nerdd_module/config/yaml_configuration.py b/nerdd_module/config/yaml_configuration.py index 2753782..1d5fca4 100644 --- a/nerdd_module/config/yaml_configuration.py +++ b/nerdd_module/config/yaml_configuration.py @@ -1,6 +1,8 @@ import base64 import os import pathlib +from os import PathLike +from typing import Optional import filetype import yaml @@ -26,7 +28,7 @@ def image_constructor(loader, node): class YamlConfiguration(Configuration): - def __init__(self, handle, base_path=None): + def __init__(self, handle: PathLike, base_path: Optional[PathLike] = None) -> None: super().__init__() if base_path is None: From 7795295bc880433109983b2b56ca1eced2bcfbd3 Mon Sep 17 00:00:00 2001 From: Steffen Hirte Date: Wed, 25 Sep 2024 14:45:15 +0200 Subject: [PATCH 04/11] Refactor input classes * introduce convenient Explorer._read method * add another criterion to find the best reader in DepthFirstExplorer * change type Generator[..., None, None] to Iterator[...] * make ReaderRegistry aware of potential constructor arguments in individual reader classes --- nerdd_module/input/__init__.py | 1 + nerdd_module/input/depth_first_explorer.py | 70 +++++++++++++--------- nerdd_module/input/explorer.py | 9 ++- nerdd_module/input/file_reader.py | 9 +-- nerdd_module/input/gzip_reader.py | 4 +- nerdd_module/input/inchi_reader.py | 4 +- nerdd_module/input/list_reader.py | 4 +- nerdd_module/input/mol_reader.py | 4 +- nerdd_module/input/reader.py | 4 +- nerdd_module/input/reader_registry.py | 67 +++++++-------------- nerdd_module/input/sdf_reader.py | 4 +- nerdd_module/input/smiles_reader.py | 11 +++- nerdd_module/input/string_reader.py | 4 +- nerdd_module/input/tar_reader.py | 4 +- nerdd_module/input/zip_reader.py | 4 +- 15 files changed, 103 insertions(+), 100 deletions(-) diff --git a/nerdd_module/input/__init__.py b/nerdd_module/input/__init__.py index f4180a2..2f4ca7e 100644 --- a/nerdd_module/input/__init__.py +++ b/nerdd_module/input/__init__.py @@ -1,4 +1,5 @@ from .depth_first_explorer import * +from .explorer import * from .file_reader import * from .gzip_reader import * from .inchi_reader import * diff --git a/nerdd_module/input/depth_first_explorer.py b/nerdd_module/input/depth_first_explorer.py index 64d0e92..6d39489 100644 --- a/nerdd_module/input/depth_first_explorer.py +++ b/nerdd_module/input/depth_first_explorer.py @@ -1,5 +1,5 @@ from itertools import chain, islice, repeat -from typing import Generator, Iterable, Optional +from typing import Iterable, Iterator, Optional from .explorer import Explorer from .reader import MoleculeEntry, Problem, Reader @@ -12,7 +12,7 @@ class InvalidInputReader(Reader): def __init__(self): super().__init__() - def read(self, input, explore) -> Generator[MoleculeEntry, None, None]: + def read(self, input, explore) -> Iterator[MoleculeEntry]: yield MoleculeEntry( raw_input=input, input_type="unknown", @@ -36,31 +36,31 @@ def __init__( super().__init__() if readers is None: - self.reader_registry = ReaderRegistry() + self._reader_registry = list(ReaderRegistry().get_readers()) else: - self.reader_registry = readers + self._reader_registry = list(readers) - self.num_test_entries = num_test_entries - self.threshold = threshold - self.state_stack = [self.empty_state()] - self.maximum_depth = maximum_depth + self._num_test_entries = num_test_entries + self._threshold = threshold + self._state_stack = [self._empty_state()] + self._maximum_depth = maximum_depth - def empty_state(self): + def _empty_state(self): return dict(first_guess=[]) - def explore(self, input) -> Generator[MoleculeEntry, None, None]: + def explore(self, input) -> Iterator[MoleculeEntry]: # create a new child node and set it as the current node - state = self.empty_state() - parent = self.state_stack[-1] - self.state_stack.append(state) + state = self._empty_state() + parent = self._state_stack[-1] + self._state_stack.append(state) - depth = len(self.state_stack) - if depth > self.maximum_depth: - raise ValueError(f"Maximum depth of {self.maximum_depth} reached") + depth = len(self._state_stack) + if depth > self._maximum_depth: + raise ValueError(f"Maximum depth of {self._maximum_depth} reached") readers_iter = chain( zip(parent["first_guess"], repeat("guess")), - zip(self.reader_registry, repeat("builtin")), + zip(self._reader_registry, repeat("builtin")), ) # try all readers and take a sample of the first num_test_entries @@ -69,40 +69,56 @@ def explore(self, input) -> Generator[MoleculeEntry, None, None]: best_mode = None best_score = 0 best_ratio = 0.0 + best_num_results = 0 generator = None sample = [] for reader, mode in readers_iter: try: # read at most num_test_entries entries - generator = reader.read(input, self.explore) - sample = list(islice(generator, self.num_test_entries)) + generator = self._read(reader, input) + sample = list(islice(generator, self._num_test_entries)) valid_entries = [entry for entry in sample if entry.mol is not None] score = len(valid_entries) ratio = len(valid_entries) / len(sample) - - if score > best_score or (score == best_score and ratio > best_ratio): + num_results = len(sample) + + if ( + score > best_score + # if the score is the same, prefer the reader with higher ratio + # of valid entries + or (score == best_score and ratio > best_ratio) + # if the ratio is the same, prefer the reader with more results + # (e.g. list with 10 x None is better than one invalid entry) + or ( + score == best_score + and ratio == best_ratio + and num_results > best_num_results + ) + ): best_reader = reader best_mode = mode best_score = score best_ratio = ratio + best_num_results = num_results - if score == self.num_test_entries: + if score == self._num_test_entries: break except Exception: pass # clean up tree - while len(self.state_stack) > depth: - self.state_stack.pop() + while len(self._state_stack) > depth: + self._state_stack.pop() generator = None if generator is None: if best_reader is None: - generator = InvalidInputReader().read(input, self.explore) + generator = self._read(InvalidInputReader(), input) + sample = [] else: - generator = best_reader.read(input, self.explore) - sample = list(islice(generator, self.num_test_entries)) + generator = self._read(best_reader, input) + sample = list(islice(generator, self._num_test_entries)) else: if best_mode is not None and best_mode != "guess": parent["first_guess"].append(best_reader) diff --git a/nerdd_module/input/explorer.py b/nerdd_module/input/explorer.py index 9b83420..2e93de9 100644 --- a/nerdd_module/input/explorer.py +++ b/nerdd_module/input/explorer.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod -from typing import Generator +from typing import Iterator -from .reader import MoleculeEntry +from .reader import MoleculeEntry, Reader class Explorer(ABC): @@ -9,5 +9,8 @@ def __init__(self): pass @abstractmethod - def explore(self, input) -> Generator[MoleculeEntry, None, None]: + def explore(self, input) -> Iterator[MoleculeEntry]: pass + + def _read(self, reader: Reader, input) -> Iterator[MoleculeEntry]: + return reader.read(input, self.explore) diff --git a/nerdd_module/input/file_reader.py b/nerdd_module/input/file_reader.py index 64cc917..52b39a8 100644 --- a/nerdd_module/input/file_reader.py +++ b/nerdd_module/input/file_reader.py @@ -1,5 +1,6 @@ +from os import PathLike from pathlib import Path -from typing import Generator, Tuple +from typing import Iterator, Optional, Tuple, Union from .reader import MoleculeEntry, Reader from .reader_registry import register_reader @@ -7,15 +8,15 @@ __all__ = ["FileReader"] -@register_reader("data_dir") +@register_reader class FileReader(Reader): - def __init__(self, data_dir=None): + def __init__(self, data_dir: Union[str, PathLike, None] = None): super().__init__() self.data_dir = data_dir if self.data_dir is not None: self.data_dir = Path(self.data_dir) - def read(self, filename, explore) -> Generator[MoleculeEntry, None, None]: + def read(self, filename, explore) -> Iterator[MoleculeEntry]: assert isinstance(filename, str), "input must be a string" # convert filename to path diff --git a/nerdd_module/input/gzip_reader.py b/nerdd_module/input/gzip_reader.py index 78d5bb0..18cf8ee 100644 --- a/nerdd_module/input/gzip_reader.py +++ b/nerdd_module/input/gzip_reader.py @@ -1,5 +1,5 @@ import gzip -from typing import Generator +from typing import Iterator from .reader import MoleculeEntry, Reader from .reader_registry import register_reader @@ -12,7 +12,7 @@ class GzipReader(Reader): def __init__(self): super().__init__() - def read(self, input_stream, explore) -> Generator[MoleculeEntry, None, None]: + def read(self, input_stream, explore) -> Iterator[MoleculeEntry]: if not hasattr(input_stream, "read") or not hasattr(input_stream, "seek"): raise TypeError("input must be a stream-like object") diff --git a/nerdd_module/input/inchi_reader.py b/nerdd_module/input/inchi_reader.py index 5c45d4d..c713fb5 100644 --- a/nerdd_module/input/inchi_reader.py +++ b/nerdd_module/input/inchi_reader.py @@ -1,5 +1,5 @@ from codecs import getreader -from typing import Generator +from typing import Iterator from rdkit.Chem import MolFromInchi from rdkit.rdBase import BlockLogs @@ -18,7 +18,7 @@ class InchiReader(Reader): def __init__(self): super().__init__() - def read(self, input_stream, explore) -> Generator[MoleculeEntry, None, None]: + def read(self, input_stream, explore) -> Iterator[MoleculeEntry]: if not hasattr(input_stream, "read") or not hasattr(input_stream, "seek"): raise TypeError("input must be a stream-like object") diff --git a/nerdd_module/input/list_reader.py b/nerdd_module/input/list_reader.py index 0464a41..fee7e78 100644 --- a/nerdd_module/input/list_reader.py +++ b/nerdd_module/input/list_reader.py @@ -1,5 +1,5 @@ from io import BytesIO, StringIO -from typing import BinaryIO, Generator, Iterable +from typing import BinaryIO, Iterable, Iterator from .reader import MoleculeEntry, Reader from .reader_registry import register_reader @@ -12,7 +12,7 @@ class ListReader(Reader): def __init__(self): super().__init__() - def read(self, input_iterable, explore) -> Generator[MoleculeEntry, None, None]: + def read(self, input_iterable, explore) -> Iterator[MoleculeEntry]: assert isinstance(input_iterable, Iterable) and not isinstance( input_iterable, (str, bytes, BytesIO, StringIO, BinaryIO) ), f"input must be an iterable, but is {type(input_iterable)}" diff --git a/nerdd_module/input/mol_reader.py b/nerdd_module/input/mol_reader.py index 3248563..428485b 100644 --- a/nerdd_module/input/mol_reader.py +++ b/nerdd_module/input/mol_reader.py @@ -1,4 +1,4 @@ -from typing import Generator +from typing import Iterator from rdkit.Chem import Mol @@ -11,7 +11,7 @@ class MolReader(Reader): def __init__(self): super().__init__() - def read(self, mol, explore) -> Generator[MoleculeEntry, None, None]: + def read(self, mol, explore) -> Iterator[MoleculeEntry]: assert isinstance(mol, Mol) yield MoleculeEntry( raw_input=mol, diff --git a/nerdd_module/input/reader.py b/nerdd_module/input/reader.py index 2d4bd1e..c1f0eb2 100644 --- a/nerdd_module/input/reader.py +++ b/nerdd_module/input/reader.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Generator, List, NamedTuple, Optional, Tuple +from typing import Iterator, List, NamedTuple, Optional, Tuple from rdkit.Chem import Mol @@ -21,5 +21,5 @@ def __init__(self): super().__init__() @abstractmethod - def read(self, input, explore) -> Generator[MoleculeEntry, None, None]: + def read(self, input, explore) -> Iterator[MoleculeEntry]: pass diff --git a/nerdd_module/input/reader_registry.py b/nerdd_module/input/reader_registry.py index b7eda14..d03d206 100644 --- a/nerdd_module/input/reader_registry.py +++ b/nerdd_module/input/reader_registry.py @@ -1,27 +1,20 @@ -from functools import lru_cache -from typing import Dict, Generator, List, Tuple, Type +from functools import lru_cache, partial +from typing import Callable, Iterator, List, Type +from ..util import call_with_mappings, class_decorator from .reader import Reader __all__ = ["ReaderRegistry", "register_reader"] +ReaderFactory = Callable[[dict], Reader] + + # lru_cache makes the registry a singleton @lru_cache(maxsize=1) class ReaderRegistry: - def __init__(self): - self._factories: List[Tuple[Type[Reader], Tuple[str, ...], Dict[str, str]]] = [] - self._config = {} - - def _create_reader(self, ReaderClass: Type[Reader], *args, **kwargs) -> Reader: - # translate all args - args = tuple(self._config.get(arg, None) for arg in args) - # translate all kwargs - kwargs = { - k: self._config.get(v, None) for k, v in kwargs.items() if v in self._config - } - - return ReaderClass(*args, **kwargs) + def __init__(self) -> None: + self._factories: List[ReaderFactory] = [] def register(self, ReaderClass: Type[Reader], *args: str, **kwargs: str): assert issubclass(ReaderClass, Reader) @@ -29,36 +22,20 @@ def register(self, ReaderClass: Type[Reader], *args: str, **kwargs: str): assert all( [isinstance(k, str) and isinstance(v, str) for k, v in kwargs.items()] ) - self._factories.append((ReaderClass, args, kwargs)) - - def readers(self) -> Generator[Reader, None, None]: - for reader, args, kwargs in self._factories: - yield self._create_reader(reader, *args, **kwargs) - - def __iter__(self): - return iter(self.readers()) - - -def register_reader(*args, **kwargs): - def wrapper(cls, *args, **kwargs): - ReaderRegistry().register(cls, *args, **kwargs) - return cls + self._factories.append( + partial( + call_with_mappings, + ReaderClass, + args_mapping=args, + kwargs_mapping=kwargs, + ) + ) - # Case 1: first argument is a class - # --> decorator is used without arguments - # @register_reader - # class F: - # ... - if len(args) > 0 and isinstance(args[0], type): - return wrapper(args[0], *args[1:], **kwargs) + def get_readers(self, **kwargs) -> Iterator[Reader]: + for factory in self._factories: + yield factory(kwargs) - # Case 2: first argument is a not a class - # --> decorator is used with arguments - # @register_reader("blah") - # class F: - # ... - def inner(cls): - assert isinstance(cls, type), "Decorator must be used with a class" - return wrapper(cls, *args, **kwargs) - return inner +@class_decorator +def register_reader(cls, *args, **kwargs): + ReaderRegistry().register(cls, *args, **kwargs) diff --git a/nerdd_module/input/sdf_reader.py b/nerdd_module/input/sdf_reader.py index 9dbdd23..c79cc93 100644 --- a/nerdd_module/input/sdf_reader.py +++ b/nerdd_module/input/sdf_reader.py @@ -1,5 +1,5 @@ from codecs import getreader -from typing import Generator +from typing import Iterator from rdkit.Chem import MolFromMolBlock from rdkit.rdBase import BlockLogs @@ -19,7 +19,7 @@ def __init__(self, max_num_lines_mol_block: int = 10000): super().__init__() self.max_num_lines_mol_block = max_num_lines_mol_block - def read(self, input_stream, explore) -> Generator[MoleculeEntry, None, None]: + def read(self, input_stream, explore) -> Iterator[MoleculeEntry]: if not hasattr(input_stream, "read") or not hasattr(input_stream, "seek"): raise TypeError("input must be a stream-like object") diff --git a/nerdd_module/input/smiles_reader.py b/nerdd_module/input/smiles_reader.py index 762a566..7f14876 100644 --- a/nerdd_module/input/smiles_reader.py +++ b/nerdd_module/input/smiles_reader.py @@ -1,5 +1,5 @@ from codecs import getreader -from typing import Generator +from typing import Iterator from rdkit.Chem import MolFromSmiles from rdkit.rdBase import BlockLogs @@ -18,7 +18,7 @@ class SmilesReader(Reader): def __init__(self): super().__init__() - def read(self, input_stream, explore) -> Generator[MoleculeEntry, None, None]: + def read(self, input_stream, explore) -> Iterator[MoleculeEntry]: if not hasattr(input_stream, "read") or not hasattr(input_stream, "seek"): raise TypeError("input must be a stream-like object") @@ -43,7 +43,12 @@ def read(self, input_stream, explore) -> Generator[MoleculeEntry, None, None]: mol = None if mol is None: - errors = [Problem("invalid_smiles", "Invalid SMILES")] + display_line = line + if len(display_line) > 100: + display_line = display_line[:100] + "..." + errors = [ + Problem("invalid_smiles", f"Invalid SMILES {display_line}") + ] else: # old versions of RDKit do not parse the name # --> get name from smiles manually diff --git a/nerdd_module/input/string_reader.py b/nerdd_module/input/string_reader.py index 01b77d7..3c2d971 100644 --- a/nerdd_module/input/string_reader.py +++ b/nerdd_module/input/string_reader.py @@ -1,5 +1,5 @@ from io import BytesIO -from typing import Generator +from typing import Iterator from .reader import MoleculeEntry, Reader from .reader_registry import register_reader @@ -12,7 +12,7 @@ class StringReader(Reader): def __init__(self): super().__init__() - def read(self, input, explore) -> Generator[MoleculeEntry, None, None]: + def read(self, input, explore) -> Iterator[MoleculeEntry]: assert isinstance(input, str) with BytesIO(input.encode("utf-8")) as f: diff --git a/nerdd_module/input/tar_reader.py b/nerdd_module/input/tar_reader.py index 309946d..3696d1e 100644 --- a/nerdd_module/input/tar_reader.py +++ b/nerdd_module/input/tar_reader.py @@ -1,5 +1,5 @@ import tarfile -from typing import Generator +from typing import Iterator from .reader import MoleculeEntry, Reader from .reader_registry import register_reader @@ -12,7 +12,7 @@ class TarReader(Reader): def __init__(self): super().__init__() - def read(self, input_stream, explore) -> Generator[MoleculeEntry, None, None]: + def read(self, input_stream, explore) -> Iterator[MoleculeEntry]: if not hasattr(input_stream, "read") or not hasattr(input_stream, "seek"): raise TypeError("input must be a stream-like object") diff --git a/nerdd_module/input/zip_reader.py b/nerdd_module/input/zip_reader.py index 5e119c5..e31a7b9 100644 --- a/nerdd_module/input/zip_reader.py +++ b/nerdd_module/input/zip_reader.py @@ -1,5 +1,5 @@ import zipfile -from typing import Generator +from typing import Iterator from .reader import MoleculeEntry, Reader from .reader_registry import register_reader @@ -12,7 +12,7 @@ class ZipReader(Reader): def __init__(self): super().__init__() - def read(self, input_stream, explore) -> Generator[MoleculeEntry, None, None]: + def read(self, input_stream, explore) -> Iterator[MoleculeEntry]: if not hasattr(input_stream, "read") or not hasattr(input_stream, "seek"): raise TypeError("input must be a stream-like object") From 89ea00c2bcdf9e2d57488f700d0867197cb60ee4 Mon Sep 17 00:00:00 2001 From: Steffen Hirte Date: Wed, 25 Sep 2024 15:13:14 +0200 Subject: [PATCH 05/11] Refactor output writing modules * move file-based code from Writer to a new class FileWriter * implement new writer classes: IteratorWriter, PandasWriter, and RecordListWriter. * update the WriterRegistry to support registering and retrieving writers based on output format --- nerdd_module/output/__init__.py | 5 ++ nerdd_module/output/csv_writer.py | 20 +++--- nerdd_module/output/file_writer.py | 41 +++++++++++++ nerdd_module/output/iterator_writer.py | 13 ++++ nerdd_module/output/pandas_writer.py | 16 +++++ nerdd_module/output/record_list_writer.py | 13 ++++ nerdd_module/output/sdf_writer.py | 26 ++++---- nerdd_module/output/writer.py | 33 +--------- nerdd_module/output/writer_registry.py | 74 +++++++++++++---------- 9 files changed, 155 insertions(+), 86 deletions(-) create mode 100644 nerdd_module/output/file_writer.py create mode 100644 nerdd_module/output/iterator_writer.py create mode 100644 nerdd_module/output/pandas_writer.py create mode 100644 nerdd_module/output/record_list_writer.py diff --git a/nerdd_module/output/__init__.py b/nerdd_module/output/__init__.py index bee604a..6c80465 100644 --- a/nerdd_module/output/__init__.py +++ b/nerdd_module/output/__init__.py @@ -1 +1,6 @@ +from .csv_writer import * +from .iterator_writer import * +from .pandas_writer import * +from .record_list_writer import * +from .writer import * from .writer_registry import * diff --git a/nerdd_module/output/csv_writer.py b/nerdd_module/output/csv_writer.py index 86a48c5..8a07bd6 100644 --- a/nerdd_module/output/csv_writer.py +++ b/nerdd_module/output/csv_writer.py @@ -1,20 +1,21 @@ import csv from itertools import chain -from typing import Dict, Iterable +from typing import IO, Any, Dict, Iterable from rdkit.Chem import Mol, MolToSmiles -from .writer import Writer +from .file_writer import FileLike, FileWriter +from .writer_registry import register_writer +__all__ = ["CsvWriter"] -class CsvWriter(Writer): - def __init__(self): - super().__init__(writes_bytes=False) - def _output_type(self) -> str: - return "csv" +@register_writer("csv") +class CsvWriter(FileWriter): + def __init__(self, output_file: FileLike): + super().__init__(output_file, writes_bytes=False) - def _write(self, output, entries: Iterable[Dict]): + def _write(self, output: IO[Any], entries: Iterable[Dict]) -> None: entry_iter = iter(entries) # get the first entry to extract the fieldnames @@ -24,7 +25,4 @@ def _write(self, output, entries: Iterable[Dict]): # write header, first entry, and remaining entries writer.writeheader() for entry in chain([first_entry], entry_iter): - for key, value in entry.items(): - if isinstance(value, Mol): - entry[key] = MolToSmiles(value, canonical=False) writer.writerow(entry) diff --git a/nerdd_module/output/file_writer.py b/nerdd_module/output/file_writer.py new file mode 100644 index 0000000..2d53cb2 --- /dev/null +++ b/nerdd_module/output/file_writer.py @@ -0,0 +1,41 @@ +import codecs +from abc import ABC, abstractmethod +from pathlib import Path +from typing import IO, Any, BinaryIO, Iterable, List, TextIO, Union + +from .writer import Writer + +StreamWriter = codecs.getwriter("utf-8") + +__all__ = ["FileWriter", "FileLike"] + + +FileLike = Union[str, Path, TextIO, BinaryIO] + + +class FileWriter(Writer): + """Abstract class for writers.""" + + def __init__(self, output_file: FileLike, writes_bytes: bool = False): + self._output_file = output_file + self._writes_bytes = writes_bytes + + def write(self, entries: Iterable[dict]): + """Write entries to output.""" + if isinstance(self._output_file, (str, Path)): + mode = "wb" if self._writes_bytes else "w" + with open(self._output_file, mode) as f: + self._write(f, entries) + else: + self._write(self._output_file, entries) + self._output_file.flush() + + @abstractmethod + def _write(self, output: IO[Any], entries: Iterable[dict]) -> None: + """Write entries to output.""" + pass + + @property + def writes_bytes(self) -> bool: + """Whether the writer writes bytes.""" + return self._writes_bytes diff --git a/nerdd_module/output/iterator_writer.py b/nerdd_module/output/iterator_writer.py new file mode 100644 index 0000000..8548570 --- /dev/null +++ b/nerdd_module/output/iterator_writer.py @@ -0,0 +1,13 @@ +from .writer import Writer +from .writer_registry import register_writer + +__all__ = ["IteratorWriter"] + + +@register_writer("iterator") +class IteratorWriter(Writer): + def __init__(self) -> None: + pass + + def write(self, records): + return records diff --git a/nerdd_module/output/pandas_writer.py b/nerdd_module/output/pandas_writer.py new file mode 100644 index 0000000..f962ce4 --- /dev/null +++ b/nerdd_module/output/pandas_writer.py @@ -0,0 +1,16 @@ +import pandas as pd + +from .writer import Writer +from .writer_registry import register_writer + +__all__ = ["PandasWriter"] + + +@register_writer("pandas") +class PandasWriter(Writer): + def __init__(self) -> None: + pass + + def write(self, records): + df = pd.DataFrame(records) + return df diff --git a/nerdd_module/output/record_list_writer.py b/nerdd_module/output/record_list_writer.py new file mode 100644 index 0000000..f279d8c --- /dev/null +++ b/nerdd_module/output/record_list_writer.py @@ -0,0 +1,13 @@ +from .writer import Writer +from .writer_registry import register_writer + +__all__ = ["RecordListWriter"] + + +@register_writer("record_list") +class RecordListWriter(Writer): + def __init__(self) -> None: + pass + + def write(self, records): + return list(records) diff --git a/nerdd_module/output/sdf_writer.py b/nerdd_module/output/sdf_writer.py index 6ed74cd..e9d93a5 100644 --- a/nerdd_module/output/sdf_writer.py +++ b/nerdd_module/output/sdf_writer.py @@ -1,18 +1,19 @@ -from typing import BinaryIO, Dict, Iterable, TextIO, Union +from typing import IO, Any, Dict, Iterable -from rdkit.Chem import Mol, MolToSmiles, SDWriter +from rdkit.Chem import SDWriter -from .writer import Writer +from .file_writer import FileLike, FileWriter +from .writer_registry import register_writer +__all__ = ["SdfWriter"] -class SdfWriter(Writer): - def __init__(self): - super().__init__(writes_bytes=False) - def _output_type(self) -> str: - return "sdf" +@register_writer("sdf") +class SdfWriter(FileWriter): + def __init__(self, output_file: FileLike) -> None: + super().__init__(output_file, writes_bytes=False) - def _write(self, output, entries: Iterable[Dict]): + def _write(self, output: IO[Any], entries: Iterable[Dict]) -> None: writer = SDWriter(output) try: for entry in entries: @@ -21,13 +22,12 @@ def _write(self, output, entries: Iterable[Dict]): # write (almost) all properties to the mol object for key, value in entry.items(): - if isinstance(value, Mol): - value = MolToSmiles(value) - elif isinstance(value, str) and "\n" in value: + value_as_str = str(value) + if "\n" in value_as_str: # SDF can't write multi-line strings continue - mol.SetProp(key, str(value)) + mol.SetProp(key, value_as_str) # write molecule writer.write(mol) diff --git a/nerdd_module/output/writer.py b/nerdd_module/output/writer.py index 91d222d..8dafb3f 100644 --- a/nerdd_module/output/writer.py +++ b/nerdd_module/output/writer.py @@ -1,7 +1,6 @@ import codecs from abc import ABC, abstractmethod -from io import BufferedWriter, TextIOWrapper -from typing import BinaryIO, Dict, Iterable, TextIO, Union +from typing import Any, Iterable StreamWriter = codecs.getwriter("utf-8") @@ -11,35 +10,9 @@ class Writer(ABC): """Abstract class for writers.""" - def __init__(self, writes_bytes: bool = False): - self._writes_bytes = writes_bytes - - @property - def output_type(self) -> str: - """The output type of the writer.""" - return self._output_type() - - @abstractmethod - def _output_type(self) -> str: - """The output type of the writer.""" + def __init__(self): pass - def write(self, output, entries: Iterable[Dict]): - """Write entries to output.""" - if isinstance(output, str): - mode = "wb" if self._writes_bytes else "w" - with open(output, mode) as f: - self._write(f, entries) - else: - self._write(output, entries) - output.flush() - @abstractmethod - def _write(self, output, entries: Iterable[Dict]): - """Write entries to output.""" + def write(self, records: Iterable[dict]) -> Any: pass - - @property - def writes_bytes(self) -> bool: - """Whether the writer writes bytes.""" - return self._writes_bytes diff --git a/nerdd_module/output/writer_registry.py b/nerdd_module/output/writer_registry.py index 5d32436..ea5fc3b 100644 --- a/nerdd_module/output/writer_registry.py +++ b/nerdd_module/output/writer_registry.py @@ -1,40 +1,50 @@ -from functools import lru_cache +from functools import lru_cache, partial +from typing import Callable, Dict, Type -from .csv_writer import CsvWriter -from .sdf_writer import SdfWriter +from ..util import call_with_mappings, class_decorator from .writer import Writer -__all__ = ["WriterRegistry"] +__all__ = [ + "WriterRegistry", + "register_writer", +] + + +WriterFactory = Callable[[dict], Writer] # lru_cache makes the registry a singleton @lru_cache(maxsize=1) class WriterRegistry: - def __init__(self): - self._writers = [] - - def register(self, writer: Writer): - self._writers.append(writer) - - def get_writer(self, output_type: str) -> Writer: - for writer in self._writers: - if writer.output_type == output_type: - return writer - - raise ValueError(f"Unsupported output type: {output_type}") - - @property - def supported_formats(self) -> frozenset: - return frozenset([writer.output_type for writer in self._writers]) - - @property - def writers(self): - return frozenset(self._writers) - - def __iter__(self): - return iter(self._writers) - - -registry = WriterRegistry() -registry.register(CsvWriter()) -registry.register(SdfWriter()) + def __init__(self) -> None: + self._factories: Dict[str, WriterFactory] = {} + + def register( + self, + output_format: str, + WriterClass: Type[Writer], + *args: str, + **kwargs: str, + ): + assert issubclass(WriterClass, Writer) + assert all([isinstance(arg, str) for arg in args]) + assert all( + [isinstance(k, str) and isinstance(v, str) for k, v in kwargs.items()] + ) + + self._factories[output_format] = partial( + call_with_mappings, WriterClass, args_mapping=args, kwargs_mapping=kwargs + ) + + def get_writer(self, output_format: str, **kwargs) -> Writer: + if output_format not in self._factories: + raise ValueError(f"Unknown output format: {output_format}") + return self._factories[output_format](kwargs) + + def get_output_formats(self) -> frozenset: + return frozenset(self._factories.keys()) + + +@class_decorator +def register_writer(cls: Type[Writer], output_format: str, *args, **kwargs): + WriterRegistry().register(output_format, cls, *args, **kwargs) From be11962b48da5369d2bf8886ab874f47c2bedbac Mon Sep 17 00:00:00 2001 From: Steffen Hirte Date: Wed, 25 Sep 2024 15:20:42 +0200 Subject: [PATCH 06/11] Adapt preprocessing pipelines to dict streaming code * implement special case of filtering "H" in FilterByElement * remove pipeline classes (they are replaced by list of preprocessing steps) * remove pipeline registry since it is not needed anymore * move code to Step / MapStep classes --- nerdd_module/preprocessing/__init__.py | 5 +- .../preprocessing/check_valid_smiles.py | 12 +-- .../chembl_structure_pipeline.py | 79 +++++++------------ nerdd_module/preprocessing/empty_pipeline.py | 8 -- .../preprocessing/filter_by_element.py | 44 +++++++++-- .../preprocessing/filter_by_weight.py | 17 ++-- nerdd_module/preprocessing/pipeline.py | 53 ------------- .../preprocessing/preprocessing_step.py | 44 +++++++++++ nerdd_module/preprocessing/registry.py | 20 ----- .../preprocessing/remove_stereochemistry.py | 16 ++-- nerdd_module/preprocessing/sanitize.py | 6 +- nerdd_module/preprocessing/step.py | 26 ------ 12 files changed, 136 insertions(+), 194 deletions(-) delete mode 100644 nerdd_module/preprocessing/empty_pipeline.py delete mode 100644 nerdd_module/preprocessing/pipeline.py create mode 100644 nerdd_module/preprocessing/preprocessing_step.py delete mode 100644 nerdd_module/preprocessing/registry.py delete mode 100644 nerdd_module/preprocessing/step.py diff --git a/nerdd_module/preprocessing/__init__.py b/nerdd_module/preprocessing/__init__.py index d5e107e..ec1d6ce 100644 --- a/nerdd_module/preprocessing/__init__.py +++ b/nerdd_module/preprocessing/__init__.py @@ -1,10 +1,7 @@ from .check_valid_smiles import * from .chembl_structure_pipeline import * -from .empty_pipeline import * from .filter_by_element import * from .filter_by_weight import * -from .pipeline import * -from .registry import * +from .preprocessing_step import * from .remove_stereochemistry import * from .sanitize import * -from .step import * diff --git a/nerdd_module/preprocessing/check_valid_smiles.py b/nerdd_module/preprocessing/check_valid_smiles.py index e1a60b3..61bedaf 100644 --- a/nerdd_module/preprocessing/check_valid_smiles.py +++ b/nerdd_module/preprocessing/check_valid_smiles.py @@ -3,24 +3,24 @@ from rdkit.Chem import Mol, MolFromSmiles, MolToSmiles from ..problem import InvalidSmiles, Problem -from .step import Step +from .preprocessing_step import PreprocessingStep __all__ = ["CheckValidSmiles"] -class CheckValidSmiles(Step): +class CheckValidSmiles(PreprocessingStep): """Checks if the molecule can be converted to SMILES and back.""" def __init__(self): super().__init__() - def _run(self, mol: Mol) -> Tuple[Optional[Mol], List[Problem]]: - errors = [] + def _preprocess(self, mol: Mol) -> Tuple[Optional[Mol], List[Problem]]: + problems = [] smi = MolToSmiles(mol, True) check_mol = MolFromSmiles(smi) if check_mol is None: - errors.append(InvalidSmiles()) + problems.append(InvalidSmiles()) mol = None - return mol, errors + return mol, problems diff --git a/nerdd_module/preprocessing/chembl_structure_pipeline.py b/nerdd_module/preprocessing/chembl_structure_pipeline.py index a4b642e..20ab1f9 100644 --- a/nerdd_module/preprocessing/chembl_structure_pipeline.py +++ b/nerdd_module/preprocessing/chembl_structure_pipeline.py @@ -5,12 +5,7 @@ from rdkit.rdBase import BlockLogs from ..problem import Problem -from .check_valid_smiles import CheckValidSmiles -from .filter_by_element import FilterByElement -from .filter_by_weight import FilterByWeight -from .pipeline import Pipeline -from .remove_stereochemistry import RemoveStereochemistry -from .step import Step +from .preprocessing_step import PreprocessingStep # before importing chembl_structure_pipeline, we need to suppress RDKit warnings warnings.filterwarnings( @@ -31,10 +26,10 @@ # --> this allows to use the rest of the package without chembl_structure_pipeline import_error = e -__all__ = ["ChemblStructurePipeline", "GetParentMol", "StandardizeWithCsp"] +__all__ = ["GetParentMol", "StandardizeWithCsp"] -class StandardizeWithCsp(Step): +class StandardizeWithCsp(PreprocessingStep): def __init__(self): super().__init__() @@ -58,7 +53,7 @@ def _run(self, mol: Mol) -> Tuple[Optional[Mol], List[Problem]]: return preprocessed_mol, errors -class GetParentMol(Step): +class GetParentMol(PreprocessingStep): def __init__(self): super().__init__() @@ -82,43 +77,29 @@ def _run(self, mol: Mol) -> Tuple[Optional[Mol], List[Problem]]: return preprocessed_mol, errors -class ChemblStructurePipeline(Pipeline): - def __init__( - self, - min_weight=150, - max_weight=1500, - allowed_elements=[ - "H", - "B", - "C", - "N", - "O", - "F", - "Si", - "P", - "S", - "Cl", - "Se", - "Br", - "I", - ], - remove_stereochemistry=False, - remove_invalid_molecules=False, - ): - super().__init__( - steps=[ - StandardizeWithCsp(), - FilterByWeight( - min_weight=min_weight, - max_weight=max_weight, - remove_invalid_molecules=remove_invalid_molecules, - ), - FilterByElement( - allowed_elements, remove_invalid_molecules=remove_invalid_molecules - ), - GetParentMol(), - ] - + ([RemoveStereochemistry()] if remove_stereochemistry else []) - + [CheckValidSmiles()], - name="chembl_structure_pipeline", - ) +# class ChemblStructurePipeline(Pipeline): +# def __init__( +# self, +# min_weight=150, +# max_weight=1500, +# allowed_elements=, +# remove_stereochemistry=False, +# remove_invalid_molecules=False, +# ): +# super().__init__( +# steps=[ +# StandardizeWithCsp(), +# FilterByWeight( +# min_weight=min_weight, +# max_weight=max_weight, +# remove_invalid_molecules=remove_invalid_molecules, +# ), +# FilterByElement( +# allowed_elements, remove_invalid_molecules=remove_invalid_molecules +# ), +# GetParentMol(), +# ] +# + ([RemoveStereochemistry()] if remove_stereochemistry else []) +# + [CheckValidSmiles()], +# name="chembl_structure_pipeline", +# ) diff --git a/nerdd_module/preprocessing/empty_pipeline.py b/nerdd_module/preprocessing/empty_pipeline.py deleted file mode 100644 index 29d47c4..0000000 --- a/nerdd_module/preprocessing/empty_pipeline.py +++ /dev/null @@ -1,8 +0,0 @@ -from .pipeline import Pipeline - -__all__ = ["EmptyPipeline"] - - -class EmptyPipeline(Pipeline): - def __init__(self): - super().__init__(steps=[], name="no_preprocessing") diff --git a/nerdd_module/preprocessing/filter_by_element.py b/nerdd_module/preprocessing/filter_by_element.py index 0e8de12..daa9662 100644 --- a/nerdd_module/preprocessing/filter_by_element.py +++ b/nerdd_module/preprocessing/filter_by_element.py @@ -3,24 +3,52 @@ from rdkit.Chem import Mol from ..problem import Problem -from .step import Step +from .preprocessing_step import PreprocessingStep +__all__ = ["FilterByElement", "ORGANIC_SUBSET"] -class FilterByElement(Step): +ORGANIC_SUBSET = [ + "H", + "B", + "C", + "N", + "O", + "F", + "Si", + "P", + "S", + "Cl", + "Se", + "Br", + "I", +] + + +class FilterByElement(PreprocessingStep): def __init__( self, allowed_elements: Iterable[str], remove_invalid_molecules: bool = False ): super().__init__() - self.allowed_elements = set(allowed_elements) + self.allowed_elements = set(a.upper() for a in allowed_elements) + self.hydrogen_in_allowed_elements = "H" in self.allowed_elements self.remove_invalid_molecules = remove_invalid_molecules - def _run(self, mol: Mol) -> Tuple[Optional[Mol], List[Problem]]: - errors = [] + def _preprocess(self, mol: Mol) -> Tuple[Optional[Mol], List[Problem]]: + problems = [] result_mol = mol elements = set(atom.GetSymbol() for atom in mol.GetAtoms()) invalid_elements = elements - self.allowed_elements - if len(elements - self.allowed_elements) > 0: + + # special case: hydrogens are not recognized by mol.GetAtoms() + if not self.hydrogen_in_allowed_elements: + # get the number of hydrogens in mol + for a in mol.GetAtoms(): + if a.GetTotalNumHs() > 0: + invalid_elements.add("H") + break + + if len(invalid_elements) > 0: if self.remove_invalid_molecules: result_mol = None @@ -29,11 +57,11 @@ def _run(self, mol: Mol) -> Tuple[Optional[Mol], List[Problem]]: else: invalid_elements_str = ", ".join(list(invalid_elements)) - errors.append( + problems.append( Problem( "invalid_elements", f"Molecule contains invalid elements {invalid_elements_str}", ) ) - return result_mol, errors + return result_mol, problems diff --git a/nerdd_module/preprocessing/filter_by_weight.py b/nerdd_module/preprocessing/filter_by_weight.py index d6f84e2..633b7b9 100644 --- a/nerdd_module/preprocessing/filter_by_weight.py +++ b/nerdd_module/preprocessing/filter_by_weight.py @@ -4,26 +4,25 @@ from rdkit.Chem.Descriptors import MolWt from ..problem import Problem -from .step import Step +from .preprocessing_step import PreprocessingStep -class FilterByWeight(Step): +class FilterByWeight(PreprocessingStep): def __init__(self, min_weight, max_weight, remove_invalid_molecules=False): super().__init__() self.min_weight = min_weight self.max_weight = max_weight self.remove_invalid_molecules = remove_invalid_molecules - def _run(self, mol: Mol) -> Tuple[Optional[Mol], List[Problem]]: - errors = [] + def _preprocess(self, mol: Mol) -> Tuple[Optional[Mol], List[Problem]]: + problems = [] + result_mol = mol weight = MolWt(mol) if weight < self.min_weight or weight > self.max_weight: if self.remove_invalid_molecules: result_mol = None - else: - result_mol = mol - errors.append( + problems.append( Problem( type="invalid_weight", message=( @@ -32,7 +31,5 @@ def _run(self, mol: Mol) -> Tuple[Optional[Mol], List[Problem]]: ), ) ) - else: - result_mol = mol - return result_mol, errors + return result_mol, problems diff --git a/nerdd_module/preprocessing/pipeline.py b/nerdd_module/preprocessing/pipeline.py deleted file mode 100644 index 0674519..0000000 --- a/nerdd_module/preprocessing/pipeline.py +++ /dev/null @@ -1,53 +0,0 @@ -from typing import Iterable, List, Optional, Tuple - -from rdkit.Chem import Mol -from stringcase import snakecase - -from ..problem import Problem -from .step import Step - -__all__ = ["Pipeline", "make_pipeline"] - - -class Pipeline: - def __init__(self, steps: Iterable[Step], name: Optional[str] = None): - assert all(isinstance(step, Step) for step in steps) - self._steps = steps - self._name = name - - def run(self, mol: Mol) -> Tuple[Mol, List[Problem]]: - errors = [] - - if mol is None: - errors.append(Problem("no_molecule", "No molecule to process")) - - for step in self._steps: - if mol is None: - break - - mol, additional_errors = step.run(mol) - - errors.extend(additional_errors) - - return mol, errors - - def __call__(self, mol: Mol) -> Tuple[Mol, List[Problem]]: - return self.run(mol) - - @property - def name(self) -> str: - if self._name is None: - if type(self) is Pipeline: - # class is an instantiation of this class (no subclass) - # --> getting the name of this class raises an error - # (usually no problem except if the pipeline should be registered) - raise ValueError("Pipeline has no name") - else: - # class is deriving from Pipeline - # return type of subclass deriving from this class - return snakecase(type(self).__name__) - return self._name - - -def make_pipeline(*steps: Step): - return Pipeline(steps=steps) diff --git a/nerdd_module/preprocessing/preprocessing_step.py b/nerdd_module/preprocessing/preprocessing_step.py new file mode 100644 index 0000000..175cd95 --- /dev/null +++ b/nerdd_module/preprocessing/preprocessing_step.py @@ -0,0 +1,44 @@ +from abc import abstractmethod +from typing import Iterable, Iterator, List, Optional, Tuple, Union + +from rdkit.Chem import Mol + +from ..problem import Problem +from ..steps import MapStep + +__all__ = ["PreprocessingStep"] + + +class PreprocessingStep(MapStep): + def __init__(self): + super().__init__() + + def _process(self, record: dict) -> Union[dict, Iterable[dict], Iterator[dict]]: + # If "preprocessed_mol" is not present, then this is the first preprocessing + # step. + if "preprocessed_mol" not in record: + mol = record.get("input_mol") + record["preprocessed_mol"] = mol + + mol = record["preprocessed_mol"] + + # We don't preprocess invalid molecules. + if mol is None: + return record + + mol, problems = self._preprocess(mol) + record["preprocessed_mol"] = mol + + if "problems" in record: + record["problems"].extend(problems) + else: + record["problems"] = problems + + return record + + @abstractmethod + def _preprocess(self, mol: Mol) -> Tuple[Optional[Mol], List[Problem]]: + """ + Runs the preprocesing step on a molecule. + """ + pass diff --git a/nerdd_module/preprocessing/registry.py b/nerdd_module/preprocessing/registry.py deleted file mode 100644 index ece7afe..0000000 --- a/nerdd_module/preprocessing/registry.py +++ /dev/null @@ -1,20 +0,0 @@ -from .chembl_structure_pipeline import ChemblStructurePipeline -from .empty_pipeline import EmptyPipeline -from .pipeline import Pipeline - -__all__ = ["registry", "register_pipeline"] - -registry = {} - - -def register_pipeline(pipeline: Pipeline): - pipeline_name = pipeline.name - registry[pipeline_name] = pipeline - - -register_pipeline(EmptyPipeline()) - -try: - register_pipeline(ChemblStructurePipeline()) -except ImportError: - pass diff --git a/nerdd_module/preprocessing/remove_stereochemistry.py b/nerdd_module/preprocessing/remove_stereochemistry.py index fdb6bbd..ae4a19a 100644 --- a/nerdd_module/preprocessing/remove_stereochemistry.py +++ b/nerdd_module/preprocessing/remove_stereochemistry.py @@ -4,21 +4,23 @@ from rdkit.Chem import RemoveStereochemistry as remove_stereochemistry from ..problem import Problem -from .step import Step +from .preprocessing_step import PreprocessingStep -class RemoveStereochemistry(Step): +class RemoveStereochemistry(PreprocessingStep): def __init__(self): super().__init__() - def _run(self, mol: Mol) -> Tuple[Mol, List[Problem]]: - errors = [] + def _preprocess(self, mol: Mol) -> Tuple[Mol, List[Problem]]: + problems = [] try: remove_stereochemistry(mol) except Exception: - errors.append( - Problem("remove_stereochemistry", "Cannot remove stereochemistry") + problems.append( + Problem( + "remove_stereochemistry_failed", "Cannot remove stereochemistry" + ) ) - return mol, errors + return mol, problems diff --git a/nerdd_module/preprocessing/sanitize.py b/nerdd_module/preprocessing/sanitize.py index daf4e73..402f347 100644 --- a/nerdd_module/preprocessing/sanitize.py +++ b/nerdd_module/preprocessing/sanitize.py @@ -1,15 +1,15 @@ from rdkit.Chem import SanitizeMol -from .step import Step +from .preprocessing_step import PreprocessingStep __all__ = ["Sanitize"] -class Sanitize(Step): +class Sanitize(PreprocessingStep): def __init__(self): super().__init__() - def _run(self, mol): + def _preprocess(self, mol): errors = [] # sanitize molecule diff --git a/nerdd_module/preprocessing/step.py b/nerdd_module/preprocessing/step.py deleted file mode 100644 index 9d4448e..0000000 --- a/nerdd_module/preprocessing/step.py +++ /dev/null @@ -1,26 +0,0 @@ -from abc import ABC, abstractmethod -from typing import List, Optional, Tuple - -from rdkit.Chem import Mol - -from ..problem import Problem - -__all__ = ["Step"] - - -class Step(ABC): - def __init__(self): - pass - - def run(self, mol: Mol) -> Tuple[Optional[Mol], List[Problem]]: - """ - Runs the step on a molecule. - """ - return self._run(mol) - - @abstractmethod - def _run(self, mol: Mol) -> Tuple[Optional[Mol], List[Problem]]: - """ - Runs the step on a molecule. - """ - pass From 4f439251cb4683fd76232585319edbc394502bdd Mon Sep 17 00:00:00 2001 From: Steffen Hirte Date: Wed, 25 Sep 2024 15:21:44 +0200 Subject: [PATCH 07/11] Introduce new base class to implement prediction steps --- nerdd_module/steps/__init__.py | 2 ++ nerdd_module/steps/map_step.py | 38 ++++++++++++++++++++++++++++++++++ nerdd_module/steps/step.py | 27 ++++++++++++++++++++++++ 3 files changed, 67 insertions(+) create mode 100644 nerdd_module/steps/__init__.py create mode 100644 nerdd_module/steps/map_step.py create mode 100644 nerdd_module/steps/step.py diff --git a/nerdd_module/steps/__init__.py b/nerdd_module/steps/__init__.py new file mode 100644 index 0000000..1186113 --- /dev/null +++ b/nerdd_module/steps/__init__.py @@ -0,0 +1,2 @@ +from .map_step import * +from .step import * diff --git a/nerdd_module/steps/map_step.py b/nerdd_module/steps/map_step.py new file mode 100644 index 0000000..e430b88 --- /dev/null +++ b/nerdd_module/steps/map_step.py @@ -0,0 +1,38 @@ +from abc import abstractmethod +from typing import Iterable, Iterator, Optional, Union + +from .step import Step + +__all__ = ["MapStep"] + + +class MapStep(Step): + def __init__(self, is_source=False) -> None: + super().__init__(is_source=is_source) + + def _run(self, source: Iterator[dict]) -> Iterator[dict]: + # The _process method might return a single result or a list of results and we + # define a wrapper function to handle both cases. In the first case, we yield + # the result, in the second case we yield each element of the list. + def _wrapper(result): + if isinstance(result, dict): + yield result + elif hasattr(result, "__iter__"): + # this can be a list or a generator + yield from result + else: + # anything that is not a dict or an iterable / generator + yield result + + # If this transform has no source, then it is the first transform in the chain + # (i.e. it generates data without input). We call _process with the empty dict + # as input to start the generation process. + if self.is_source: + yield from _wrapper(self._process(dict())) + else: + for record in source: + yield from _wrapper(self._process(record)) + + @abstractmethod + def _process(self, record: dict) -> Union[dict, Iterable[dict], Iterator[dict]]: + pass diff --git a/nerdd_module/steps/step.py b/nerdd_module/steps/step.py new file mode 100644 index 0000000..d7485b7 --- /dev/null +++ b/nerdd_module/steps/step.py @@ -0,0 +1,27 @@ +from abc import ABC, abstractmethod +from typing import Iterator, Optional + +__all__ = ["Step"] + + +class Step(ABC): + def __init__(self, is_source=False) -> None: + self._is_source = is_source + + @property + def is_source(self) -> bool: + return self._is_source + + def __call__(self, source: Optional[Iterator[dict]] = None) -> Iterator[dict]: + assert self.is_source == ( + source is None + ), "No source was given and this step is not a source." + + if source is not None: + return self._run(source) + else: + return self._run(iter([])) + + @abstractmethod + def _run(self, source: Iterator[dict]) -> Iterator[dict]: + pass From 43cae68aa8bdf6eb044969509cfaf599ea3a06a1 Mon Sep 17 00:00:00 2001 From: Steffen Hirte Date: Wed, 25 Sep 2024 15:24:08 +0200 Subject: [PATCH 08/11] Add util submodule with convenient functions * move class_decorator in utils submodule * implement call_with_mappings that calls a method and assigns function arguments by name * implement method to retrieve the file of a given class instance --- nerdd_module/util/__init__.py | 3 ++ nerdd_module/util/call_with_mappings.py | 53 +++++++++++++++++++++++++ nerdd_module/util/class_decorator.py | 29 ++++++++++++++ nerdd_module/util/package.py | 24 +++++++++++ 4 files changed, 109 insertions(+) create mode 100644 nerdd_module/util/__init__.py create mode 100644 nerdd_module/util/call_with_mappings.py create mode 100644 nerdd_module/util/class_decorator.py create mode 100644 nerdd_module/util/package.py diff --git a/nerdd_module/util/__init__.py b/nerdd_module/util/__init__.py new file mode 100644 index 0000000..5ef6ad3 --- /dev/null +++ b/nerdd_module/util/__init__.py @@ -0,0 +1,3 @@ +from .call_with_mappings import * +from .class_decorator import * +from .package import * diff --git a/nerdd_module/util/call_with_mappings.py b/nerdd_module/util/call_with_mappings.py new file mode 100644 index 0000000..75d6259 --- /dev/null +++ b/nerdd_module/util/call_with_mappings.py @@ -0,0 +1,53 @@ +import inspect +from typing import Callable, Dict, Tuple, Type, TypeVar, Union + +__all__ = ["call_with_mappings"] + +T = TypeVar("T") + + +def call_with_mappings( + class_or_function: Union[Type[T], Callable[..., T]], + config: dict, + args_mapping: Tuple[str, ...] = (), + kwargs_mapping: Dict[str, str] = {}, +) -> T: + # translate all args + translated_args = tuple(config.get(arg) for arg in args_mapping) + # translate all kwargs + translated_kwargs = { + k: config["v"] for k, v in kwargs_mapping.items() if v in config + } + + # copy config + config = config.copy() + + # we check what arguments the constructor of the writer class can take + spec = inspect.getfullargspec(class_or_function) + parameter_names = [a for a in spec.args if a != "self"] + accept_any_args = spec.varargs is not None + accept_any_kwargs = spec.varkw is not None + + args = [] + if accept_any_args and len(parameter_names) == 0: + args = list(translated_args) + else: + for i, arg in enumerate(parameter_names): + if i < len(translated_args): + args.append(translated_args[i]) + elif arg in translated_kwargs: + args.append(translated_kwargs[arg]) + del translated_kwargs[arg] + elif arg in config: + args.append(config[arg]) + del config[arg] + elif i >= len(parameter_names) - len(spec.defaults or []): + pass + else: + raise ValueError(f"Missing required argument: {arg}") + + kwargs = {} + if accept_any_kwargs: + kwargs = config + + return class_or_function(*args, **kwargs) diff --git a/nerdd_module/util/class_decorator.py b/nerdd_module/util/class_decorator.py new file mode 100644 index 0000000..f775ee4 --- /dev/null +++ b/nerdd_module/util/class_decorator.py @@ -0,0 +1,29 @@ +__all__ = ["class_decorator"] + + +def class_decorator(f): + def result_decorator(*args, **kwargs): + # Case 1: first argument is a class -> decorator is used without arguments + # + # Example: + # + # @my_decorator + # class F: + # ... + if len(args) > 0 and isinstance(args[0], type): + return f(args[0], *args[1:], **kwargs) + + # Case 2: first argument is a not a class --> decorator is used with arguments + # + # Example: + # + # @my_decorator(42) + # class F: + # ... + def inner(cls): + assert isinstance(cls, type), "Decorator must be used with a class" + return f(cls, *args, **kwargs) + + return inner + + return result_decorator diff --git a/nerdd_module/util/package.py b/nerdd_module/util/package.py new file mode 100644 index 0000000..5599af7 --- /dev/null +++ b/nerdd_module/util/package.py @@ -0,0 +1,24 @@ +import sys + +__all__ = ["get_file_path_to_instance"] + + +def get_file_path_to_instance(instance): + # get the class of the provided class instance, e.g. + instance_class = instance.__class__ + + # get the module name of the class + # e.g. "cypstrate.cypstrate_model" + instance_module_name = instance_class.__module__ + + # get the file where the nerdd module class is defined + # e.g. "/path/to/cypstrate/cypstrate_model.py" + module = sys.modules[instance_module_name] + + if hasattr(module, "__file__") and module.__file__ is not None: + path = module.__file__ + else: + # if the module has no __file__ attribute, return None + path = None + + return path From e073129be0f58f6d476ead5fb179dbc229c1bcf1 Mon Sep 17 00:00:00 2001 From: Steffen Hirte Date: Wed, 25 Sep 2024 15:37:18 +0200 Subject: [PATCH 09/11] Adapt models to new dict streaming workflow * move code for assigning mol ids and names to separate classes (AssignMolId, AssignName) * move code that takes care of adding missing columns to a separate class (EnforceSchema) * let the base class Model handle prediction of molecules * add all additional convenience code to SimpleModel (configuration, input, output) * implement wrapper classes for input reading (ReadInput) and output writing (WriteOutput) --- nerdd_module/__init__.py | 8 +- nerdd_module/abstract_model.py | 271 --------------------------- nerdd_module/model/__init__.py | 5 + nerdd_module/model/assign_mol_id.py | 15 ++ nerdd_module/model/assign_name.py | 19 ++ nerdd_module/model/enforce_schema.py | 13 ++ nerdd_module/model/model.py | 187 ++++++++++++++++++ nerdd_module/model/read_input.py | 24 +++ nerdd_module/model/simple_model.py | 135 +++++++++++++ nerdd_module/model/write_output.py | 31 +++ 10 files changed, 433 insertions(+), 275 deletions(-) delete mode 100644 nerdd_module/abstract_model.py create mode 100644 nerdd_module/model/__init__.py create mode 100644 nerdd_module/model/assign_mol_id.py create mode 100644 nerdd_module/model/assign_name.py create mode 100644 nerdd_module/model/enforce_schema.py create mode 100644 nerdd_module/model/model.py create mode 100644 nerdd_module/model/read_input.py create mode 100644 nerdd_module/model/simple_model.py create mode 100644 nerdd_module/model/write_output.py diff --git a/nerdd_module/__init__.py b/nerdd_module/__init__.py index 53ed374..96912ba 100644 --- a/nerdd_module/__init__.py +++ b/nerdd_module/__init__.py @@ -1,10 +1,10 @@ -from .abstract_model import * from .cli import * -from .config import * +from .input import ReaderRegistry +from .model import * +from .output import WriterRegistry +from .polyfills import get_entry_points from .problem import * from .version import * -from .polyfills import get_entry_points - for entry_point in get_entry_points("nerdd-module.plugins"): entry_point.load() diff --git a/nerdd_module/abstract_model.py b/nerdd_module/abstract_model.py deleted file mode 100644 index 79b647f..0000000 --- a/nerdd_module/abstract_model.py +++ /dev/null @@ -1,271 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Callable, Iterable, List, Tuple, Union - -import pandas as pd -from rdkit.Chem import Mol - -from .config import AutoConfiguration, Configuration -from .input import DepthFirstExplorer, MoleculeEntry -from .preprocessing import Pipeline, Step, registry -from .problem import Problem, UnknownProblem - -__all__ = ["AbstractModel"] - - -class CustomPreprocessingStep(Step): - def __init__(self, fn: Callable[[Mol], Tuple[Mol, List[Problem]]]): - super().__init__() - self.fn = fn - - def _run(self, mol: Mol) -> Tuple[Mol, List[Problem]]: - return self.fn(mol) - - -class AbstractModel(ABC): - def __init__( - self, - preprocessing_pipeline: Union[str, Pipeline, Iterable[Step], None], - num_processes: int = 1, - ): - # - # preprocessing pipeline - # - if preprocessing_pipeline is None or preprocessing_pipeline == "custom": - self.preprocessing_pipeline = Pipeline( - steps=[CustomPreprocessingStep(self._preprocess_single_mol)] - ) - elif isinstance(preprocessing_pipeline, Pipeline): - self.preprocessing_pipeline = preprocessing_pipeline - elif isinstance(preprocessing_pipeline, str): - if preprocessing_pipeline in registry: - self.preprocessing_pipeline = registry[preprocessing_pipeline] - else: - raise ValueError( - "Invalid preprocessing pipeline. Choose one of the following: " - ", ".join(list(registry.keys()) + ["custom"]) - ) - elif isinstance(preprocessing_pipeline, Iterable) and all( - isinstance(step, Step) for step in preprocessing_pipeline - ): - # mypy assumes that preprocessing_pipeline might be a string (although we - # checked this case above) and complains about that when constructing the - # pipeline - # --> explicitly assert that preprocessing_pipeline is not a string - assert not isinstance(preprocessing_pipeline, str) - self.preprocessing_pipeline = Pipeline(steps=preprocessing_pipeline) - else: - raise ValueError( - f"Invalid preprocessing pipeline {preprocessing_pipeline}." - ) - - # - # reading molecules - # - - # add methods for all supported formats - # TODO - - # - # other parameters - # - self.num_processes = num_processes - - def _preprocess_single_mol(self, mol: Mol) -> Tuple[Mol, List[Problem]]: - # if this method is called, the preprocessing_pipeline was set to "custom" - # and this method has to be overwritten - raise NotImplementedError() - - @abstractmethod - def _predict_mols(self, mols: List[Mol], **kwargs) -> pd.DataFrame: - pass - - def _predict_entries( - self, - inputs: Iterable[MoleculeEntry], - **kwargs, - ) -> pd.DataFrame: - """ - 'preprocessed_mol', 'mol_id', 'input_mol', 'input_type', 'name', - 'input_smiles', 'preprocessed_smiles', 'atom_id', 'mass', 'errors', - 'input' - """ - # - # LOAD MOLECULES - # - df_load = pd.DataFrame( - inputs, - columns=["input", "input_type", "source", "mol", "load_errors"], - ) - df_load["mol_id"] = range(len(df_load)) - - # - # PREPROCESS ALL MOLECULES - # - df_preprocess = pd.DataFrame( - [self.preprocessing_pipeline.run(mol) for mol in df_load.mol], - columns=["preprocessed_mol", "preprocessing_errors"], - ) - - # necessary for models that create multiple (or zero) entries per molecule - df_preprocess["mol_id"] = range(len(df_preprocess)) - - # add raw molecules to dataframe - df_preprocess["input_mol"] = df_load.mol - - # add name to dataframe - df_preprocess["name"] = [ - (mol.GetProp("_Name") if mol is not None and mol.HasProp("_Name") else "") - for mol in df_preprocess.input_mol - ] - - # - # PREPARE PREDICTION OF MOLECULES - # - - # each molecule gets its unique id (0, 1, ..., n) as its name - for id, mol in zip(df_preprocess.mol_id, df_preprocess.preprocessed_mol): - if mol is not None: - mol.SetProp("_Name", str(id)) - - # do the prediction on molecules that are not None - df_valid_subset = df_preprocess[df_preprocess.preprocessed_mol.notnull()] - - # - # PREDICTION - # - df_predictions = self._predict_mols( - df_valid_subset.preprocessed_mol.tolist(), **kwargs - ) - - # - # POST PROCESSING AND ERROR HANDLING - # - - # make sure that reserved column names do not appear in the output dataframe - reserved_column_names = ["input", "name", "input_mol"] - assert ( - set(df_predictions.columns).intersection(reserved_column_names) == set() - ), f"Do not use reserved column names {', '.join(reserved_column_names)}!" - - # during prediction, molecules might have been removed / reordered - # there are three ways to connect the predictions to the original molecules: - # 1. df_prediction contains a column "mol_id" that contains the molecule ids - # 2. df_prediction contains a column "mol" that contains the molecules, which - # have the id as their name so that we can match them to the original - # 3. df_prediction has the same length as the number of valid molecules - # (and we assume that the order of the molecules is the same) - if "mol_id" in df_predictions.columns: - # check that mol_id contains only valid ids - assert set(df_predictions.mol_id).issubset(set(df_valid_subset.mol_id)), ( - f"The mol_id column contains invalid ids: " - f"{set(df_predictions.mol_id).difference(set(df_valid_subset.mol_id))}." - ) - - # use mol_id as index - df_predictions.set_index("mol_id", drop=True, inplace=True) - elif "mol" in df_predictions.columns: - # check that molecule names contain only valid ids - names = df_predictions.mol.apply(lambda mol: int(mol.GetProp("_Name"))) - assert set(names).issubset(set(df_preprocess.mol_id)), ( - f"The mol_id column contains invalid ids: " - f"{set(df_predictions.mol_id).difference(set(df_valid_subset.mol_id))}." - ) - - # use mol_id as index - df_predictions.set_index( - names, - inplace=True, - ) - df_predictions.drop(columns="mol", inplace=True) - else: - assert len(df_predictions) == len(df_valid_subset), ( - "The number of predicted molecules must be equal to the number of " - "valid input molecules." - ) - # use index from input series (type cast if series was empty) - df_predictions.set_index( - df_valid_subset.index.astype("int64"), inplace=True - ) - - # TODO: check derivative_id or atom_id - - # add column that indicates whether a molecule was missing - missing_mol_ids = set(df_preprocess.mol_id).difference(df_predictions.index) - df_preprocess["missing"] = df_preprocess.mol_id.isin(missing_mol_ids) - - # merge the preprocessed molecules with the predictions - df_result = df_preprocess.merge( - df_predictions, left_on="mol_id", right_index=True, how="left" - ) - - # if the result has multiple entries per mol_id, check that atom_id or - # derivative_id is present - if len(df_result) > df_result.mol_id.nunique(): - assert ( - "atom_id" in df_result.columns or "derivative_id" in df_result.columns - ), ( - "The result contains multiple entries per molecule, but does not " - "contain atom_id or derivative_id." - ) - - # merge errors from preprocessing and prediction - if "prediction_errors" in df_result.columns: - df_result["errors"] = ( - df_result.preprocessing_errors + df_result.prediction_errors - ) - df_result.drop(columns=["prediction_errors"], inplace=True) - else: - df_result["errors"] = df_result.preprocessing_errors - df_result["errors"] = df_result.errors + df_result.missing.map( - lambda x: [UnknownProblem()] if x else [] - ) - df_result.drop(columns=["missing", "preprocessing_errors"], inplace=True) - - # convert errors to string - if "errors" not in df_result.columns: - df_result["errors"] = [] - - # delete mol column (not needed anymore) - df_load.drop(columns=["mol"], inplace=True) - - # merge load and prediction - df_result = df_result.merge(df_load, on="mol_id", how="left") - - # merge errors from loading and prediction - df_result["errors"] = [ - load_errors + prediction_errors - for load_errors, prediction_errors in zip( - df_result.load_errors, df_result.errors - ) - ] - - df_result.drop(columns=["load_errors"], inplace=True) - - # reorder columns - mandatory_columns = [ - "mol_id", - "input", - "input_type", - "source", - "name", - "input_mol", - "preprocessed_mol", - "errors", - ] - remaining_columns = [c for c in df_result.columns if c not in mandatory_columns] - df_result = df_result[mandatory_columns + remaining_columns] - - return df_result - - def predict( - self, - inputs: Union[Iterable[str], Iterable[Mol], str, Mol], - input_type=None, - **kwargs, - ): - entries = DepthFirstExplorer().explore(inputs) - - return self._predict_entries(entries, **kwargs) - - def get_config(self) -> Configuration: - return AutoConfiguration(self) diff --git a/nerdd_module/model/__init__.py b/nerdd_module/model/__init__.py new file mode 100644 index 0000000..25e5854 --- /dev/null +++ b/nerdd_module/model/__init__.py @@ -0,0 +1,5 @@ +from .assign_mol_id import * +from .assign_name import * +from .model import * +from .read_input import * +from .simple_model import * diff --git a/nerdd_module/model/assign_mol_id.py b/nerdd_module/model/assign_mol_id.py new file mode 100644 index 0000000..edffbdd --- /dev/null +++ b/nerdd_module/model/assign_mol_id.py @@ -0,0 +1,15 @@ +from ..steps import Step + +__all__ = ["AssignMolId"] + + +class AssignMolId(Step): + def __init__(self): + super().__init__() + + def _run(self, source): + mol_id = 0 + for record in source: + record["mol_id"] = mol_id + mol_id += 1 + yield record diff --git a/nerdd_module/model/assign_name.py b/nerdd_module/model/assign_name.py new file mode 100644 index 0000000..d2d46ed --- /dev/null +++ b/nerdd_module/model/assign_name.py @@ -0,0 +1,19 @@ +from typing import Iterable, Iterator, Union + +from ..steps import MapStep + +__all__ = ["AssignName"] + + +class AssignName(MapStep): + def __init__(self): + super().__init__() + + def _process(self, record: dict) -> Union[dict, Iterable[dict], Iterator[dict]]: + mol = record.get("input_mol") + + record["name"] = ( + mol.GetProp("_Name") if mol is not None and mol.HasProp("_Name") else "" + ) + + return record diff --git a/nerdd_module/model/enforce_schema.py b/nerdd_module/model/enforce_schema.py new file mode 100644 index 0000000..fad4e10 --- /dev/null +++ b/nerdd_module/model/enforce_schema.py @@ -0,0 +1,13 @@ +from typing import Iterator, Optional + +from ..steps import Step + + +class EnforceSchema(Step): + def __init__(self, config): + super().__init__() + self._properties = [p["name"] for p in config["result_properties"]] + + def _run(self, source: Iterator[dict]) -> Iterator[dict]: + for record in source: + yield {k: record.get(k) for k in self._properties} diff --git a/nerdd_module/model/model.py b/nerdd_module/model/model.py new file mode 100644 index 0000000..e9be957 --- /dev/null +++ b/nerdd_module/model/model.py @@ -0,0 +1,187 @@ +from abc import ABC, abstractmethod +from collections import defaultdict +from typing import Any, Iterator, List, Optional + +from rdkit.Chem import Mol + +from ..problem import UnknownProblem +from ..steps import Step +from ..util import call_with_mappings +from .write_output import WriteOutput + + +class Model(ABC): + def __init__(self): + super().__init__() + + @abstractmethod + def _predict_mols(self, mols: List[Mol], **kwargs) -> List[dict]: + pass + + @abstractmethod + def _get_input_steps( + self, input: Any, input_format: Optional[str], **kwargs + ) -> List[Step]: + pass + + @abstractmethod + def _get_output_steps(self, output_format: Optional[str], **kwargs) -> List[Step]: + pass + + def predict( + self, + input, + input_format=None, + output_format=None, + **kwargs, + ) -> Any: + input_steps = self._get_input_steps(input, input_format, **kwargs) + output_steps = self._get_output_steps(output_format, **kwargs) + + steps = [ + *input_steps, + PredictionStep(self, batch_size=1, **kwargs), + *output_steps, + ] + + # build the pipeline from the list of transforms + pipeline = None + for t in steps: + pipeline = t(pipeline) + + # the last pipeline step holds the result + last_step = steps[-1] + assert isinstance(last_step, WriteOutput) + return last_step.get_result() + + +class PredictionStep(Step): + def __init__(self, model: Model, batch_size: int, **kwargs): + super().__init__() + self.model = model + self.batch_size = batch_size + self.kwargs = kwargs + + def _run(self, source: Iterator[dict]) -> Iterator[dict]: + # We need to process the molecules in batches, because most ML models perform + # better when predicting multiple molecules at once. Additionally, we want to + # filter out molecules that could not be preprocessed. + def _batch_and_filter(source, n): + batch = [] + none_batch = [] + for record in source: + if record["preprocessed_mol"] is None: + none_batch.append(record) + else: + batch.append(record) + if len(batch) == n: + yield batch, none_batch + batch = [] + none_batch = [] + if len(batch) > 0 or len(none_batch) > 0: + yield batch, none_batch + + for batch, none_batch in _batch_and_filter(source, self.batch_size): + # return the records where mols are None + yield from none_batch + + # process the batch + yield from self._process_batch(batch) + + def _process_batch(self, batch: List[dict]) -> Iterator[dict]: + # each molecule gets a unique id (0, 1, ..., n) as its temporary id + mol_ids = [record["mol_id"] for record in batch] + mols = [record["preprocessed_mol"] for record in batch] + temporary_mol_ids = range(len(batch)) + for id, mol in zip(temporary_mol_ids, mols): + mol.SetProp("_TempId", str(id)) + + # do the actual prediction + predictions = call_with_mappings( + self.model._predict_mols, + {**self.kwargs, "mols": mols}, + ) + + # During prediction, molecules might have been removed / reordered. + # There are three ways to connect the predictions to the original molecules: + # 1. predictions have a key "mol_id" that contains the molecule ids + # 2. predictions have a key "mol" that contains the molecules that were passed + # to the _predict_mols method (they have a secret _TempId property that we + # can use for the matching) + # 3. the list of predictions has as many records as the batch (and we assume + # that the order of the molecules stayed the same) + if all("mol_id" in record for record in predictions): + pass + elif all("mol" in record for record in predictions): + # check that molecule names contain only valid ids + for record in predictions: + mol_id_from_mol = int(record["mol"].GetProp("_TempId")) + record["mol_id"] = mol_id_from_mol + + # we don't need the molecule anymore (we have it in the batch) + del record["mol"] + else: + assert len(predictions) == len(batch), ( + "The number of predicted molecules must be equal to the number of " + "valid input molecules." + ) + for i, record in enumerate(predictions): + record["mol_id"] = i + + # check that mol_id contains only valid ids + mol_id_set = set(temporary_mol_ids) + for record in predictions: + assert ( + record["mol_id"] in mol_id_set + ), f"The mol_id {record['mol_id']} is not in the batch." + + # create a mapping from mol_id to record (for quick access) + mol_id_to_record = defaultdict(list) + for record in predictions: + mol_id_to_record[record["mol_id"]].append(record) + + # add all records that are missing in the predictions + for mol_id, record in zip(temporary_mol_ids, batch): + if mol_id not in mol_id_to_record: + # notify the user that the molecule could not be predicted + record["problems"].append(UnknownProblem()) + + # add the record to the mapping + mol_id_to_record[mol_id].append(record) + + # If the result has multiple entries per mol_id, check that atom_id or + # derivative_id is present in multi-entry results. + if len(predictions) > len(batch): + for _, records in mol_id_to_record.items(): + if len(records) > 1: + has_atom_id = all("atom_id" in record for record in records) + has_derivative_id = all( + "derivative_id" in record for record in records + ) + assert has_atom_id or has_derivative_id, ( + "The result contains multiple entries per molecule, but does " + "not contain atom_id or derivative_id." + ) + + # TODO: check range and completeness of atom ids and derivative ids + + for key, records in mol_id_to_record.items(): + for record in records: + # merge the prediction with the original record + result = { + **batch[key], + **record, + } + + # remove the temporary id + result["preprocessed_mol"].ClearProp("_TempId") + + # add the original mol id + result["mol_id"] = mol_ids[key] + + # merge problems from preprocessing and prediction + preprocessing_problems = batch[key].get("problems", []) + prediction_problems = record.get("problems", []) + result["problems"] = preprocessing_problems + prediction_problems + + yield result diff --git a/nerdd_module/model/read_input.py b/nerdd_module/model/read_input.py new file mode 100644 index 0000000..2160a30 --- /dev/null +++ b/nerdd_module/model/read_input.py @@ -0,0 +1,24 @@ +from typing import Iterator + +from ..input.explorer import Explorer +from ..steps import Step + +__all__ = ["ReadInput"] + + +class ReadInput(Step): + def __init__(self, explorer: Explorer, input) -> None: + super().__init__(is_source=True) + self._explorer = explorer + self._input = input + + def _run(self, source: Iterator[dict]) -> Iterator[dict]: + for entry in self._explorer.explore(self._input): + record = dict( + raw_input=entry.raw_input, + source=entry.source, + input_type=entry.input_type, + input_mol=entry.mol, + problems=entry.errors, + ) + yield record diff --git a/nerdd_module/model/simple_model.py b/nerdd_module/model/simple_model.py new file mode 100644 index 0000000..a4703f5 --- /dev/null +++ b/nerdd_module/model/simple_model.py @@ -0,0 +1,135 @@ +import sys +from abc import abstractmethod +from typing import Any, Iterable, List, Optional, Tuple, Union + +from rdkit.Chem import Mol + +from ..config import ( + Configuration, + DefaultConfiguration, + DictConfiguration, + MergedConfiguration, + PackageConfiguration, + SearchYamlConfiguration, +) +from ..input import DepthFirstExplorer +from ..preprocessing import PreprocessingStep +from ..problem import Problem +from ..steps import Step +from ..util import get_file_path_to_instance +from .assign_mol_id import AssignMolId +from .assign_name import AssignName +from .enforce_schema import EnforceSchema +from .model import Model +from .read_input import ReadInput +from .write_output import WriteOutput + +__all__ = ["SimpleModel"] + + +class SimpleModel(Model): + def __init__(self, preprocessing_steps: Iterable[Step] = []): + super().__init__() + self._preprocessing_steps = preprocessing_steps + + def _get_input_steps( + self, input: Any, input_format: Optional[str], **kwargs + ) -> List[Step]: + return [ + ReadInput(DepthFirstExplorer(), input), + AssignMolId(), + AssignName(), + *self._preprocessing_steps, + # the following step ensures that the column preprocessed_mol is created + # (even is self._preprocessing_steps is empty) + CustomPreprocessingStep(self), + ] + + def _get_output_steps(self, output_format: Optional[str], **kwargs) -> List[Step]: + output_format = output_format or "pandas" + + return [ + EnforceSchema(self.get_config()), + WriteOutput(output_format, **kwargs), + ] + + def _preprocess(self, mol: Mol) -> Tuple[Optional[Mol], List[Problem]]: + return mol, [] + + @abstractmethod + def _predict_mols(self, mols: List[Mol], **kwargs) -> List[dict]: + pass + + def _get_config(self) -> Union[Configuration, dict]: + return {} + + def get_config(self) -> Configuration: + # get base configuration specified in this class + base_config = self._get_config() + if isinstance(base_config, dict): + base_config = DictConfiguration(base_config) + + # get the class of the nerdd module, e.g. + nerdd_module_class = self.__class__ + + # get the module name of the nerdd module class + # e.g. "cypstrate.cypstrate_model" + python_module = nerdd_module_class.__module__ + + # get the root module name, e.g. "cypstrate" + root_module = python_module.split(".")[0] + + configs = [ + DefaultConfiguration(self), + SearchYamlConfiguration(get_file_path_to_instance(self)), + PackageConfiguration(f"{root_module}.data"), + # base config comes last -> highest priority + base_config, + ] + + # add default properties mol_id, raw_input, etc. + task = MergedConfiguration(*configs).get_task() + + # check whether we need to add to add a property "atom_id" or "derivative_id" + task_based_property = [] + if task == "atom_property_prediction": + task_based_property = [ + {"name": "atom_id", "type": "integer"}, + ] + elif task == "derivative_property_prediction": + task_based_property = [ + {"name": "derivative_id", "type": "integer"}, + ] + + default_properties_start = [ + {"name": "mol_id", "type": "integer"}, + *task_based_property, + {"name": "raw_input", "type": "string"}, + {"name": "input_type", "type": "string"}, + {"name": "name", "type": "string"}, + {"name": "input_mol", "type": "mol"}, + {"name": "input_smiles", "type": "string"}, + {"name": "preprocessed_mol", "type": "mol"}, + {"name": "preprocessed_smiles", "type": "string"}, + ] + + default_properties_end = [ + {"name": "problems", "type": "problem_list"}, + ] + + configs = [ + DictConfiguration({"result_properties": default_properties_start}), + *configs, + DictConfiguration({"result_properties": default_properties_end}), + ] + + return MergedConfiguration(*configs) + + +class CustomPreprocessingStep(PreprocessingStep): + def __init__(self, model: SimpleModel): + super().__init__() + self.model = model + + def _preprocess(self, mol: Mol) -> Tuple[Optional[Mol], List[Problem]]: + return self.model._preprocess(mol) diff --git a/nerdd_module/model/write_output.py b/nerdd_module/model/write_output.py new file mode 100644 index 0000000..90d0a34 --- /dev/null +++ b/nerdd_module/model/write_output.py @@ -0,0 +1,31 @@ +from typing import Iterator, Optional + +from ..output import WriterRegistry +from ..steps import Step + +__all__ = ["WriteOutput"] + + +class WriteOutput(Step): + def __init__(self, output_format: str, **kwargs) -> None: + super().__init__() + self._output_format = output_format + self._kawrgs = kwargs + self._source: Optional[Iterator[dict]] = None + + def get_result(self): + assert ( + self._source is not None + ), "No source data to write. You might need to run the pipeline first." + + # get the correct output writer + writer = WriterRegistry().get_writer(self._output_format, **self._kawrgs) + result = writer.write(self._source) + + return result + + def _run(self, source: Iterator[dict]) -> Iterator[dict]: + self._source = source + + # return an empty iterator to satisfy method return type + return iter([]) From 8cc9a519426044f9a45e25e1b8d96e304b21725b Mon Sep 17 00:00:00 2001 From: Steffen Hirte Date: Wed, 25 Sep 2024 15:38:59 +0200 Subject: [PATCH 10/11] Update testing code * merge different model classes into one class * merge all scenario test cases into one file --- nerdd_module/tests/checks.py | 91 +++++++----- nerdd_module/tests/models/AtomicMassModel.py | 43 ++++-- nerdd_module/tests/models/MolWeightModel.py | 30 +++- .../MolWeightModelWithExplicitMolIds.py | 30 ---- .../models/MolWeightModelWithExplicitMols.py | 27 ---- nerdd_module/tests/models/__init__.py | 2 - nerdd_module/tests/predictions.py | 24 +-- nerdd_module/tests/predictors.py | 38 +++-- nerdd_module/tests/representations.py | 6 +- .../features/atom_property_prediction.feature | 59 -------- .../depth_first_explorer.feature} | 16 +- .../models/atom_property_prediction.feature | 48 ++++++ .../molecule_property_prediction.feature | 55 +++++++ .../molecule_property_prediction.feature | 57 -------- tests/features/preprocessing.feature | 7 - .../preprocessing/filter_by_element.feature | 80 ++++++++++ tests/steps/__init__.py | 1 + tests/steps/checks.py | 49 ++++--- tests/steps/input.py | 114 +++++++++++++++ tests/steps/preprocessing.py | 55 ++++++- tests/test_atom_property_prediction.py | 65 --------- tests/test_features.py | 6 + tests/test_molecule_property_prediction.py | 60 -------- tests/test_preprocessing.py | 12 -- tests/test_reading_formats.py | 137 ------------------ 25 files changed, 524 insertions(+), 588 deletions(-) delete mode 100644 nerdd_module/tests/models/MolWeightModelWithExplicitMolIds.py delete mode 100644 nerdd_module/tests/models/MolWeightModelWithExplicitMols.py delete mode 100644 tests/features/atom_property_prediction.feature rename tests/features/{representations.feature => input/depth_first_explorer.feature} (89%) create mode 100644 tests/features/models/atom_property_prediction.feature create mode 100644 tests/features/models/molecule_property_prediction.feature delete mode 100644 tests/features/molecule_property_prediction.feature delete mode 100644 tests/features/preprocessing.feature create mode 100644 tests/features/preprocessing/filter_by_element.feature create mode 100644 tests/steps/input.py delete mode 100644 tests/test_atom_property_prediction.py create mode 100644 tests/test_features.py delete mode 100644 tests/test_molecule_property_prediction.py delete mode 100644 tests/test_preprocessing.py delete mode 100644 tests/test_reading_formats.py diff --git a/nerdd_module/tests/checks.py b/nerdd_module/tests/checks.py index 744fd6e..a52d57a 100644 --- a/nerdd_module/tests/checks.py +++ b/nerdd_module/tests/checks.py @@ -1,4 +1,3 @@ -import json from ast import literal_eval import numpy as np @@ -6,13 +5,12 @@ from pytest_bdd import parsers, then -@then(parsers.parse("The result should contain the columns:\n{column_names}")) -def check_result_columns(predictions, column_names): - column_names = column_names.strip() - for c in column_names.split("\n"): - assert ( - c in predictions.columns - ), f"Column {c} not in predictions {predictions.columns.tolist()}" +@then(parsers.parse("The result should contain the columns:\n{expected_column_names}")) +def check_result_columns(predictions, expected_column_names): + expected_column_names = expected_column_names.strip().split("\n") + for c in expected_column_names: + for record in predictions: + assert c in record, f"Column {c} not in record {list(record.keys())}" @then( @@ -35,16 +33,23 @@ def check_column_range(subset, column_name, low, high): else: high = float(high) - assert (low <= subset[column_name]).all() - assert (subset[column_name] <= high).all() + values = [record[column_name] for record in subset] + + assert all( + low <= v <= high for v in values + ), f"Column {column_name} is assigned to {values} not in [{low}, {high}]" -@then(parsers.parse("the value in column '{column_name}' should be '{expected_value}'")) -def check_column_value(subset, column_name, expected_value): +@then( + parsers.parse( + "the value in column '{column_name}' should be equal to {expected_value}" + ) +) +def check_column_value_equality(subset, column_name, expected_value): if len(subset) == 0: return - value = subset[column_name].iloc[0] + values = [record[column_name] for record in subset] # expected value is always provided as string # try to convert to float if possible @@ -53,14 +58,46 @@ def check_column_value(subset, column_name, expected_value): except: pass - if expected_value == "(none)": + if expected_value == "(none)" or expected_value is None: # if expected_value is the magic string "(none)", we expect None - assert pd.isnull(value), f"Column {column_name} is assigned to {value} != None" + assert all( + v is None for v in values + ), f"Column {column_name} is assigned to {values} != None" else: # otherwise, we expect the value to be equal to the expected value - assert ( - value == expected_value - ), f"Column {column_name} is assigned to {value} != {expected_value}" + assert all( + v == expected_value for v in values + ), f"Column {column_name} is assigned to {values} != {expected_value}" + + +@then( + parsers.parse( + "the value in column '{column_name}' should not be equal to {forbidden_value}" + ) +) +def check_column_value_inequality(subset, column_name, forbidden_value): + if len(subset) == 0: + return + + values = [record[column_name] for record in subset] + + # expected value is always provided as string + # try to convert to float if possible + try: + forbidden_value = literal_eval(forbidden_value) + except: + pass + + if forbidden_value == "(none)" or forbidden_value is None: + # if expected_value is the magic string "(none)", we expect None + assert all( + v is not None for v in values + ), f"Column {column_name} is assigned to {values} == None" + else: + # otherwise, we expect the value to be equal to the expected value + assert all( + v != forbidden_value for v in values + ), f"Column {column_name} is assigned to {values} == {forbidden_value}" @then( @@ -69,7 +106,7 @@ def check_column_value(subset, column_name, expected_value): ) ) def check_column_subset(subset, column_name, superset): - superset = set(json.loads(superset)) + superset = set(literal_eval(superset)) assert all( set(value).issubset(superset) for value in subset[column_name] @@ -78,7 +115,7 @@ def check_column_subset(subset, column_name, superset): @then(parsers.parse("the value in column '{column_name}' should be one of {superset}")) def check_column_membership(subset, column_name, superset): - superset = json.loads(superset) + superset = literal_eval(superset) assert isinstance( superset, list @@ -99,20 +136,6 @@ def check_png_image(subset, column_name): ).all(), f"Column {column_name} does not contain a PNG image" -@then( - parsers.parse("the value in column '{column_name}' should contain only '{value}'") -) -def check_column_membership_single(predictions, column_name, value): - if value == "(none)": - assert all( - pd.isnull(predictions[column_name]) - ), f"Column {column_name} must be none" - else: - assert all( - value in values for values in predictions[column_name] - ), f"Column {column_name} contains value {value}" - - @then( parsers.parse( "the value in column '{column_name}' should have type '{expected_type}'" diff --git a/nerdd_module/tests/models/AtomicMassModel.py b/nerdd_module/tests/models/AtomicMassModel.py index 3ce75ec..2b6d938 100644 --- a/nerdd_module/tests/models/AtomicMassModel.py +++ b/nerdd_module/tests/models/AtomicMassModel.py @@ -1,21 +1,40 @@ -import pandas as pd -from nerdd_module import AbstractModel +from nerdd_module import SimpleModel +from nerdd_module.preprocessing import Sanitize __all__ = ["AtomicMassModel"] -class AtomicMassModel(AbstractModel): - def __init__(self, preprocessing_pipeline="no_preprocessing", **kwargs): - super().__init__(preprocessing_pipeline, **kwargs) +class AtomicMassModel(SimpleModel): + def __init__(self, preprocessing_steps=[Sanitize()], version="mol_ids", **kwargs): + assert version in [ + "mol_ids", + "mols", + ], f"version must be one of 'mol_ids', or 'mols', but got {version}." + + super().__init__(preprocessing_steps, **kwargs) + self._version = version def _predict_mols(self, mols, multiplier): - return pd.DataFrame( - { - "mol": [m for m in mols for _ in m.GetAtoms()], - "atom_id": [a.GetIdx() for m in mols for a in m.GetAtoms()], - "mass": [a.GetMass() * multiplier for m in mols for a in m.GetAtoms()], - } - ) + if self._version == "mol_ids": + return [ + { + "mol_id": i, + "atom_id": a.GetIdx(), + "mass": a.GetMass() * multiplier, + } + for i, m in enumerate(mols) + for a in m.GetAtoms() + ] + elif self._version == "mols": + return [ + { + "mol": m, + "atom_id": a.GetIdx(), + "mass": a.GetMass() * multiplier, + } + for m in mols + for a in m.GetAtoms() + ] def _get_config(self): return { diff --git a/nerdd_module/tests/models/MolWeightModel.py b/nerdd_module/tests/models/MolWeightModel.py index c374d2f..5554cf9 100644 --- a/nerdd_module/tests/models/MolWeightModel.py +++ b/nerdd_module/tests/models/MolWeightModel.py @@ -1,17 +1,33 @@ -import pandas as pd -from nerdd_module import AbstractModel -from nerdd_module.preprocessing import Sanitize from rdkit.Chem.Descriptors import MolWt +from nerdd_module import SimpleModel +from nerdd_module.preprocessing import Sanitize + __all__ = ["MolWeightModel"] -class MolWeightModel(AbstractModel): - def __init__(self, preprocessing_pipeline=[Sanitize()], **kwargs): - super().__init__(preprocessing_pipeline, **kwargs) +class MolWeightModel(SimpleModel): + def __init__( + self, preprocessing_steps=[Sanitize()], version="order_based", **kwargs + ): + assert version in ["order_based", "mol_ids", "mols"], ( + f"version must be one of 'order_based', 'mol_ids', or 'mols', " + f"but got {version}." + ) + + super().__init__(preprocessing_steps, **kwargs) + self._version = version def _predict_mols(self, mols, multiplier): - return pd.DataFrame({"weight": [MolWt(m) * multiplier for m in mols]}) + if self._version == "order_based": + return [{"weight": MolWt(m) * multiplier} for m in mols] + elif self._version == "mol_ids": + return [ + {"mol_id": i, "weight": MolWt(m) * multiplier} + for i, m in enumerate(mols) + ] + elif self._version == "mols": + return [{"mol": m, "weight": MolWt(m) * multiplier} for m in mols] def _get_config(self): return { diff --git a/nerdd_module/tests/models/MolWeightModelWithExplicitMolIds.py b/nerdd_module/tests/models/MolWeightModelWithExplicitMolIds.py deleted file mode 100644 index af8c308..0000000 --- a/nerdd_module/tests/models/MolWeightModelWithExplicitMolIds.py +++ /dev/null @@ -1,30 +0,0 @@ -import pandas as pd -from nerdd_module import AbstractModel -from nerdd_module.preprocessing import Sanitize -from rdkit.Chem.Descriptors import MolWt - -__all__ = ["MolWeightModelWithExplicitMolIds"] - - -class MolWeightModelWithExplicitMolIds(AbstractModel): - def __init__(self, preprocessing_pipeline=[Sanitize()], **kwargs): - super().__init__(preprocessing_pipeline, **kwargs) - - def _predict_mols(self, mols, multiplier): - return pd.DataFrame( - { - "mol_id": [int(m.GetProp("_Name")) for m in mols], - "weight": [MolWt(m) * multiplier for m in mols], - } - ) - - def _get_config(self): - return { - "name": "mol_weight_model_with_explicit_mol_ids", - "job_parameters": [ - {"name": "multiplier", "type": "float"}, - ], - "result_properties": [ - {"name": "weight", "type": "float"}, - ], - } diff --git a/nerdd_module/tests/models/MolWeightModelWithExplicitMols.py b/nerdd_module/tests/models/MolWeightModelWithExplicitMols.py deleted file mode 100644 index ec84734..0000000 --- a/nerdd_module/tests/models/MolWeightModelWithExplicitMols.py +++ /dev/null @@ -1,27 +0,0 @@ -import pandas as pd -from nerdd_module import AbstractModel -from nerdd_module.preprocessing import Sanitize -from rdkit.Chem.Descriptors import MolWt - -__all__ = ["MolWeightModelWithExplicitMols"] - - -class MolWeightModelWithExplicitMols(AbstractModel): - def __init__(self, preprocessing_pipeline=[Sanitize()], **kwargs): - super().__init__(preprocessing_pipeline, **kwargs) - - def _predict_mols(self, mols, multiplier): - return pd.DataFrame( - {"mol": mols, "weight": [MolWt(m) * multiplier for m in mols]} - ) - - def _get_config(self): - return { - "name": "mol_weight_model_with_explicit_mols", - "job_parameters": [ - {"name": "multiplier", "type": "float"}, - ], - "result_properties": [ - {"name": "weight", "type": "float"}, - ], - } diff --git a/nerdd_module/tests/models/__init__.py b/nerdd_module/tests/models/__init__.py index 847d1e6..91d9a19 100644 --- a/nerdd_module/tests/models/__init__.py +++ b/nerdd_module/tests/models/__init__.py @@ -1,4 +1,2 @@ from .AtomicMassModel import * from .MolWeightModel import * -from .MolWeightModelWithExplicitMolIds import * -from .MolWeightModelWithExplicitMols import * diff --git a/nerdd_module/tests/predictions.py b/nerdd_module/tests/predictions.py index 674f9cd..3cfb718 100644 --- a/nerdd_module/tests/predictions.py +++ b/nerdd_module/tests/predictions.py @@ -1,30 +1,10 @@ -import pandas as pd from pytest_bdd import parsers, then, when @when( - parsers.parse("the model generates predictions for the molecule representations"), - target_fixture="predictions", -) -def predictions( - representations, - model, - input_type, -): - return model.predict( - representations, - ) - - -@when( - "The subset of the result where the input was not None is considered", + "the subset of the result where the input was not None is considered", target_fixture="subset", ) def subset_without_none(predictions): # remove None entries - return predictions[predictions.preprocessed_mol.notnull()] - - -@then("the result should be a pandas DataFrame") -def check_result(predictions): - assert isinstance(predictions, pd.DataFrame) + return [p for p in predictions if p["input_mol"] is not None] diff --git a/nerdd_module/tests/predictors.py b/nerdd_module/tests/predictors.py index 3059703..2512e2b 100644 --- a/nerdd_module/tests/predictors.py +++ b/nerdd_module/tests/predictors.py @@ -1,11 +1,6 @@ from pytest_bdd import given, parsers, when -from .models import ( - AtomicMassModel, - MolWeightModel, - MolWeightModelWithExplicitMolIds, - MolWeightModelWithExplicitMols, -) +from .models import AtomicMassModel, MolWeightModel @given( @@ -21,12 +16,12 @@ def multiplier(multiplier): target_fixture="predictor", ) def molecule_property_predictor(version): - if version == "no_ids": - return MolWeightModel() - elif version == "with_ids": - return MolWeightModelWithExplicitMolIds() - elif version == "with_mols": - return MolWeightModelWithExplicitMols() + assert version in ["order_based", "mol_ids", "mols"], ( + f"version must be one of 'order_based', 'mol_ids', or 'mols', " + f"but got {version}." + ) + + return MolWeightModel(version=version) @given( @@ -34,19 +29,22 @@ def molecule_property_predictor(version): target_fixture="predictor", ) def atom_property_predictor(version): - # if version == "no_ids": - return AtomicMassModel() - # elif version == "with_ids": - # return MolWeightModelWithExplicitMolIds() - # elif version == "with_mols": - # return MolWeightModelWithExplicitMols() + assert version in [ + "mol_ids", + "mols", + ], f"version must be one of 'mol_ids', or 'mols', but got {version}." + + return AtomicMassModel(version=version) @when( - parsers.parse("the model is used on the molecules given as {input_type}"), + parsers.parse("the model generates predictions for the molecule representations"), target_fixture="predictions", ) def predictions(representations, predictor, input_type, multiplier): return predictor.predict( - representations, input_type=input_type, multiplier=multiplier + representations, + input_type=input_type, + multiplier=multiplier, + output_format="record_list", ) diff --git a/nerdd_module/tests/representations.py b/nerdd_module/tests/representations.py index aabb3a2..1766a5d 100644 --- a/nerdd_module/tests/representations.py +++ b/nerdd_module/tests/representations.py @@ -2,9 +2,9 @@ from hypothesis import given as hgiven from hypothesis import seed, settings from hypothesis import strategies as st -from hypothesis_rdkit import mols, smiles +from hypothesis_rdkit import mols from pytest_bdd import given, parsers -from rdkit.Chem import MolFromSmiles, MolToMolBlock, MolToSmiles +from rdkit.Chem import MolToInchi, MolToMolBlock, MolToSmiles @given(parsers.parse("a random seed set to {seed:d}"), target_fixture="random_seed") @@ -29,6 +29,8 @@ def representations_from_molecules(molecules, input_type): converter = MolToSmiles elif input_type == "mol_block": converter = MolToMolBlock + elif input_type == "inchi": + converter = MolToInchi elif input_type == "rdkit_mol": converter = lambda mol: mol else: diff --git a/tests/features/atom_property_prediction.feature b/tests/features/atom_property_prediction.feature deleted file mode 100644 index cdcac83..0000000 --- a/tests/features/atom_property_prediction.feature +++ /dev/null @@ -1,59 +0,0 @@ -Feature: Atom prediction - - Scenario Outline: Predicting a property for each atom - Given a list of random molecules, where entries are None - And the input type is '' - And the representations of the molecules - And an example model predicting atomic masses, version - And a prediction parameter 'multiplier' set to - - When the model is used on the molecules given as - And the subset of the result where the input was not None is considered - - Then the result should be a pandas DataFrame - And the result should contain as many rows as atoms in the input molecules - And the result should contain the columns: - mol_id - name - input_mol - preprocessed_mol - input_type - errors - atom_id - mass - And the input type column should be '' - And the name column should contain valid names - And the mass column should contain the (multiplied) atomic masses - And the input column should contain the input representation - And the number of unique atom ids should be the same as the number of atoms in the input - And the errors column should be a list of problem instances - - Examples: - | input_type | version | num_molecules | multiplier | num_none | - | rdkit_mol | no_ids | 10 | 3 | 0 | - | smiles | no_ids | 10 | 3 | 0 | - | mol_block | no_ids | 10 | 3 | 0 | - | rdkit_mol | no_ids | 10 | 3 | 5 | - | smiles | no_ids | 10 | 3 | 5 | - | mol_block | no_ids | 10 | 3 | 5 | - | rdkit_mol | no_ids | 0 | 3 | 0 | - | smiles | no_ids | 0 | 3 | 0 | - | mol_block | no_ids | 0 | 3 | 0 | - | rdkit_mol | with_ids | 10 | 3 | 0 | - | smiles | with_ids | 10 | 3 | 0 | - | mol_block | with_ids | 10 | 3 | 0 | - | rdkit_mol | with_ids | 10 | 3 | 5 | - | smiles | with_ids | 10 | 3 | 5 | - | mol_block | with_ids | 10 | 3 | 5 | - | rdkit_mol | with_ids | 0 | 3 | 0 | - | smiles | with_ids | 0 | 3 | 0 | - | mol_block | with_ids | 0 | 3 | 0 | - | rdkit_mol | with_mols | 10 | 3 | 0 | - | smiles | with_mols | 10 | 3 | 0 | - | mol_block | with_mols | 10 | 3 | 0 | - | rdkit_mol | with_mols | 10 | 3 | 5 | - | smiles | with_mols | 10 | 3 | 5 | - | mol_block | with_mols | 10 | 3 | 5 | - | rdkit_mol | with_mols | 0 | 3 | 0 | - | smiles | with_mols | 0 | 3 | 0 | - | mol_block | with_mols | 0 | 3 | 0 | \ No newline at end of file diff --git a/tests/features/representations.feature b/tests/features/input/depth_first_explorer.feature similarity index 89% rename from tests/features/representations.feature rename to tests/features/input/depth_first_explorer.feature index 05f5c6c..7576fd3 100644 --- a/tests/features/representations.feature +++ b/tests/features/input/depth_first_explorer.feature @@ -3,7 +3,8 @@ Feature: Reading molecule representations Scenario Outline: Read a single molecule from a valid representation Given a list of 1 random molecules, where 0 entries are None - And the representations of the molecules as + And the input type is '' + And the representations of the molecules When the reader gets the representations as input with input type Then the result should contain the same number of entries as the input And the result should contain the same number of non-null entries as the input @@ -18,7 +19,8 @@ Feature: Reading molecule representations Scenario Outline: Read lists of molecules from valid representations Given a list of random molecules, where entries are None - And the representations of the molecules as + And the input type is '' + And the representations of the molecules When the reader gets the representations as input with input type Then the result should contain the same number of entries as the input @@ -41,8 +43,9 @@ Feature: Reading molecule representations Scenario Outline: Read a single file containing valid representations Given a list of random molecules, where entries are None - And the representations of the molecules as - And a file containing the representations as + And the input type is '' + And the representations of the molecules + And a file containing the representations When the reader gets the file name(s) as input Then the result should contain the same number of entries as the input @@ -59,8 +62,9 @@ Feature: Reading molecule representations Scenario Outline: Read multiple files containing valid representations Given a list of random molecules, where entries are None - And the representations of the molecules as - And a list of files containing the representations as + And the input type is '' + And the representations of the molecules + And a list of files containing the representations When the reader gets the file name(s) as input Then the result should contain the same number of entries as the input And the source of each entry should be one of the file names diff --git a/tests/features/models/atom_property_prediction.feature b/tests/features/models/atom_property_prediction.feature new file mode 100644 index 0000000..ce4ab67 --- /dev/null +++ b/tests/features/models/atom_property_prediction.feature @@ -0,0 +1,48 @@ +Feature: Atom property prediction + + Scenario Outline: Predicting a property for each atom + Given a list of random molecules, where entries are None + And the input type is '' + And the representations of the molecules + And an example model predicting atomic masses, version + And a prediction parameter 'multiplier' set to + + When the model generates predictions for the molecule representations + And the subset of the result where the input was not None is considered + + Then the result should contain as many rows as atoms in the input molecules + And the result should contain the columns: + mol_id + name + input_mol + preprocessed_mol + input_type + problems + atom_id + mass + And the value in column 'input_type' should be equal to '' + And the value in column 'name' should not be equal to None + And the value in column 'mass' should be between 0 and infinity + And the number of unique atom ids should be the same as the number of atoms in the input + And the problems column should be a list of problem instances + + Examples: + | input_type | version | num_molecules | multiplier | num_none | + | rdkit_mol | mol_ids | 10 | 3 | 0 | + | smiles | mol_ids | 10 | 3 | 0 | + | mol_block | mol_ids | 10 | 3 | 0 | + | rdkit_mol | mol_ids | 10 | 3 | 5 | + | smiles | mol_ids | 10 | 3 | 5 | + | mol_block | mol_ids | 10 | 3 | 5 | + | rdkit_mol | mol_ids | 0 | 3 | 0 | + | smiles | mol_ids | 0 | 3 | 0 | + | mol_block | mol_ids | 0 | 3 | 0 | + | rdkit_mol | mols | 10 | 3 | 0 | + | smiles | mols | 10 | 3 | 0 | + | mol_block | mols | 10 | 3 | 0 | + | rdkit_mol | mols | 10 | 3 | 5 | + | smiles | mols | 10 | 3 | 5 | + | mol_block | mols | 10 | 3 | 5 | + | rdkit_mol | mols | 0 | 3 | 0 | + | smiles | mols | 0 | 3 | 0 | + | mol_block | mols | 0 | 3 | 0 | \ No newline at end of file diff --git a/tests/features/models/molecule_property_prediction.feature b/tests/features/models/molecule_property_prediction.feature new file mode 100644 index 0000000..c040a86 --- /dev/null +++ b/tests/features/models/molecule_property_prediction.feature @@ -0,0 +1,55 @@ +Feature: Molecular property prediction + + Scenario Outline: Predicting a molecular property + Given a list of random molecules, where entries are None + And the input type is '' + And the representations of the molecules + And an example model predicting molecular weight, version + And a prediction parameter 'multiplier' set to + + When the model generates predictions for the molecule representations + And the subset of the result where the input was not None is considered + + Then the result should contain the same number of rows as the input + And the result should contain the columns: + mol_id + name + input_mol + preprocessed_mol + input_type + weight + problems + And the value in column 'input_type' should be equal to '' + And the value in column 'name' should not be equal to None + And the value in column 'weight' should be between 0 and infinity + And the problems column should be a list of problem instances + + Examples: + | input_type | version | num_molecules | multiplier | num_none | + | rdkit_mol | order_based | 10 | 3 | 0 | + | smiles | order_based | 10 | 3 | 0 | + | mol_block | order_based | 10 | 3 | 0 | + | rdkit_mol | order_based | 10 | 3 | 5 | + | smiles | order_based | 10 | 3 | 5 | + | mol_block | order_based | 10 | 3 | 5 | + | rdkit_mol | order_based | 0 | 3 | 0 | + | smiles | order_based | 0 | 3 | 0 | + | mol_block | order_based | 0 | 3 | 0 | + | rdkit_mol | mol_ids | 10 | 3 | 0 | + | smiles | mol_ids | 10 | 3 | 0 | + | mol_block | mol_ids | 10 | 3 | 0 | + | rdkit_mol | mol_ids | 10 | 3 | 5 | + | smiles | mol_ids | 10 | 3 | 5 | + | mol_block | mol_ids | 10 | 3 | 5 | + | rdkit_mol | mol_ids | 0 | 3 | 0 | + | smiles | mol_ids | 0 | 3 | 0 | + | mol_block | mol_ids | 0 | 3 | 0 | + | rdkit_mol | mols | 10 | 3 | 0 | + | smiles | mols | 10 | 3 | 0 | + | mol_block | mols | 10 | 3 | 0 | + | rdkit_mol | mols | 10 | 3 | 5 | + | smiles | mols | 10 | 3 | 5 | + | mol_block | mols | 10 | 3 | 5 | + | rdkit_mol | mols | 0 | 3 | 0 | + | smiles | mols | 0 | 3 | 0 | + | mol_block | mols | 0 | 3 | 0 | \ No newline at end of file diff --git a/tests/features/molecule_property_prediction.feature b/tests/features/molecule_property_prediction.feature deleted file mode 100644 index 80b0844..0000000 --- a/tests/features/molecule_property_prediction.feature +++ /dev/null @@ -1,57 +0,0 @@ -Feature: Property prediction - - Scenario Outline: Predicting a molecular property - Given a list of random molecules, where entries are None - And the input type is '' - And the representations of the molecules - And an example model predicting molecular weight, version - And a prediction parameter 'multiplier' set to - - When the model is used on the molecules given as - And the subset of the result where the input was not None is considered - - Then the result should be a pandas DataFrame - And the result should contain the same number of rows as the input - And the result should contain the columns: - mol_id - name - input_mol - preprocessed_mol - input_type - errors - weight - And the input type column should be '' - And the name column should contain valid names - And the weight column should contain the (multiplied) molecule weights - And the input column should contain the input representation - And the errors column should be a list of problem instances - - Examples: - | input_type | version | num_molecules | multiplier | num_none | - | rdkit_mol | no_ids | 10 | 3 | 0 | - | smiles | no_ids | 10 | 3 | 0 | - | mol_block | no_ids | 10 | 3 | 0 | - | rdkit_mol | no_ids | 10 | 3 | 5 | - | smiles | no_ids | 10 | 3 | 5 | - | mol_block | no_ids | 10 | 3 | 5 | - | rdkit_mol | no_ids | 0 | 3 | 0 | - | smiles | no_ids | 0 | 3 | 0 | - | mol_block | no_ids | 0 | 3 | 0 | - | rdkit_mol | with_ids | 10 | 3 | 0 | - | smiles | with_ids | 10 | 3 | 0 | - | mol_block | with_ids | 10 | 3 | 0 | - | rdkit_mol | with_ids | 10 | 3 | 5 | - | smiles | with_ids | 10 | 3 | 5 | - | mol_block | with_ids | 10 | 3 | 5 | - | rdkit_mol | with_ids | 0 | 3 | 0 | - | smiles | with_ids | 0 | 3 | 0 | - | mol_block | with_ids | 0 | 3 | 0 | - | rdkit_mol | with_mols | 10 | 3 | 0 | - | smiles | with_mols | 10 | 3 | 0 | - | mol_block | with_mols | 10 | 3 | 0 | - | rdkit_mol | with_mols | 10 | 3 | 5 | - | smiles | with_mols | 10 | 3 | 5 | - | mol_block | with_mols | 10 | 3 | 5 | - | rdkit_mol | with_mols | 0 | 3 | 0 | - | smiles | with_mols | 0 | 3 | 0 | - | mol_block | with_mols | 0 | 3 | 0 | \ No newline at end of file diff --git a/tests/features/preprocessing.feature b/tests/features/preprocessing.feature deleted file mode 100644 index 364be3d..0000000 --- a/tests/features/preprocessing.feature +++ /dev/null @@ -1,7 +0,0 @@ -Feature: Preprocessing - - Scenario: Preprocessing molecules - Given a list of 10 random molecules, where 0 entries are None - And an example model predicting molecular weight, version no_ids - When the model preprocesses the molecules - Then the preprocessed molecules are valid \ No newline at end of file diff --git a/tests/features/preprocessing/filter_by_element.feature b/tests/features/preprocessing/filter_by_element.feature new file mode 100644 index 0000000..a22b5cf --- /dev/null +++ b/tests/features/preprocessing/filter_by_element.feature @@ -0,0 +1,80 @@ +Feature: Filter molecules by element + + Scenario Outline: Tag molecules with invalid elements + Given an input molecule specified by '' + And the list of allowed elements is ['C', 'H', 'O', 'N'] + And the parameter remove_invalid_molecules is set to False + When the molecules are filtered by element + And the subset of the result where the input was not None is considered + Then the subset should contain the problem 'invalid_elements' + And the value in column 'preprocessed_mol' should not be equal to None + Examples: + | input_smiles | + | CI | + + Scenario Outline: Do not tag molecules having only allowed elements + Given an input molecule specified by '' + And the list of allowed elements is ['C', 'H', 'O', 'N'] + And the parameter remove_invalid_molecules is set to False + When the molecules are filtered by element + And the subset of the result where the input was not None is considered + Then the subset should not contain the problem 'invalid_elements' + And the value in column 'preprocessed_mol' should not be equal to None + Examples: + | input_smiles | + | CCO | + | [H][H] | + + + Scenario Outline: Filter molecules containing hydrogen + Given an input molecule specified by '' + And the list of allowed elements is ['C', 'O', 'N'] + And the parameter remove_invalid_molecules is set to False + When the molecules are filtered by element + And the subset of the result where the input was not None is considered + Then the subset should contain the problem 'invalid_elements' + And the value in column 'preprocessed_mol' should not be equal to None + Examples: + | input_smiles | + | CCO | + | [H][H] | + + + Scenario Outline: Setting allowed elements to an empty list + Given an input molecule specified by '' + And the list of allowed elements is [] + And the parameter remove_invalid_molecules is set to False + When the molecules are filtered by element + And the subset of the result where the input was not None is considered + Then the subset should contain the problem 'invalid_elements' + And the value in column 'preprocessed_mol' should not be equal to None + Examples: + | input_smiles | + | CCO | + | [H][H] | + | O=O | + + Scenario Outline: Setting remove_invalid_molecules to True + Given an input molecule specified by '' + And the list of allowed elements is ['C', 'O', 'H'] + And the parameter remove_invalid_molecules is set to True + When the molecules are filtered by element + And the subset of the result where the input was not None is considered + Then the subset should contain the problem 'invalid_elements' + And the value in column 'preprocessed_mol' should be equal to None + Examples: + | input_smiles | + | CCCCCS | + + Scenario Outline: Using lowercase element symbols also works + Given an input molecule specified by '' + And the list of allowed elements is ['c', 'h', 'o', 'n'] + And the parameter remove_invalid_molecules is set to False + When the molecules are filtered by element + And the subset of the result where the input was not None is considered + Then the subset should not contain the problem 'invalid_elements' + And the value in column 'preprocessed_mol' should not be equal to None + Examples: + | input_smiles | + | CCO | + | [H][H] | \ No newline at end of file diff --git a/tests/steps/__init__.py b/tests/steps/__init__.py index 32fabb5..cf31f23 100644 --- a/tests/steps/__init__.py +++ b/tests/steps/__init__.py @@ -1,2 +1,3 @@ from .checks import * +from .input import * from .preprocessing import * diff --git a/tests/steps/checks.py b/tests/steps/checks.py index dae6127..7e49812 100644 --- a/tests/steps/checks.py +++ b/tests/steps/checks.py @@ -1,13 +1,9 @@ +from collections import defaultdict from typing import Iterable -import pandas as pd -from nerdd_module import Problem -from pytest_bdd import parsers, then - +from pytest_bdd import then -@then("the result should be a pandas DataFrame") -def check_result_type(predictions): - assert isinstance(predictions, pd.DataFrame) +from nerdd_module import Problem @then("the result should contain the same number of rows as the input") @@ -19,20 +15,24 @@ def check_result_length(representations, predictions): assert len(predictions) == len(representations) -@then(parsers.parse("the result should contain the columns:\n{column_names}")) -def check_result_columns(predictions, column_names): - column_names = column_names.strip() - for c in column_names.split("\n"): - assert ( - c in predictions.columns - ), f"Column {c} not in predictions {predictions.columns.tolist()}" +@then( + "the number of unique atom ids should be the same as the number of atoms in the " + "input" +) +def check_atom_ids(subset): + records_per_mol_id = defaultdict(list) + for record in subset: + records_per_mol_id[record["mol_id"]].append(record) -@then(parsers.parse("the input type column should be '{input_type}'")) -def check_input_type_column(subset, input_type): - assert ( - subset.input_type == input_type - ).all(), f"Not all predictions have the input_type {input_type}" + for mol_id, records in records_per_mol_id.items(): + mol = records[0]["preprocessed_mol"] + num_atom_ids = len(set([r["atom_id"] for r in records])) + num_atoms = mol.GetNumAtoms() + assert num_atom_ids == num_atoms, ( + f"Number of atom ids ({num_atom_ids}) does not match number of atoms " + f"({num_atoms})" + ) @then("the result should contain as many rows as atoms in the input molecules") @@ -48,11 +48,12 @@ def check_result_length_atom(molecules, predictions): assert len(predictions) == num_expected_predictions -@then("the errors column should be a list of problem instances") -def check_error_column(predictions): - for error_list in predictions.errors: - assert isinstance(error_list, Iterable) - for e in error_list: +@then("the problems column should be a list of problem instances") +def check_problem_column(predictions): + for record in predictions: + problems_list = record["problems"] + assert isinstance(problems_list, Iterable) + for e in problems_list: assert isinstance( e, Problem ), f"Expected Problem, got {e} of type {type(e)}" diff --git a/tests/steps/input.py b/tests/steps/input.py new file mode 100644 index 0000000..2bbfbf0 --- /dev/null +++ b/tests/steps/input.py @@ -0,0 +1,114 @@ +from tempfile import NamedTemporaryFile + +import numpy as np +from pytest_bdd import given, parsers, then, when + +from nerdd_module.input import DepthFirstExplorer + + +@given( + parsers.parse("a file containing the representations"), + target_fixture="representation_files", +) +def representation_file(representations, input_type): + with NamedTemporaryFile("w", delete=False) as f: + for representation in representations: + if representation is None: + f.write("None") + else: + f.write(representation) + if input_type in ["smiles", "inchi"]: + f.write("\n") + elif input_type == "mol_block": + f.write("\n$$$$\n") + f.flush() + return f.name + + +@given( + parsers.parse( + "a list of {num_files:d} files containing the representations", + ), + target_fixture="representation_files", +) +def representation_files(representations, input_type, num_files): + # choose num_files-1 numbers to split the representations into num_files parts + # the while loop makes sure that each part contains at least one valid molecule + while True: + split_indices = np.random.choice( + len(representations), size=num_files - 1, replace=False + ) + split_indices = np.sort(split_indices) + + # split the representations + split_representations = np.split(representations, split_indices) + + # check if each part contains at least one valid molecule + if all( + any(representation is not None for representation in split_representation) + for split_representation in split_representations + ): + break + + # write the representations to files + representations_files = [] + + for _, split_representation in enumerate(split_representations): + with NamedTemporaryFile("w", delete=False) as f: + for representation in split_representation: + if representation is None: + f.write("None") + else: + f.write(representation) + if input_type in ["smiles", "inchi"]: + f.write("\n") + elif input_type == "mol_block": + f.write("\n$$$$\n") + f.flush() + representations_files.append(f.name) + + return representations_files + + +@when( + parsers.parse( + "the reader gets the representations as input with input type {input_type}" + ), + target_fixture="entries", +) +def entries(representations, input_type): + if input_type == "unknown": + input_type = None + if len(representations) == 1: + return list(DepthFirstExplorer().explore(representations[0])) + else: + return list(DepthFirstExplorer().explore(representations)) + + +@when("the reader gets the file name(s) as input", target_fixture="entries") +def entries_from_file(representation_files): + return list(DepthFirstExplorer().explore(representation_files)) + + +@then("the result should contain the same number of entries as the input") +def check_predictions(representations, entries): + if len(representations) == 0: + # expect one entry saying that nothing could be read from this source + assert len(entries) == 1 + else: + assert len(entries) == len(representations) + + +@then("the result should contain the same number of non-null entries as the input") +def check_predictions_nonnull(representations, entries): + assert len([e for e in entries if e.mol is not None]) == len( + [e for e in representations if e is not None] + ) + + +@then("the source of each entry should be one of the file names") +def check_source(representation_files, entries): + for entry in entries: + assert ( + entry.source[0] in representation_files + ), f"source {entry.source[0]} not in {representation_files}" diff --git a/tests/steps/preprocessing.py b/tests/steps/preprocessing.py index 6d63435..bf02fac 100644 --- a/tests/steps/preprocessing.py +++ b/tests/steps/preprocessing.py @@ -1,9 +1,54 @@ -from pytest_bdd import parsers, when +from pytest_bdd import given, parsers, then, when + +from nerdd_module.input import DepthFirstExplorer +from nerdd_module.model import ReadInput +from nerdd_module.preprocessing import FilterByElement, Sanitize + + +@given( + parsers.parse("the list of allowed elements is {l}"), + target_fixture="allowed_elements", +) +def allowed_elements(l): + return eval(l) + + +@given( + parsers.parse("the parameter remove_invalid_molecules is set to {value}"), + target_fixture="remove_invalid_molecules", +) +def remove_invalid_molecules(value): + return eval(value) @when( - parsers.parse("the model preprocesses the molecules"), - target_fixture="preprocessed_molecules", + parsers.parse("the molecules are filtered by element"), + target_fixture="predictions", ) -def preprocessed_molecules(molecules, predictor): - return [predictor.preprocessing_pipeline.run(mol) for mol in molecules] +def preprocessed_molecules_filter_by_element( + representations, allowed_elements, remove_invalid_molecules +): + input_step = ReadInput(DepthFirstExplorer(), representations) + sanitize = Sanitize() + filter_by_element = FilterByElement( + allowed_elements, remove_invalid_molecules=remove_invalid_molecules + ) + return list(filter_by_element(sanitize(input_step()))) + + +@then(parsers.parse("the subset should contain the problem '{problem}'")) +def check_problem_in_list(subset, problem): + for record in subset: + problems = record.get("problems", []) + assert problem in [ + p.type for p in problems + ], f"Problem list lacks problem {problem} in record {record}" + + +@then(parsers.parse("the subset should not contain the problem '{problem}'")) +def check_problem_not_in_list(subset, problem): + for record in subset: + problems = record.get("problems", []) + assert problem not in [ + p.type for p in problems + ], f"Problem list contains problem {problem} in record {record}" diff --git a/tests/test_atom_property_prediction.py b/tests/test_atom_property_prediction.py deleted file mode 100644 index 583e769..0000000 --- a/tests/test_atom_property_prediction.py +++ /dev/null @@ -1,65 +0,0 @@ -import numpy as np -from pytest_bdd import parsers, scenario, then, when - - -@scenario( - "features/atom_property_prediction.feature", "Predicting a property for each atom" -) -def test_atom_property_prediction(): - pass - - -@when( - "the subset of the result where the input was not None is considered", - target_fixture="subset", -) -def subset_without_none(predictions): - # remove None entries - return predictions[predictions.input_mol.notnull()] - - -@then("the name column should contain valid names") -def check_name_column(subset): - if len(subset) > 0: - assert subset.name.notnull().all(), "Some molecules have no name" - - -@then("the input column should contain the input representation") -def check_input_column(representations, subset): - # if input is not a mol, then smiles / mol_blocks were provided - # --> there must be a column called "input" - if not subset.input_type.eq("rdkit_mol").all(): - assert "input" in subset.columns, "Column input not in predictions" - - valid_molecules = [m for m in representations if m is not None] - - # the input column must contain the input representation (e.g. smiles) - assert ( - subset.drop_duplicates("mol_id").input == valid_molecules - ).all(), "Input column contains wrong data" - - -@then(parsers.parse("the mass column should contain the (multiplied) atomic masses")) -def check_weight_column(subset, multiplier): - if len(subset) > 0: - expected_masses = [ - m.GetAtomWithIdx(int(atom_id)).GetMass() * multiplier - for (m, atom_id) in zip(subset.preprocessed_mol, subset.atom_id) - ] - assert np.allclose( - subset.mass, expected_masses - ), f"the provided weights do not match the expected weights" - - -@then( - "the number of unique atom ids should be the same as the number of atoms in the " - "input" -) -def check_atom_ids(subset): - for _, group in subset.groupby("mol_id"): - num_atom_ids = group.atom_id.nunique() - num_atoms = group.preprocessed_mol.iloc[0].GetNumAtoms() - assert num_atom_ids == num_atoms, ( - f"Number of atom ids ({num_atom_ids}) does not match number of atoms " - f"({num_atoms})" - ) diff --git a/tests/test_features.py b/tests/test_features.py new file mode 100644 index 0000000..7abaab3 --- /dev/null +++ b/tests/test_features.py @@ -0,0 +1,6 @@ +from pytest_bdd import scenarios + +scenarios("features/input/depth_first_explorer.feature") +scenarios("features/preprocessing/filter_by_element.feature") +scenarios("features/models/atom_property_prediction.feature") +scenarios("features/models/molecule_property_prediction.feature") diff --git a/tests/test_molecule_property_prediction.py b/tests/test_molecule_property_prediction.py deleted file mode 100644 index d023097..0000000 --- a/tests/test_molecule_property_prediction.py +++ /dev/null @@ -1,60 +0,0 @@ -import numpy as np -from pytest_bdd import parsers, scenario, then, when -from rdkit.Chem.Descriptors import MolWt - - -@scenario( - "features/molecule_property_prediction.feature", "Predicting a molecular property" -) -def test_molecule_property_prediction(): - pass - - -@when( - "the subset of the result where the input was not None is considered", - target_fixture="subset", -) -def subset_without_none(predictions): - # filter None entries - return predictions[predictions.input_mol.notnull()] - - -@then(parsers.parse("the result should contain the columns:\n{column_names}")) -def check_result_columns(predictions, column_names): - column_names = column_names.strip() - for c in column_names.split("\n"): - assert ( - c in predictions.columns - ), f"Column {c} not in predictions {predictions.columns.tolist()}" - - -@then("the name column should contain valid names") -def check_name_column(subset): - if len(subset) > 0: - assert subset.name.notnull().all(), "Some molecules have no name" - - -@then("the input column should contain the input representation") -def check_input_column(representations, subset): - # if input is not a mol, then smiles / mol_blocks were provided - # --> there must be a column called "input" - if not subset.input_type.eq("rdkit_mol").all(): - assert "input" in subset.columns, "Column input not in predictions" - - valid_molecules = [m for m in representations if m is not None] - - # the input column must contain the input representation (e.g. smiles) - assert ( - subset.input == valid_molecules - ).all(), "Input column contains wrong data" - - -@then( - parsers.parse("the weight column should contain the (multiplied) molecule weights") -) -def check_weight_column(subset, multiplier): - if len(subset) > 0: - expected_weights = subset.preprocessed_mol.map(MolWt) * multiplier - assert np.allclose( - subset.weight, expected_weights - ), f"the provided weights do not match the expected weights" diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py deleted file mode 100644 index 49415cb..0000000 --- a/tests/test_preprocessing.py +++ /dev/null @@ -1,12 +0,0 @@ -from pytest_bdd import scenario, then - - -@scenario("features/preprocessing.feature", "Preprocessing molecules") -def test_preprocessing(): - pass - - -@then("the preprocessed molecules are valid") -def preprocessed_molecules_are_valid(preprocessed_molecules): - for mol, errors in preprocessed_molecules: - assert mol is not None and mol.GetNumAtoms() > 0 diff --git a/tests/test_reading_formats.py b/tests/test_reading_formats.py deleted file mode 100644 index 13f1713..0000000 --- a/tests/test_reading_formats.py +++ /dev/null @@ -1,137 +0,0 @@ -# from tempfile import NamedTemporaryFile - -# import numpy as np -# from nerdd_module.io import DepthFirstExplorer -# from pytest_bdd import given, parsers, scenario, then, when - - -# @scenario( -# "features/representations.feature", -# "Read a single molecule from a valid representation", -# ) -# def test_reading_single_representation_with_given_input_type(): -# pass - - -# @scenario( -# "features/representations.feature", -# "Read lists of molecules from valid representations", -# ) -# def test_reading_multiple_representation_with_given_input_type(): -# pass - - -# @scenario( -# "features/representations.feature", -# "Read multiple files containing valid representations", -# ) -# def test_reading_multiple_files(): -# pass - - -# @given( -# parsers.parse("a file containing the representations as {input_type}"), -# target_fixture="representation_files", -# ) -# def representation_file(representations, input_type): -# with NamedTemporaryFile("w", delete=False) as f: -# for representation in representations: -# if representation is None: -# f.write("None") -# else: -# f.write(representation) -# if input_type in ["smiles", "inchi"]: -# f.write("\n") -# elif input_type == "mol_block": -# f.write("\n$$$$\n") -# f.flush() -# return f.name - - -# @given( -# parsers.parse( -# "a list of {num_files:d} files containing the representations as {input_type}", -# ), -# target_fixture="representation_files", -# ) -# def representation_files(representations, input_type, num_files): -# # choose num_files-1 numbers to split the representations into num_files parts -# # the while loop makes sure that each part contains at least one valid molecule -# while True: -# split_indices = np.random.choice( -# len(representations), size=num_files - 1, replace=False -# ) -# split_indices = np.sort(split_indices) - -# # split the representations -# split_representations = np.split(representations, split_indices) - -# # check if each part contains at least one valid molecule -# if all( -# any(representation is not None for representation in split_representation) -# for split_representation in split_representations -# ): -# break - -# # write the representations to files -# representations_files = [] - -# for _, split_representation in enumerate(split_representations): -# with NamedTemporaryFile("w", delete=False) as f: -# for representation in split_representation: -# if representation is None: -# f.write("None") -# else: -# f.write(representation) -# if input_type in ["smiles", "inchi"]: -# f.write("\n") -# elif input_type == "mol_block": -# f.write("\n$$$$\n") -# f.flush() -# representations_files.append(f.name) - -# return representations_files - - -# @when( -# parsers.parse( -# "the reader gets the representations as input with input type {input_type}" -# ), -# target_fixture="entries", -# ) -# def entries(representations, input_type): -# if input_type == "unknown": -# input_type = None -# if len(representations) == 1: -# return list(DepthFirstExplorer().explore(representations[0])) -# else: -# return list(DepthFirstExplorer().explore(representations)) - - -# @when("the reader gets the file name(s) as input", target_fixture="entries") -# def entries_from_file(representation_files): -# return list(DepthFirstExplorer().explore(representation_files)) - - -# @then("the result should contain the same number of entries as the input") -# def check_predictions(representations, entries): -# if len(representations) == 0: -# # expect one entry saying that nothing could be read from this source -# assert len(entries) == 1 -# else: -# assert len(entries) == len(representations) - - -# @then("the result should contain the same number of non-null entries as the input") -# def check_predictions_nonnull(representations, entries): -# assert len([e for e in entries if e.mol is not None]) == len( -# [e for e in representations if e is not None] -# ) - - -# @then("the source of each entry should be one of the file names") -# def check_source(representation_files, entries): -# for entry in entries: -# assert ( -# entry.source[0] in representation_files -# ), f"source {entry.source[0]} not in {representation_files}" From bcdafa1f22c00e06c261e0aa08269b9daf342c4d Mon Sep 17 00:00:00 2001 From: Steffen Hirte Date: Wed, 25 Sep 2024 15:39:26 +0200 Subject: [PATCH 11/11] Bump version --- setup.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/setup.py b/setup.py index eb319d5..a536734 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ setup( name="nerdd-module", - version="0.2.6", + version="0.3.0", maintainer="Steffen Hirte", maintainer_email="steffen.hirte@univie.ac.at", packages=find_packages(), @@ -40,7 +40,6 @@ extras_require={ "dev": [ "black", - "isort", ], "csp": [ # note: version 1.0.0 of chembl_structure_pipeline is not available on pypi