Skip to content

Commit

Permalink
refactor filtering of user specified strains
Browse files Browse the repository at this point in the history
Based on the new loading pipeline, the filtering of user strains is conducted just after loading strain mappings, which will simplify the upcoming loading pipeline. The change of loading process looks like below:
**Before**: load strain mappings --> load BGC, GCF, spectra and MF --> filter user strains
**Now**: load strain mappings --> filter user strains --> load BGC, GCF, spectra and MF 


Major Changes:
- create `strain_loader.py` and add function `load_user_strains`
- add schema for json file of user specified strains (now we require user to provide strains in a JSON file)
- update the use of `load_user_strains` in loader.py
- remove test_loader.py (the whole loading pipeline is ongoing, and tests will be added later)
  • Loading branch information
CunliangGeng authored Dec 14, 2023
1 parent 4016a64 commit bd45807
Show file tree
Hide file tree
Showing 8 changed files with 157 additions and 204 deletions.
108 changes: 11 additions & 97 deletions src/nplinker/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from nplinker.pairedomics.runbigscape import run_bigscape
from nplinker.pairedomics.strain_mappings_generator import podp_generate_strain_mappings
from nplinker.strain_collection import StrainCollection
from nplinker.strain_loader import load_user_strains
from nplinker.strains import Strain


Expand Down Expand Up @@ -205,10 +206,6 @@ def load(self):
# TODO add a config file option for this?
self._filter_only_common_strains()

# if the user specified a set of strains to be explicitly included, filter
# out everything except those strains
self._filter_user_strains()

# if we don't have at least *some* strains here it probably means missing mappings
# or a complete failure to parse things, so bail out
if len(self.strains) == 0:
Expand Down Expand Up @@ -382,13 +379,21 @@ def _load_mibig(self):
return True

def _load_strain_mappings(self):
# First load user's strain mappings
# 1. load strain mappings
sc = StrainCollection.read_json(self.strain_mappings_file)
for strain in sc:
self.strains.add(strain)
logger.info("Loaded {} non-MiBIG Strain objects".format(len(self.strains)))

# Then load MiBIG strain mappings
# 2. filter user specificied strains (remove all that are not specified by user).
# It's not allowed to specify empty list of strains, otherwise validation will fail.
if os.path.exists(self.include_strains_file):
logger.info(f"Loading user specified strains from file {self.include_strains_file}.")
user_strains = load_user_strains(self.include_strains_file)
logger.info(f"Loaded {len(user_strains)} user specified strains.")
self.strains.filter(user_strains)

# 3. load MiBIG strain mappings
if self._mibig_strain_bgc_mapping:
for k, v in self._mibig_strain_bgc_mapping.items():
strain = Strain(k)
Expand Down Expand Up @@ -573,28 +578,6 @@ def _load_optional(self):
self.description_text = open(self.description_file).read()
logger.debug("Parsed description text")

self.include_only_strains = set()
if os.path.exists(self.include_strains_file):
logger.debug("Loading include_strains from {}".format(self.include_strains_file))
strain_list = open(self.include_strains_file).readlines()
self.include_only_strains = StrainCollection()
for line_num, sid in enumerate(strain_list):
sid = sid.strip() # get rid of newline
try:
strain_ref_list = self.strains.lookup(sid)
except KeyError:
logger.warning(
'Line {} of {}: invalid/unknown strain ID "{}"'.format(
line_num + 1, self.include_strains_file, sid
)
)
continue
for strain in strain_ref_list:
self.include_only_strains.add(strain)
logger.debug(
"Found {} strain IDs in include_strains".format(len(self.include_only_strains))
)

def _filter_only_common_strains(self):
"""Filter strain population to only strains present in both genomic and molecular data."""
# TODO: Maybe there should be an option to specify which strains are used, both so we can
Expand Down Expand Up @@ -627,75 +610,6 @@ def _filter_only_common_strains(self):
spec.strains.filter(common_strains)
logger.info("Strains filtered down to total of {}".format(len(self.strains)))

def _filter_user_strains(self):
"""If the user has supplied a list of strains to be explicitly included, go through the
existing sets of objects we have and remove any that only include other strains. This
involves an initial round of removing BGC and Spectrum objects, then a further round
of removing now-empty GCF and MolFam objects.
"""
if len(self.include_only_strains) == 0:
logger.info("No further strain filtering to apply")
return

logger.info(
"Found a list of {} strains to retain, filtering objects".format(
len(self.include_only_strains)
)
)

# filter the main list of strains
self.strains.filter(self.include_only_strains)

if len(self.strains) == 0:
logger.error("Strain list has been filtered down until it is empty! ")
logger.error(
"This probably indicates that you tried to specifically include a set of strains that had no overlap with the set common to metabolomics and genomics data (see the common_strains.csv in the dataset folder for a list of these"
)
raise Exception("No strains left after filtering, cannot continue!")

# get the list of BGCs which have a strain found in the set we were given
bgcs_to_retain = {bgc for bgc in self.bgcs if bgc.strain in self.include_only_strains}
# get the list of spectra which have at least one strain in the set
spectra_to_retain = {
spec
for spec in self.spectra
for sstrain in spec.strains
if sstrain in self.include_only_strains
}

logger.info(
"Current / filtered BGC counts: {} / {}".format(len(self.bgcs), len(bgcs_to_retain))
)
logger.info(
"Current / filtered spectra counts: {} / {}".format(
len(self.spectra), len(spectra_to_retain)
)
)

self.bgcs = list(bgcs_to_retain)

self.spectra = list(spectra_to_retain)
# also need to filter the set of strains attached to each spectrum
for i, spec in enumerate(self.spectra):
spec.strains.filter(self.include_only_strains)
spec.id = i

