From a7668e7a4056404db9c1e9cacd6d6ce15d678982 Mon Sep 17 00:00:00 2001 From: Cunliang Geng Date: Tue, 14 Nov 2023 16:11:47 +0100 Subject: [PATCH] refactor reading and filtering of user specified strains - remove `_filter_user_strains` method - remove reading user strains file from `_load_optional` method - add `_load_user_strains` method - update `_load_strain_mappings` method to include the reading and filtering of user specified strains after loading strain mappings --- src/nplinker/loader.py | 123 +++++++++-------------------------------- 1 file changed, 26 insertions(+), 97 deletions(-) diff --git a/src/nplinker/loader.py b/src/nplinker/loader.py index 753eeba6..01137558 100644 --- a/src/nplinker/loader.py +++ b/src/nplinker/loader.py @@ -205,10 +205,6 @@ def load(self): # TODO add a config file option for this? self._filter_only_common_strains() - # if the user specified a set of strains to be explicitly included, filter - # out everything except those strains - self._filter_user_strains() - # if we don't have at least *some* strains here it probably means missing mappings # or a complete failure to parse things, so bail out if len(self.strains) == 0: @@ -382,13 +378,18 @@ def _load_mibig(self): return True def _load_strain_mappings(self): - # First load user's strain mappings + # 1. load strain mappings sc = StrainCollection.read_json(self.strain_mappings_file) for strain in sc: self.strains.add(strain) logger.info("Loaded {} non-MiBIG Strain objects".format(len(self.strains))) - # Then load MiBIG strain mappings + # 2. filter user specificied strains (remove all that are not specified by user) + user_strains = self._load_user_strains() + if user_strains: + self.strains.filter(user_strains) + + # 3. load MiBIG strain mappings if self._mibig_strain_bgc_mapping: for k, v in self._mibig_strain_bgc_mapping.items(): strain = Strain(k) @@ -398,6 +399,25 @@ def _load_strain_mappings(self): return True + def _load_user_strains(self) -> set[Strain]: + """Load user-specified strains from a file. + + The file must contain one strain name per line. + + Returns: + set[Strain]: A set of user specified strains. + """ + strains = set() + if os.path.exists(self.include_strains_file): + logger.debug(f"Loading user specified strains from {self.include_strains_file}.") + with open(self.include_strains_file, "r") as f: + for line in f.readlines(): + strains.add(Strain(line.strip())) + + if len(strains) != 0: + logger.debug(f"Loaded {len(strains)} user specified strains.") + return strains + # TODO CG: replace deprecated load_dataset with GPNSLoader def _load_metabolomics(self): spec_dict, self.spectra, self.molfams, unknown_strains = load_dataset( @@ -573,28 +593,6 @@ def _load_optional(self): self.description_text = open(self.description_file).read() logger.debug("Parsed description text") - self.include_only_strains = set() - if os.path.exists(self.include_strains_file): - logger.debug("Loading include_strains from {}".format(self.include_strains_file)) - strain_list = open(self.include_strains_file).readlines() - self.include_only_strains = StrainCollection() - for line_num, sid in enumerate(strain_list): - sid = sid.strip() # get rid of newline - try: - strain_ref_list = self.strains.lookup(sid) - except KeyError: - logger.warning( - 'Line {} of {}: invalid/unknown strain ID "{}"'.format( - line_num + 1, self.include_strains_file, sid - ) - ) - continue - for strain in strain_ref_list: - self.include_only_strains.add(strain) - logger.debug( - "Found {} strain IDs in include_strains".format(len(self.include_only_strains)) - ) - def _filter_only_common_strains(self): """Filter strain population to only strains present in both genomic and molecular data.""" # TODO: Maybe there should be an option to specify which strains are used, both so we can @@ -627,75 +625,6 @@ def _filter_only_common_strains(self): spec.strains.filter(common_strains) logger.info("Strains filtered down to total of {}".format(len(self.strains))) - def _filter_user_strains(self): - """If the user has supplied a list of strains to be explicitly included, go through the - existing sets of objects we have and remove any that only include other strains. This - involves an initial round of removing BGC and Spectrum objects, then a further round - of removing now-empty GCF and MolFam objects. - """ - if len(self.include_only_strains) == 0: - logger.info("No further strain filtering to apply") - return - - logger.info( - "Found a list of {} strains to retain, filtering objects".format( - len(self.include_only_strains) - ) - ) - - # filter the main list of strains - self.strains.filter(self.include_only_strains) - - if len(self.strains) == 0: - logger.error("Strain list has been filtered down until it is empty! ") - logger.error( - "This probably indicates that you tried to specifically include a set of strains that had no overlap with the set common to metabolomics and genomics data (see the common_strains.csv in the dataset folder for a list of these" - ) - raise Exception("No strains left after filtering, cannot continue!") - - # get the list of BGCs which have a strain found in the set we were given - bgcs_to_retain = {bgc for bgc in self.bgcs if bgc.strain in self.include_only_strains} - # get the list of spectra which have at least one strain in the set - spectra_to_retain = { - spec - for spec in self.spectra - for sstrain in spec.strains - if sstrain in self.include_only_strains - } - - logger.info( - "Current / filtered BGC counts: {} / {}".format(len(self.bgcs), len(bgcs_to_retain)) - ) - logger.info( - "Current / filtered spectra counts: {} / {}".format( - len(self.spectra), len(spectra_to_retain) - ) - ) - - self.bgcs = list(bgcs_to_retain) - - self.spectra = list(spectra_to_retain) - # also need to filter the set of strains attached to each spectrum - for i, spec in enumerate(self.spectra): - spec.strains.filter(self.include_only_strains) - spec.id = i - - # now filter GCFs and MolFams based on the filtered BGCs and Spectra - gcfs = {parent for bgc in self.bgcs for parent in bgc.parents} - logger.info("Current / filtered GCF counts: {} / {}".format(len(self.gcfs), len(gcfs))) - self.gcfs = list(gcfs) - # filter each GCF's strain list - for gcf in self.gcfs: - gcf.strains.filter(self.include_only_strains) - - molfams = {spec.family for spec in self.spectra} - logger.info( - "Current / filtered MolFam counts: {} / {}".format(len(self.molfams), len(molfams)) - ) - self.molfams = list(molfams) - for i, molfam in enumerate(self.molfams): - molfam.id = i - def find_via_glob(path, file_type, optional=False): try: