Skip to content

Commit

Permalink
update the use of the new ScoringBase class
Browse files Browse the repository at this point in the history
  • Loading branch information
CunliangGeng committed May 31, 2024
1 parent 93fa861 commit 3f8bbdd
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 18 deletions.
14 changes: 7 additions & 7 deletions src/nplinker/nplinker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
from .metabolomics import MolecularFamily
from .metabolomics import Spectrum
from .pickler import save_pickled_data
from .scoring.abc import ScoringBase
from .scoring.link_collection import LinkCollection
from .scoring.metcalf_scoring import MetcalfScoring
from .scoring.methods import ScoringMethod
from .scoring.np_class_scoring import NPClassScoring
from .scoring.rosetta_scoring import RosettaScoring

Expand All @@ -37,9 +37,9 @@ class NPLinker:
# default set of enabled scoring methods
# TODO: ideally these shouldn't be hardcoded like this
SCORING_METHODS = {
MetcalfScoring.NAME: MetcalfScoring,
RosettaScoring.NAME: RosettaScoring,
NPClassScoring.NAME: NPClassScoring,
MetcalfScoring.name: MetcalfScoring,
RosettaScoring.name: RosettaScoring,
NPClassScoring.name: NPClassScoring,
}

def __init__(self, config_file: str | PathLike):
Expand Down Expand Up @@ -266,7 +266,7 @@ def get_links(

if not self._datalinks:
logger.debug("Creating internal datalinks object")
self._datalinks = self.scoring_method(MetcalfScoring.NAME).datalinks
self._datalinks = self.scoring_method(MetcalfScoring.name).datalinks
logger.debug("Created internal datalinks object")

if len(link_collection) == 0:
Expand Down Expand Up @@ -318,7 +318,7 @@ def get_common_strains(
and values are a list of shared Strain objects.
"""
if not self._datalinks:
self._datalinks = self.scoring_method(MetcalfScoring.NAME).datalinks
self._datalinks = self.scoring_method(MetcalfScoring.name).datalinks
common_strains = self._datalinks.get_common_strains(met, gcfs, filter_no_shared)
return common_strains

Expand Down Expand Up @@ -401,7 +401,7 @@ def class_matches(self):
"""ClassMatches with the matched classes and scoring tables from MIBiG."""
return self._class_matches

def scoring_method(self, name: str) -> ScoringMethod | None:
def scoring_method(self, name: str) -> ScoringBase | None:
"""Return an instance of a scoring method.
Args:
Expand Down
8 changes: 4 additions & 4 deletions src/nplinker/scoring/metcalf_scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
from nplinker.metabolomics import Spectrum
from nplinker.pickler import load_pickled_data
from nplinker.pickler import save_pickled_data
from .abc import ScoringBase
from .linking import LINK_TYPES
from .linking import DataLinks
from .linking import LinkFinder
from .linking import isinstance_all
from .methods import ScoringMethod
from .object_link import ObjectLink


Expand All @@ -24,19 +24,19 @@
logger = logging.getLogger(__name__)


class MetcalfScoring(ScoringMethod):
class MetcalfScoring(ScoringBase):
"""Metcalf scoring method.
Attributes:
name: The name of this scoring method, set to a fixed value `metcalf`.
DATALINKS: The DataLinks object to use for scoring.
LINKFINDER: The LinkFinder object to use for scoring.
NAME: The name of the scoring method. This is set to 'metcalf'.
CACHE: The name of the cache file to use for storing the MetcalfScoring.
"""

name = "metcalf"
DATALINKS = None
LINKFINDER = None
NAME = "metcalf"
CACHE = "cache_metcalf_scoring.pckl"

def __init__(self, npl: NPLinker) -> None:
Expand Down
8 changes: 4 additions & 4 deletions src/nplinker/scoring/np_class_scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@
from nplinker.genomics import BGC
from nplinker.genomics import GCF
from nplinker.metabolomics import Spectrum
from nplinker.scoring.abc import ScoringBase
from nplinker.scoring.metcalf_scoring import MetcalfScoring
from nplinker.scoring.methods import ScoringMethod
from nplinker.scoring.object_link import ObjectLink


logger = logging.getLogger(__name__)


class NPClassScoring(ScoringMethod):
NAME = "npclassscore"
class NPClassScoring(ScoringBase):
name = "npclassscore"

def __init__(self, npl):
super().__init__(npl)
Expand Down Expand Up @@ -347,7 +347,7 @@ def get_links(self, objects, link_collection):
logger.info("Using Metcalf scoring to get shared strains")
# get mapping of shared strains
if not self.npl._datalinks:
self.npl._datalinks = self.npl.scoring_method(MetcalfScoring.NAME).datalinks
self.npl._datalinks = self.npl.scoring_method(MetcalfScoring.name).datalinks
if obj_is_gen:
common_strains = self.npl.get_common_strains(targets, objects)
else:
Expand Down
6 changes: 3 additions & 3 deletions src/nplinker/scoring/rosetta_scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@
from nplinker.genomics.bgc import BGC
from nplinker.genomics.gcf import GCF
from nplinker.metabolomics import MolecularFamily
from nplinker.scoring.methods import ScoringMethod
from nplinker.scoring.abc import ScoringBase
from nplinker.scoring.object_link import ObjectLink
from nplinker.scoring.rosetta.rosetta import Rosetta


logger = logging.getLogger(__name__)


class RosettaScoring(ScoringMethod):
NAME = "rosetta"
class RosettaScoring(ScoringBase):
name = "rosetta"
ROSETTA_OBJ = None

def __init__(self, npl):
Expand Down

0 comments on commit 3f8bbdd

Please sign in to comment.