# now filter GCFs and MolFams based on the filtered BGCs and Spectra
gcfs = {parent for bgc in self.bgcs for parent in bgc.parents}
logger.info("Current / filtered GCF counts: {} / {}".format(len(self.gcfs), len(gcfs)))
self.gcfs = list(gcfs)
# filter each GCF's strain list
for gcf in self.gcfs:
gcf.strains.filter(self.include_only_strains)

molfams = {spec.family for spec in self.spectra}
logger.info(
"Current / filtered MolFam counts: {} / {}".format(len(self.molfams), len(molfams))
)
self.molfams = list(molfams)
for i, molfam in enumerate(self.molfams):
molfam.id = i


def find_via_glob(path, file_type, optional=False):
try:
Expand Down
4 changes: 4 additions & 0 deletions src/nplinker/schemas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"GENOME_BGC_MAPPINGS_SCHEMA",
"STRAIN_MAPPINGS_SCHEMA",
"PODP_ADAPTED_SCHEMA",
"USER_STRAINS_SCHEMA",
"validate_podp_json",
]

Expand All @@ -24,3 +25,6 @@

with open(SCHEMA_DIR / "strain_mappings_schema.json", "r") as f:
STRAIN_MAPPINGS_SCHEMA = json.load(f)

with open(SCHEMA_DIR / "user_strains.json", "r") as f:
USER_STRAINS_SCHEMA = json.load(f)
30 changes: 30 additions & 0 deletions src/nplinker/schemas/user_strains.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
{
"$schema": "https://json-schema.org/draft/2020-12/schema",
"$id": "https://raw.githubusercontent.com/NPLinker/nplinker/main/src/nplinker/schemas/user_strains.json",
"title": "User specificed strains",
"description": "A list of strain IDs specified by user",
"type": "object",
"required": [
"strain_ids"
],
"properties": {
"strain_ids": {
"type": "array",
"title": "Strain IDs",
"description": "A list of strain IDs specificed by user. The strain IDs must be the same as the ones in the strain mappings file.",
"items": {
"type": "string",
"minLength": 1
},
"minItems": 1,
"uniqueItems": true
},
"version": {
"type": "string",
"enum": [
"1.0"
]
}
},
"additionalProperties": false
}
6 changes: 5 additions & 1 deletion src/nplinker/strain_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,11 @@ def remove(self, strain: Strain):
raise ValueError(f"Strain {strain} not found in strain collection.")

def filter(self, strain_set: set[Strain]):
"""Remove all strains that are not in strain_set from the strain collection."""
"""Remove all strains that are not in strain_set from the strain collection.
Args:
strain_set(set[Strain]): Set of strains to keep.
"""
# note that we need to copy the list of strains, as we are modifying it
for strain in self._strains.copy():
if strain not in strain_set:
Expand Down
35 changes: 35 additions & 0 deletions src/nplinker/strain_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import json
from os import PathLike
from jsonschema import validate
from nplinker.logconfig import LogConfig
from nplinker.schemas import USER_STRAINS_SCHEMA
from .strains import Strain


logger = LogConfig.getLogger(__name__)


def load_user_strains(json_file: str | PathLike) -> set[Strain]:
"""Load user specified strains from a JSON file.
The JSON file must follow the schema defined in "nplinker/schemas/user_strains.json".
An example content of the JSON file:
{"strain_ids": ["strain1", "strain2"]}
Args:
json_file(str | PathLike): Path to the JSON file containing user specified strains.
Returns:
set[Strain]: A set of user specified strains.
"""
with open(json_file, "r") as f:
json_data = json.load(f)

# validate json data
validate(instance=json_data, schema=USER_STRAINS_SCHEMA)

strains = set()
for strain_id in json_data["strain_ids"]:
strains.add(Strain(strain_id))

return strains
49 changes: 49 additions & 0 deletions tests/schemas/test_user_strains_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import pytest
from jsonschema import validate
from jsonschema.exceptions import ValidationError
from nplinker.schemas import USER_STRAINS_SCHEMA


# Test schema aginast invalid data
data_no_strain_ids = {"version": "1.0"}
data_empty_strain_ids = {"strain_ids": [], "version": "1.0"}
data_invalid_strain_ids = {"strain_ids": [1, 2, 3], "version": "1.0"}
data_empty_version = {"strain_ids": ["strain1", "strain2"], "version": ""}
data_invalid_version = {"strain_ids": ["strain1", "strain2"], "version": "1.0.0"}


@pytest.mark.parametrize(
"data, expected",
[
[data_no_strain_ids, "'strain_ids' is a required property"],
[data_empty_strain_ids, "[] is too short"],
[data_invalid_strain_ids, "1 is not of type 'string'"],
[data_empty_version, "'' is not one of ['1.0']"],
[data_invalid_version, "'1.0.0' is not one of ['1.0']"],
],
)
def test_invalid_data(data, expected):
"""Test user strains schema against invalid data."""
with pytest.raises(ValidationError) as e:
validate(data, USER_STRAINS_SCHEMA)
assert e.value.message == expected


# Test schema aginast valid data
data = {"strain_ids": ["strain1", "strain2"], "version": "1.0"}
data_no_version = {"strain_ids": ["strain1", "strain2"]}


@pytest.mark.parametrize(
"data",
[
data,
data_no_version,
],
)
def test_valid_data(data):
"""Test user strains schema against valid data."""
try:
validate(data, USER_STRAINS_SCHEMA)
except ValidationError:
pytest.fail("Unexpected ValidationError")
106 changes: 0 additions & 106 deletions tests/test_loader.py

This file was deleted.

Loading

0 comments on commit bd45807

Please sign in to comment.