Skip to content

Commit

Permalink
add abstract base class for scoring methods (#247)
Browse files Browse the repository at this point in the history
* change staticmethod to classmethod for `setup` method

* change `ScoringMethod` to abstract base class and rename it

* update the use of the new `ScoringBase` class
  • Loading branch information
CunliangGeng authored Jun 6, 2024
1 parent 7a96b1c commit 3de97c1
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 95 deletions.
14 changes: 7 additions & 7 deletions src/nplinker/nplinker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
from .metabolomics import MolecularFamily
from .metabolomics import Spectrum
from .pickler import save_pickled_data
from .scoring.abc import ScoringBase
from .scoring.link_collection import LinkCollection
from .scoring.metcalf_scoring import MetcalfScoring
from .scoring.methods import ScoringMethod
from .scoring.np_class_scoring import NPClassScoring
from .scoring.rosetta_scoring import RosettaScoring

Expand All @@ -37,9 +37,9 @@ class NPLinker:
# default set of enabled scoring methods
# TODO: ideally these shouldn't be hardcoded like this
SCORING_METHODS = {
MetcalfScoring.NAME: MetcalfScoring,
RosettaScoring.NAME: RosettaScoring,
NPClassScoring.NAME: NPClassScoring,
MetcalfScoring.name: MetcalfScoring,
RosettaScoring.name: RosettaScoring,
NPClassScoring.name: NPClassScoring,
}

def __init__(self, config_file: str | PathLike):
Expand Down Expand Up @@ -266,7 +266,7 @@ def get_links(

if not self._datalinks:
logger.debug("Creating internal datalinks object")
self._datalinks = self.scoring_method(MetcalfScoring.NAME).datalinks
self._datalinks = self.scoring_method(MetcalfScoring.name).datalinks
logger.debug("Created internal datalinks object")

if len(link_collection) == 0:
Expand Down Expand Up @@ -318,7 +318,7 @@ def get_common_strains(
and values are a list of shared Strain objects.
"""
if not self._datalinks:
self._datalinks = self.scoring_method(MetcalfScoring.NAME).datalinks
self._datalinks = self.scoring_method(MetcalfScoring.name).datalinks
common_strains = self._datalinks.get_common_strains(met, gcfs, filter_no_shared)
return common_strains

Expand Down Expand Up @@ -401,7 +401,7 @@ def class_matches(self):
"""ClassMatches with the matched classes and scoring tables from MIBiG."""
return self._class_matches

def scoring_method(self, name: str) -> ScoringMethod | None:
def scoring_method(self, name: str) -> ScoringBase | None:
"""Return an instance of a scoring method.
Args:
Expand Down
4 changes: 2 additions & 2 deletions src/nplinker/scoring/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .abc import ScoringBase
from .link_collection import LinkCollection
from .metcalf_scoring import MetcalfScoring
from .methods import ScoringMethod
from .object_link import ObjectLink


__all__ = ["LinkCollection", "MetcalfScoring", "ScoringMethod", "ObjectLink"]
__all__ = ["LinkCollection", "MetcalfScoring", "ScoringBase", "ObjectLink"]
56 changes: 56 additions & 0 deletions src/nplinker/scoring/abc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from __future__ import annotations
import logging
from abc import ABC
from abc import abstractmethod
from typing import TYPE_CHECKING


if TYPE_CHECKING:
from nplinker import NPLinker
from . import LinkCollection

logger = logging.getLogger(__name__)


class ScoringBase(ABC):
"""Abstract base class of scoring methods.
Attributes:
name: The name of the scoring method.
npl: The NPLinker object.
"""

name: str = "ScoringBase"

def __init__(self, npl: NPLinker):
"""Initialize the scoring method.
Args:
npl: The NPLinker object.
"""
self.npl = npl

@classmethod
@abstractmethod
def setup(cls, npl: NPLinker):
"""Setup class level attributes."""

@abstractmethod
def get_links(self, *objects, link_collection: LinkCollection) -> LinkCollection:
"""Get links information for the given objects.
Args:
objects: A set of objects.
link_collection: The LinkCollection object.
Returns:
The LinkCollection object.
"""

@abstractmethod
def format_data(self, data) -> str:
"""Format the scoring data to a string."""

@abstractmethod
def sort(self, objects, reverse=True) -> list:
"""Sort the given objects based on the scoring data."""
39 changes: 18 additions & 21 deletions src/nplinker/scoring/metcalf_scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
from nplinker.metabolomics import Spectrum
from nplinker.pickler import load_pickled_data
from nplinker.pickler import save_pickled_data
from .abc import ScoringBase
from .linking import LINK_TYPES
from .linking import DataLinks
from .linking import LinkFinder
from .linking import isinstance_all
from .methods import ScoringMethod
from .object_link import ObjectLink


Expand All @@ -24,19 +24,19 @@
logger = logging.getLogger(__name__)


class MetcalfScoring(ScoringMethod):
class MetcalfScoring(ScoringBase):
"""Metcalf scoring method.
Attributes:
name: The name of this scoring method, set to a fixed value `metcalf`.
DATALINKS: The DataLinks object to use for scoring.
LINKFINDER: The LinkFinder object to use for scoring.
NAME: The name of the scoring method. This is set to 'metcalf'.
CACHE: The name of the cache file to use for storing the MetcalfScoring.
"""

name = "metcalf"
DATALINKS = None
LINKFINDER = None
NAME = "metcalf"
CACHE = "cache_metcalf_scoring.pckl"

def __init__(self, npl: NPLinker) -> None:
Expand All @@ -57,21 +57,20 @@ def __init__(self, npl: NPLinker) -> None:
self.cutoff = 1.0
self.standardised = True

# TODO CG: not sure why using staticmethod here. Check later and refactor if possible
# TODO CG: refactor this method and extract code for cache file to a separate method
@staticmethod
def setup(npl: NPLinker):
"""Setup the MetcalfScoring object.
@classmethod
def setup(cls, npl: NPLinker):
"""Setup the DataLinks and LinkFinder objects.
DataLinks and LinkFinder objects are created and cached for later use.
This method is only called once to setup the DataLinks and LinkFinder objects.
"""
logger.info(
"MetcalfScoring.setup (bgcs={}, gcfs={}, spectra={}, molfams={}, strains={})".format(
len(npl.bgcs), len(npl.gcfs), len(npl.spectra), len(npl.molfams), len(npl.strains)
)
)

cache_file = npl.output_dir / MetcalfScoring.CACHE
cache_file = npl.output_dir / cls.CACHE

# the metcalf preprocessing can take a long time for large datasets, so it's
# better to cache as the data won't change unless the number of objects does
Expand All @@ -97,27 +96,25 @@ def setup(npl: NPLinker):
break

if cache_ok:
MetcalfScoring.DATALINKS = datalinks
MetcalfScoring.LINKFINDER = linkfinder
cls.DATALINKS = datalinks
cls.LINKFINDER = linkfinder

if MetcalfScoring.DATALINKS is None:
if cls.DATALINKS is None:
logger.info("MetcalfScoring.setup preprocessing dataset (this may take some time)")
MetcalfScoring.DATALINKS = DataLinks(npl.gcfs, npl.spectra, npl.molfams, npl.strains)
MetcalfScoring.LINKFINDER = LinkFinder()
MetcalfScoring.LINKFINDER.calc_score(MetcalfScoring.DATALINKS, link_type=LINK_TYPES[0])
MetcalfScoring.LINKFINDER.calc_score(MetcalfScoring.DATALINKS, link_type=LINK_TYPES[1])
cls.DATALINKS = DataLinks(npl.gcfs, npl.spectra, npl.molfams, npl.strains)
cls.LINKFINDER = LinkFinder()
cls.LINKFINDER.calc_score(MetcalfScoring.DATALINKS, link_type=LINK_TYPES[0])
cls.LINKFINDER.calc_score(MetcalfScoring.DATALINKS, link_type=LINK_TYPES[1])
logger.info("MetcalfScoring.setup caching results")
save_pickled_data(
(dataset_counts, MetcalfScoring.DATALINKS, MetcalfScoring.LINKFINDER), cache_file
)
save_pickled_data((dataset_counts, cls.DATALINKS, cls.LINKFINDER), cache_file)

logger.info("MetcalfScoring.setup completed")

# TODO CG: is it needed? remove it if not
@property
def datalinks(self) -> DataLinks | None:
"""Get the DataLinks object used for scoring."""
return MetcalfScoring.DATALINKS
return self.DATALINKS

def get_links(
self, *objects: GCF | Spectrum | MolecularFamily, link_collection: LinkCollection
Expand Down
50 changes: 0 additions & 50 deletions src/nplinker/scoring/methods.py

This file was deleted.

12 changes: 6 additions & 6 deletions src/nplinker/scoring/np_class_scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@
from nplinker.genomics import BGC
from nplinker.genomics import GCF
from nplinker.metabolomics import Spectrum
from nplinker.scoring.abc import ScoringBase
from nplinker.scoring.metcalf_scoring import MetcalfScoring
from nplinker.scoring.methods import ScoringMethod
from nplinker.scoring.object_link import ObjectLink


logger = logging.getLogger(__name__)


class NPClassScoring(ScoringMethod):
NAME = "npclassscore"
class NPClassScoring(ScoringBase):
name = "npclassscore"

def __init__(self, npl):
super().__init__(npl)
Expand Down Expand Up @@ -313,8 +313,8 @@ def _get_met_classes(self, spec_like, method="mix"):
)
return spec_like_classes, spec_like_classes_names_inds

@staticmethod
def setup(npl):
@classmethod
def setup(cls, npl):
"""Perform any one-off initialisation required (will only be called once)."""
logger.info("Set up NPClassScore scoring")
met_options = npl.chem_classes.class_predict_options
Expand Down Expand Up @@ -347,7 +347,7 @@ def get_links(self, objects, link_collection):
logger.info("Using Metcalf scoring to get shared strains")
# get mapping of shared strains
if not self.npl._datalinks:
self.npl._datalinks = self.npl.scoring_method(MetcalfScoring.NAME).datalinks
self.npl._datalinks = self.npl.scoring_method(MetcalfScoring.name).datalinks
if obj_is_gen:
common_strains = self.npl.get_common_strains(targets, objects)
else:
Expand Down
20 changes: 11 additions & 9 deletions src/nplinker/scoring/rosetta_scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@
from nplinker.genomics.bgc import BGC
from nplinker.genomics.gcf import GCF
from nplinker.metabolomics import MolecularFamily
from nplinker.scoring.methods import ScoringMethod
from nplinker.scoring.abc import ScoringBase
from nplinker.scoring.object_link import ObjectLink
from nplinker.scoring.rosetta.rosetta import Rosetta


logger = logging.getLogger(__name__)


class RosettaScoring(ScoringMethod):
NAME = "rosetta"
class RosettaScoring(ScoringBase):
name = "rosetta"
ROSETTA_OBJ = None

def __init__(self, npl):
Expand All @@ -22,10 +22,14 @@ def __init__(self, npl):
self.spec_score_cutoff = 0.0
self.bgc_score_cutoff = 0.0

@staticmethod
def setup(npl):
@classmethod
def setup(cls, npl):
"""Setup the Rosetta object and run the scoring algorithm.
This method is only called once to setup the Rosetta object.
"""
logger.info("RosettaScoring setup")
RosettaScoring.ROSETTA_OBJ = Rosetta(npl, ignore_genomic_cache=False)
cls.ROSETTA_OBJ = Rosetta(npl, ignore_genomic_cache=False)
ms1_tol = Rosetta.DEF_MS1_TOL
ms2_tol = Rosetta.DEF_MS2_TOL
score_thresh = Rosetta.DEF_SCORE_THRESH
Expand All @@ -35,9 +39,7 @@ def setup(npl):
npl.config
)

RosettaScoring.ROSETTA_OBJ.run(
npl.spectra, npl.bgcs, ms1_tol, ms2_tol, score_thresh, min_match_peaks
)
cls.ROSETTA_OBJ.run(npl.spectra, npl.bgcs, ms1_tol, ms2_tol, score_thresh, min_match_peaks)
logger.info("RosettaScoring setup completed")

@staticmethod
Expand Down

0 comments on commit 3de97c1

Please sign in to comment.