diff --git a/src/nplinker/loader.py b/src/nplinker/loader.py index 753eeba6..65af03c3 100644 --- a/src/nplinker/loader.py +++ b/src/nplinker/loader.py @@ -25,6 +25,7 @@ from nplinker.pairedomics.runbigscape import run_bigscape from nplinker.pairedomics.strain_mappings_generator import podp_generate_strain_mappings from nplinker.strain_collection import StrainCollection +from nplinker.strain_loader import load_user_strains from nplinker.strains import Strain @@ -205,10 +206,6 @@ def load(self): # TODO add a config file option for this? self._filter_only_common_strains() - # if the user specified a set of strains to be explicitly included, filter - # out everything except those strains - self._filter_user_strains() - # if we don't have at least *some* strains here it probably means missing mappings # or a complete failure to parse things, so bail out if len(self.strains) == 0: @@ -382,13 +379,21 @@ def _load_mibig(self): return True def _load_strain_mappings(self): - # First load user's strain mappings + # 1. load strain mappings sc = StrainCollection.read_json(self.strain_mappings_file) for strain in sc: self.strains.add(strain) logger.info("Loaded {} non-MiBIG Strain objects".format(len(self.strains))) - # Then load MiBIG strain mappings + # 2. filter user specificied strains (remove all that are not specified by user). + # It's not allowed to specify empty list of strains, otherwise validation will fail. + if os.path.exists(self.include_strains_file): + logger.info(f"Loading user specified strains from file {self.include_strains_file}.") + user_strains = load_user_strains(self.include_strains_file) + logger.info(f"Loaded {len(user_strains)} user specified strains.") + self.strains.filter(user_strains) + + # 3. load MiBIG strain mappings if self._mibig_strain_bgc_mapping: for k, v in self._mibig_strain_bgc_mapping.items(): strain = Strain(k) @@ -573,28 +578,6 @@ def _load_optional(self): self.description_text = open(self.description_file).read() logger.debug("Parsed description text") - self.include_only_strains = set() - if os.path.exists(self.include_strains_file): - logger.debug("Loading include_strains from {}".format(self.include_strains_file)) - strain_list = open(self.include_strains_file).readlines() - self.include_only_strains = StrainCollection() - for line_num, sid in enumerate(strain_list): - sid = sid.strip() # get rid of newline - try: - strain_ref_list = self.strains.lookup(sid) - except KeyError: - logger.warning( - 'Line {} of {}: invalid/unknown strain ID "{}"'.format( - line_num + 1, self.include_strains_file, sid - ) - ) - continue - for strain in strain_ref_list: - self.include_only_strains.add(strain) - logger.debug( - "Found {} strain IDs in include_strains".format(len(self.include_only_strains)) - ) - def _filter_only_common_strains(self): """Filter strain population to only strains present in both genomic and molecular data.""" # TODO: Maybe there should be an option to specify which strains are used, both so we can @@ -627,75 +610,6 @@ def _filter_only_common_strains(self): spec.strains.filter(common_strains) logger.info("Strains filtered down to total of {}".format(len(self.strains))) - def _filter_user_strains(self): - """If the user has supplied a list of strains to be explicitly included, go through the - existing sets of objects we have and remove any that only include other strains. This - involves an initial round of removing BGC and Spectrum objects, then a further round - of removing now-empty GCF and MolFam objects. - """ - if len(self.include_only_strains) == 0: - logger.info("No further strain filtering to apply") - return - - logger.info( - "Found a list of {} strains to retain, filtering objects".format( - len(self.include_only_strains) - ) - ) - - # filter the main list of strains - self.strains.filter(self.include_only_strains) - - if len(self.strains) == 0: - logger.error("Strain list has been filtered down until it is empty! ") - logger.error( - "This probably indicates that you tried to specifically include a set of strains that had no overlap with the set common to metabolomics and genomics data (see the common_strains.csv in the dataset folder for a list of these" - ) - raise Exception("No strains left after filtering, cannot continue!") - - # get the list of BGCs which have a strain found in the set we were given - bgcs_to_retain = {bgc for bgc in self.bgcs if bgc.strain in self.include_only_strains} - # get the list of spectra which have at least one strain in the set - spectra_to_retain = { - spec - for spec in self.spectra - for sstrain in spec.strains - if sstrain in self.include_only_strains - } - - logger.info( - "Current / filtered BGC counts: {} / {}".format(len(self.bgcs), len(bgcs_to_retain)) - ) - logger.info( - "Current / filtered spectra counts: {} / {}".format( - len(self.spectra), len(spectra_to_retain) - ) - ) - - self.bgcs = list(bgcs_to_retain) - - self.spectra = list(spectra_to_retain) - # also need to filter the set of strains attached to each spectrum - for i, spec in enumerate(self.spectra): - spec.strains.filter(self.include_only_strains) - spec.id = i - - # now filter GCFs and MolFams based on the filtered BGCs and Spectra - gcfs = {parent for bgc in self.bgcs for parent in bgc.parents} - logger.info("Current / filtered GCF counts: {} / {}".format(len(self.gcfs), len(gcfs))) - self.gcfs = list(gcfs) - # filter each GCF's strain list - for gcf in self.gcfs: - gcf.strains.filter(self.include_only_strains) - - molfams = {spec.family for spec in self.spectra} - logger.info( - "Current / filtered MolFam counts: {} / {}".format(len(self.molfams), len(molfams)) - ) - self.molfams = list(molfams) - for i, molfam in enumerate(self.molfams): - molfam.id = i - def find_via_glob(path, file_type, optional=False): try: diff --git a/src/nplinker/schemas/__init__.py b/src/nplinker/schemas/__init__.py index 477a87ed..31737ed3 100644 --- a/src/nplinker/schemas/__init__.py +++ b/src/nplinker/schemas/__init__.py @@ -12,6 +12,7 @@ "GENOME_BGC_MAPPINGS_SCHEMA", "STRAIN_MAPPINGS_SCHEMA", "PODP_ADAPTED_SCHEMA", + "USER_STRAINS_SCHEMA", "validate_podp_json", ] @@ -24,3 +25,6 @@ with open(SCHEMA_DIR / "strain_mappings_schema.json", "r") as f: STRAIN_MAPPINGS_SCHEMA = json.load(f) + +with open(SCHEMA_DIR / "user_strains.json", "r") as f: + USER_STRAINS_SCHEMA = json.load(f) diff --git a/src/nplinker/schemas/user_strains.json b/src/nplinker/schemas/user_strains.json new file mode 100644 index 00000000..64949566 --- /dev/null +++ b/src/nplinker/schemas/user_strains.json @@ -0,0 +1,30 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "https://raw.githubusercontent.com/NPLinker/nplinker/main/src/nplinker/schemas/user_strains.json", + "title": "User specificed strains", + "description": "A list of strain IDs specified by user", + "type": "object", + "required": [ + "strain_ids" + ], + "properties": { + "strain_ids": { + "type": "array", + "title": "Strain IDs", + "description": "A list of strain IDs specificed by user. The strain IDs must be the same as the ones in the strain mappings file.", + "items": { + "type": "string", + "minLength": 1 + }, + "minItems": 1, + "uniqueItems": true + }, + "version": { + "type": "string", + "enum": [ + "1.0" + ] + } + }, + "additionalProperties": false +} diff --git a/src/nplinker/strain_collection.py b/src/nplinker/strain_collection.py index 90b32250..cff5a9de 100644 --- a/src/nplinker/strain_collection.py +++ b/src/nplinker/strain_collection.py @@ -110,7 +110,11 @@ def remove(self, strain: Strain): raise ValueError(f"Strain {strain} not found in strain collection.") def filter(self, strain_set: set[Strain]): - """Remove all strains that are not in strain_set from the strain collection.""" + """Remove all strains that are not in strain_set from the strain collection. + + Args: + strain_set(set[Strain]): Set of strains to keep. + """ # note that we need to copy the list of strains, as we are modifying it for strain in self._strains.copy(): if strain not in strain_set: diff --git a/src/nplinker/strain_loader.py b/src/nplinker/strain_loader.py new file mode 100644 index 00000000..6204e9c7 --- /dev/null +++ b/src/nplinker/strain_loader.py @@ -0,0 +1,35 @@ +import json +from os import PathLike +from jsonschema import validate +from nplinker.logconfig import LogConfig +from nplinker.schemas import USER_STRAINS_SCHEMA +from .strains import Strain + + +logger = LogConfig.getLogger(__name__) + + +def load_user_strains(json_file: str | PathLike) -> set[Strain]: + """Load user specified strains from a JSON file. + + The JSON file must follow the schema defined in "nplinker/schemas/user_strains.json". + An example content of the JSON file: + {"strain_ids": ["strain1", "strain2"]} + + Args: + json_file(str | PathLike): Path to the JSON file containing user specified strains. + + Returns: + set[Strain]: A set of user specified strains. + """ + with open(json_file, "r") as f: + json_data = json.load(f) + + # validate json data + validate(instance=json_data, schema=USER_STRAINS_SCHEMA) + + strains = set() + for strain_id in json_data["strain_ids"]: + strains.add(Strain(strain_id)) + + return strains diff --git a/tests/schemas/test_user_strains_schema.py b/tests/schemas/test_user_strains_schema.py new file mode 100644 index 00000000..d14ea300 --- /dev/null +++ b/tests/schemas/test_user_strains_schema.py @@ -0,0 +1,49 @@ +import pytest +from jsonschema import validate +from jsonschema.exceptions import ValidationError +from nplinker.schemas import USER_STRAINS_SCHEMA + + +# Test schema aginast invalid data +data_no_strain_ids = {"version": "1.0"} +data_empty_strain_ids = {"strain_ids": [], "version": "1.0"} +data_invalid_strain_ids = {"strain_ids": [1, 2, 3], "version": "1.0"} +data_empty_version = {"strain_ids": ["strain1", "strain2"], "version": ""} +data_invalid_version = {"strain_ids": ["strain1", "strain2"], "version": "1.0.0"} + + +@pytest.mark.parametrize( + "data, expected", + [ + [data_no_strain_ids, "'strain_ids' is a required property"], + [data_empty_strain_ids, "[] is too short"], + [data_invalid_strain_ids, "1 is not of type 'string'"], + [data_empty_version, "'' is not one of ['1.0']"], + [data_invalid_version, "'1.0.0' is not one of ['1.0']"], + ], +) +def test_invalid_data(data, expected): + """Test user strains schema against invalid data.""" + with pytest.raises(ValidationError) as e: + validate(data, USER_STRAINS_SCHEMA) + assert e.value.message == expected + + +# Test schema aginast valid data +data = {"strain_ids": ["strain1", "strain2"], "version": "1.0"} +data_no_version = {"strain_ids": ["strain1", "strain2"]} + + +@pytest.mark.parametrize( + "data", + [ + data, + data_no_version, + ], +) +def test_valid_data(data): + """Test user strains schema against valid data.""" + try: + validate(data, USER_STRAINS_SCHEMA) + except ValidationError: + pytest.fail("Unexpected ValidationError") diff --git a/tests/test_loader.py b/tests/test_loader.py deleted file mode 100644 index a279f16e..00000000 --- a/tests/test_loader.py +++ /dev/null @@ -1,106 +0,0 @@ -import shutil -import pytest -from nplinker.globals import STRAIN_MAPPINGS_FILENAME -from nplinker.loader import DatasetLoader -from nplinker.metabolomics.gnps import GNPSExtractor -from nplinker.metabolomics.gnps import GNPSSpectrumLoader -from nplinker.strain_collection import StrainCollection -from . import DATA_DIR - - -@pytest.fixture -def config(): - return { - "dataset": { - "root": DATA_DIR / "ProteoSAFe-METABOLOMICS-SNETS-c22f44b1-download_clustered_spectra", - "platform_id": "", - "overrides": {"strain_mappings_file": str(DATA_DIR / STRAIN_MAPPINGS_FILENAME)}, - } - } - - -@pytest.fixture -def config_with_new_gnps_extractor(): - GNPSExtractor( - DATA_DIR / "ProteoSAFe-METABOLOMICS-SNETS-c22f44b1-download_clustered_spectra.zip", - DATA_DIR / "extracted", - ) - yield { - "dataset": { - "root": DATA_DIR / "extracted", - "platform_id": "", - "overrides": {"strain_mappings_file": str(DATA_DIR / STRAIN_MAPPINGS_FILENAME)}, - } - } - shutil.rmtree(DATA_DIR / "extracted") - - -def test_default(config): - sut = DatasetLoader(config) - assert sut._platform_id == config["dataset"]["platform_id"] - - -def test_has_metabolomics_paths(config): - sut = DatasetLoader(config) - sut._init_metabolomics_paths() - assert sut.mgf_file == str( - config["dataset"]["root"] - / "METABOLOMICS-SNETS-c22f44b1-download_clustered_spectra-main.mgf" - ) - assert sut.edges_file == str( - config["dataset"]["root"] - / "networkedges_selfloop" - / "6da5be36f5b14e878860167fa07004d6.pairsinfo" - ) - assert sut.nodes_file == str( - config["dataset"]["root"] - / "clusterinfosummarygroup_attributes_withIDs_withcomponentID" - / "d69356c8e5044c2a9fef3dd2a2f991e1.tsv" - ) - assert sut.annotations_dir == str(config["dataset"]["root"] / "result_specnets_DB") - - -def test_has_metabolomics_paths_new_gnps(config_with_new_gnps_extractor): - sut = DatasetLoader(config_with_new_gnps_extractor) - sut._init_metabolomics_paths() - assert sut.mgf_file == str(config_with_new_gnps_extractor["dataset"]["root"] / "spectra.mgf") - assert sut.edges_file == str( - config_with_new_gnps_extractor["dataset"]["root"] / "molecular_families.tsv" - ) - assert sut.nodes_file == str( - config_with_new_gnps_extractor["dataset"]["root"] / "file_mappings.tsv" - ) - assert sut.annotations_dir == str(config_with_new_gnps_extractor["dataset"]["root"]) - assert sut.annotations_config_file == str( - config_with_new_gnps_extractor["dataset"]["root"] / "annotations.tsv" - ) - - -def test_load_metabolomics(config): - sut = DatasetLoader(config) - sut._init_paths() - sut._load_strain_mappings() - sut._load_metabolomics() - - expected_spectra = GNPSSpectrumLoader(sut.mgf_file).spectra() - - # HH TODO: switch to different comparison as soon as strains are implemented - assert len(sut.spectra) == len(expected_spectra) - assert len(sut.molfams) == 429 - - -def test_has_strain_mappings(config): - sut = DatasetLoader(config) - sut._init_paths() - assert sut.strain_mappings_file == str(DATA_DIR / STRAIN_MAPPINGS_FILENAME) - - -def test_load_strain_mappings(config): - sut = DatasetLoader(config) - sut._init_paths() - sut._load_strain_mappings() - - actual = sut.strains - expected = StrainCollection.read_json(sut.strain_mappings_file) - - assert actual == expected diff --git a/tests/test_strain_loader.py b/tests/test_strain_loader.py new file mode 100644 index 00000000..8543a4c0 --- /dev/null +++ b/tests/test_strain_loader.py @@ -0,0 +1,23 @@ +import json +import pytest +from nplinker.strain_loader import load_user_strains +from nplinker.strains import Strain + + +@pytest.fixture +def user_strains_file(tmp_path): + """Create a JSON file containing user specified strains.""" + data = { + "strain_ids": ["strain1", "strain2", "strain3"], + } + file_path = tmp_path / "user_strains.json" + with open(file_path, "w") as f: + json.dump(data, f) + return file_path + + +def test_load_user_strains(user_strains_file): + """Test load_user_strains function.""" + actual = load_user_strains(user_strains_file) + expected = {Strain("strain1"), Strain("strain2"), Strain("strain3")} + assert actual == expected