Skip to content

Commit

Permalink
refactor reading and filtering of user specified strains
Browse files Browse the repository at this point in the history
- remove `_filter_user_strains` method
- remove reading user strains file from `_load_optional` method
- add `_load_user_strains` method
- update `_load_strain_mappings` method to include the reading and filtering of user specified strains after loading strain mappings
  • Loading branch information
CunliangGeng committed Nov 14, 2023
1 parent 5043c1f commit a7668e7
Showing 1 changed file with 26 additions and 97 deletions.
123 changes: 26 additions & 97 deletions src/nplinker/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,6 @@ def load(self):
# TODO add a config file option for this?
self._filter_only_common_strains()

# if the user specified a set of strains to be explicitly included, filter
# out everything except those strains
self._filter_user_strains()

# if we don't have at least *some* strains here it probably means missing mappings
# or a complete failure to parse things, so bail out
if len(self.strains) == 0:
Expand Down Expand Up @@ -382,13 +378,18 @@ def _load_mibig(self):
return True

def _load_strain_mappings(self):
# First load user's strain mappings
# 1. load strain mappings
sc = StrainCollection.read_json(self.strain_mappings_file)
for strain in sc:
self.strains.add(strain)
logger.info("Loaded {} non-MiBIG Strain objects".format(len(self.strains)))

# Then load MiBIG strain mappings
# 2. filter user specificied strains (remove all that are not specified by user)
user_strains = self._load_user_strains()
if user_strains:
self.strains.filter(user_strains)

# 3. load MiBIG strain mappings
if self._mibig_strain_bgc_mapping:
for k, v in self._mibig_strain_bgc_mapping.items():
strain = Strain(k)
Expand All @@ -398,6 +399,25 @@ def _load_strain_mappings(self):

return True

def _load_user_strains(self) -> set[Strain]:
"""Load user-specified strains from a file.
The file must contain one strain name per line.
Returns:
set[Strain]: A set of user specified strains.
"""
strains = set()
if os.path.exists(self.include_strains_file):
logger.debug(f"Loading user specified strains from {self.include_strains_file}.")
with open(self.include_strains_file, "r") as f:
for line in f.readlines():
strains.add(Strain(line.strip()))

if len(strains) != 0:
logger.debug(f"Loaded {len(strains)} user specified strains.")
return strains

# TODO CG: replace deprecated load_dataset with GPNSLoader
def _load_metabolomics(self):
spec_dict, self.spectra, self.molfams, unknown_strains = load_dataset(
Expand Down Expand Up @@ -573,28 +593,6 @@ def _load_optional(self):
self.description_text = open(self.description_file).read()
logger.debug("Parsed description text")

self.include_only_strains = set()
if os.path.exists(self.include_strains_file):
logger.debug("Loading include_strains from {}".format(self.include_strains_file))
strain_list = open(self.include_strains_file).readlines()
self.include_only_strains = StrainCollection()
for line_num, sid in enumerate(strain_list):
sid = sid.strip() # get rid of newline
try:
strain_ref_list = self.strains.lookup(sid)
except KeyError:
logger.warning(
'Line {} of {}: invalid/unknown strain ID "{}"'.format(
line_num + 1, self.include_strains_file, sid
)
)
continue
for strain in strain_ref_list:
self.include_only_strains.add(strain)
logger.debug(
"Found {} strain IDs in include_strains".format(len(self.include_only_strains))
)

def _filter_only_common_strains(self):
"""Filter strain population to only strains present in both genomic and molecular data."""
# TODO: Maybe there should be an option to specify which strains are used, both so we can
Expand Down Expand Up @@ -627,75 +625,6 @@ def _filter_only_common_strains(self):
spec.strains.filter(common_strains)
logger.info("Strains filtered down to total of {}".format(len(self.strains)))

def _filter_user_strains(self):
"""If the user has supplied a list of strains to be explicitly included, go through the
existing sets of objects we have and remove any that only include other strains. This
involves an initial round of removing BGC and Spectrum objects, then a further round
of removing now-empty GCF and MolFam objects.
"""
if len(self.include_only_strains) == 0:
logger.info("No further strain filtering to apply")
return

logger.info(
"Found a list of {} strains to retain, filtering objects".format(
len(self.include_only_strains)
)
)

# filter the main list of strains
self.strains.filter(self.include_only_strains)

if len(self.strains) == 0:
logger.error("Strain list has been filtered down until it is empty! ")
logger.error(
"This probably indicates that you tried to specifically include a set of strains that had no overlap with the set common to metabolomics and genomics data (see the common_strains.csv in the dataset folder for a list of these"
)
raise Exception("No strains left after filtering, cannot continue!")

# get the list of BGCs which have a strain found in the set we were given
bgcs_to_retain = {bgc for bgc in self.bgcs if bgc.strain in self.include_only_strains}
# get the list of spectra which have at least one strain in the set
spectra_to_retain = {
spec
for spec in self.spectra
for sstrain in spec.strains
if sstrain in self.include_only_strains
}

logger.info(
"Current / filtered BGC counts: {} / {}".format(len(self.bgcs), len(bgcs_to_retain))
)
logger.info(
"Current / filtered spectra counts: {} / {}".format(
len(self.spectra), len(spectra_to_retain)
)
)

self.bgcs = list(bgcs_to_retain)

self.spectra = list(spectra_to_retain)
# also need to filter the set of strains attached to each spectrum
for i, spec in enumerate(self.spectra):
spec.strains.filter(self.include_only_strains)
spec.id = i

# now filter GCFs and MolFams based on the filtered BGCs and Spectra
gcfs = {parent for bgc in self.bgcs for parent in bgc.parents}
logger.info("Current / filtered GCF counts: {} / {}".format(len(self.gcfs), len(gcfs)))
self.gcfs = list(gcfs)
# filter each GCF's strain list
for gcf in self.gcfs:
gcf.strains.filter(self.include_only_strains)

molfams = {spec.family for spec in self.spectra}
logger.info(
"Current / filtered MolFam counts: {} / {}".format(len(self.molfams), len(molfams))
)
self.molfams = list(molfams)
for i, molfam in enumerate(self.molfams):
molfam.id = i


def find_via_glob(path, file_type, optional=False):
try:
Expand Down

0 comments on commit a7668e7

Please sign in to comment.