From fc9ddec26b0c3d15e5307a8c5f5213ed42508ea7 Mon Sep 17 00:00:00 2001 From: Cunliang Geng Date: Fri, 14 Jun 2024 13:34:50 +0200 Subject: [PATCH] refactor NPLinker class (#256) * remove unused property `metadata` * refactor private data containers and relevant methods * remove legacy code for saving data * change the variable of default database folder to package level * update config related code * uniform the names of scoring methods * refactor `NPLinker.get_links` method - remove unused method `scoring_method` - refactor the method `get_links` * Delete test_nplinker_scoring.py * reorder methods and properties of NPLinker class * fix returned value type * update integration tests - simplify the `conftest.py` - remove unused functions - add tests for `get_link` * fix mypy errors --- src/nplinker/__init__.py | 8 +- src/nplinker/loader.py | 9 +- src/nplinker/nplinker.py | 479 +++++++++----------- src/nplinker/scoring/metcalf_scoring.py | 3 +- src/nplinker/scoring/np_class_scoring.py | 3 +- src/nplinker/scoring/rosetta/rosetta.py | 7 +- src/nplinker/scoring/rosetta_scoring.py | 3 +- tests/integration/conftest.py | 55 +-- tests/integration/test_nplinker_local.py | 55 ++- tests/unit/scoring/conftest.py | 6 +- tests/unit/scoring/test_nplinker_scoring.py | 116 ----- 11 files changed, 283 insertions(+), 461 deletions(-) delete mode 100644 tests/unit/scoring/test_nplinker_scoring.py diff --git a/src/nplinker/__init__.py b/src/nplinker/__init__.py index 0b347ec8..43ace5fe 100644 --- a/src/nplinker/__init__.py +++ b/src/nplinker/__init__.py @@ -1,4 +1,5 @@ import logging +from pathlib import Path logging.getLogger(__name__).addHandler(logging.NullHandler()) @@ -8,6 +9,11 @@ __version__ = "2.0.0-alpha.1" +# The path to the NPLinker application database directory +NPLINKER_APP_DATA_DIR = Path(__file__).parent / "data" +del Path + + def setup_logging(level: str = "INFO", file: str = "", use_console: bool = True) -> None: """Setup logging configuration for the ancestor logger "nplinker". @@ -22,7 +28,7 @@ def setup_logging(level: str = "INFO", file: str = "", use_console: bool = True) from rich.console import Console from rich.logging import RichHandler - # Get the acncestor logger "nplinker" + # Get the ancestor logger "nplinker" logger = logging.getLogger(__name__) logger.setLevel(level) diff --git a/src/nplinker/loader.py b/src/nplinker/loader.py index cfa7fe36..ce534f52 100644 --- a/src/nplinker/loader.py +++ b/src/nplinker/loader.py @@ -1,8 +1,8 @@ import logging import os -from importlib.resources import files from deprecated import deprecated from dynaconf import Dynaconf +from nplinker import NPLINKER_APP_DATA_DIR from nplinker import defaults from nplinker.genomics.antismash import AntismashBGCLoader from nplinker.genomics.bigscape import BigscapeGCFLoader @@ -23,8 +23,6 @@ logger = logging.getLogger(__name__) -NPLINKER_APP_DATA_DIR = files("nplinker").joinpath("data") - class DatasetLoader: """Class to load all data. @@ -228,9 +226,10 @@ def _load_class_info(self): True if everything completes """ # load Class_matches with mibig info from data - mibig_class_file = NPLINKER_APP_DATA_DIR.joinpath( - "MIBiG2.0_compounds_with_AS_BGC_CF_NPC_classes.txt" + mibig_class_file = ( + NPLINKER_APP_DATA_DIR / "MIBiG2.0_compounds_with_AS_BGC_CF_NPC_classes.txt" ) + self.class_matches = ClassMatches(mibig_class_file) # noqa # run canopus if canopus_dir does not exist diff --git a/src/nplinker/nplinker.py b/src/nplinker/nplinker.py index 2a1059a8..60e20be7 100644 --- a/src/nplinker/nplinker.py +++ b/src/nplinker/nplinker.py @@ -1,39 +1,75 @@ from __future__ import annotations import logging -import sys -from os import PathLike from pprint import pformat +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import overload from . import setup_logging from .arranger import DatasetArranger from .config import load_config from .defaults import OUTPUT_DIRNAME from .genomics import BGC from .genomics import GCF -from .loader import NPLINKER_APP_DATA_DIR from .loader import DatasetLoader from .metabolomics import MolecularFamily from .metabolomics import Spectrum -from .pickler import save_pickled_data -from .scoring.abc import ScoringBase from .scoring.metcalf_scoring import MetcalfScoring -from .scoring.np_class_scoring import NPClassScoring -from .scoring.rosetta_scoring import RosettaScoring +from .strain import StrainCollection + + +if TYPE_CHECKING: + from os import PathLike + from typing import Sequence + from nplinker.scoring.link_graph import LinkGraph logger = logging.getLogger(__name__) +ObjectType = TypeVar("ObjectType", BGC, GCF, Spectrum, MolecularFamily) -class NPLinker: - """Main class for the NPLinker application.""" - # allowable types for objects to be passed to scoring methods - OBJ_CLASSES = [Spectrum, MolecularFamily, GCF, BGC] - # default set of enabled scoring methods - # TODO: ideally these shouldn't be hardcoded like this - SCORING_METHODS = { +class NPLinker: + """Main class for the NPLinker application. + + Attributes: + config: The configuration object for the current NPLinker application. + root_dir: The path to the root directory of the current NPLinker application. + output_dir: The path to the output directory of the current NPLinker application. + bgcs: A list of all BGC objects. + gcfs: A list of all GCF objects. + spectra: A list of all Spectrum objects. + mfs: A list of all MolecularFamily objects. + mibig_bgcs: A list of all MiBIG BGC objects. + strains: A StrainCollection object containing all Strain objects. + product_types: A list of all BiGSCAPE product types. + scoring_methods: A list of all valid scoring methods. + + + Examples: + To start a NPLinker application: + >>> from nplinker import NPLinker + >>> npl = NPLinker("path/to/config.toml") + + To load all data into memory: + >>> npl.load_data() + + To check the number of GCF objects: + >>> len(npl.gcfs) + + To get the links for all GCF objects using the Metcalf scoring method, the result is a + LinkGraph object: + >>> lg = npl.get_links(npl.gcfs, "metcalf") + + To get the link data between two objects: + >>> link_data = lg.get_link_data(npl.gcfs[0], npl.spectra[0]) + {"metcalf": Score("metcalf", 1.0, {"cutoff": 0, "standardised": False})} + """ + + # Valid scoring methods + _valid_scoring_methods = { MetcalfScoring.name: MetcalfScoring, - RosettaScoring.name: RosettaScoring, - NPClassScoring.name: NPClassScoring, + # RosettaScoring.name: RosettaScoring, # To be refactored + # NPClassScoring.name: NPClassScoring, # To be refactored } def __init__(self, config_file: str | PathLike): @@ -42,8 +78,10 @@ def __init__(self, config_file: str | PathLike): Args: config_file: Path to the configuration file to use. """ + # Load the configuration file self.config = load_config(config_file) + # Setup logging for the application setup_logging( level=self.config.log.level, file=self.config.log.get("file", ""), @@ -53,308 +91,211 @@ def __init__(self, config_file: str | PathLike): "Configuration:\n %s", pformat(self.config.as_dict(), width=20, sort_dicts=False) ) - self.output_dir = self.config.root_dir / OUTPUT_DIRNAME - self.output_dir.mkdir(exist_ok=True) - - self._spectra = [] - self._bgcs = [] - self._gcfs = [] - self._strains = None - self._metadata = {} - self._mfs = [] - self._mibig_bgcs = [] - self._chem_classes = None - self._class_matches = None - - self._bgc_lookup = {} - self._gcf_lookup = {} - self._spec_lookup = {} - self._mf_lookup = {} - - self._scoring_methods = {} - config_methods = self.config.get("scoring_methods", []) - for name, method in NPLinker.SCORING_METHODS.items(): - if len(config_methods) == 0 or name in config_methods: - self._scoring_methods[name] = method - logger.info(f"Enabled scoring method: {name}") - - self._scoring_methods_setup_complete = { - name: False for name in self._scoring_methods.keys() - } - - self._repro_data = {} - repro_file = self.config.get("repro_file") - if repro_file: - self.save_repro_data(repro_file) - - def _collect_repro_data(self): - """Creates a dict containing data to aid reproducible runs of nplinker. - - This method creates a dict containing various bits of information about - the current execution of nplinker. This data will typically be saved to - a file in order to aid reproducibility using :func:`save_repro_data`. - - TODO describe contents + # Setup the output directory + self._output_dir = self.config.root_dir / OUTPUT_DIRNAME + self._output_dir.mkdir(exist_ok=True) - Returns: - A dict containing the information described above - """ - self._repro_data = {} - # TODO best way to embed latest git commit hash? probably with a packaging script... - # TODO versions of all Python dependencies used (just rely on - # Pipfile.lock here?) + # Initialise data containers that will be populated by the `load_data` method + self._bgc_dict: dict[str, BGC] = {} + self._gcf_dict: dict[str, GCF] = {} + self._spec_dict: dict[str, Spectrum] = {} + self._mf_dict: dict[str, MolecularFamily] = {} + self._mibig_bgcs: list[BGC] = [] + self._strains: StrainCollection = StrainCollection() + self._product_types: list = [] + self._chem_classes = None # TODO: to be refactored + self._class_matches = None # TODO: to be refactored - # insert command line arguments - self._repro_data["args"] = {} - for i, arg in enumerate(sys.argv): - self._repro_data["args"][i] = arg + # Flags to keep track of whether the scoring methods have been set up + self._scoring_methods_setup_done = {name: False for name in self._valid_scoring_methods} - # TODO anything else to include here? + @property + def root_dir(self) -> str: + """Get the path to the root directory of the current NPLinker instance.""" + return str(self.config.root_dir) - return self._repro_data + @property + def output_dir(self) -> str: + """Get the path to the output directory of the current NPLinker instance.""" + return str(self._output_dir) - def save_repro_data(self, filename): - self._collect_repro_data() - with open(filename, "wb") as repro_file: - # TODO is pickle the best format to use? - save_pickled_data(self._repro_data, repro_file) - logger.info(f"Saving reproducibility data to {filename}") + @property + def bgcs(self) -> list[BGC]: + """Get all BGC objects.""" + return list(self._bgc_dict.values()) @property - def root_dir(self) -> str: - """Returns path to the current dataset root directory. + def gcfs(self) -> list[GCF]: + """Get all GCF objects.""" + return list(self._gcf_dict.values()) - Returns: - The path to the dataset root directory currently in use - """ - return self.config.root_dir + @property + def spectra(self) -> list[Spectrum]: + """Get all Spectrum objects.""" + return list(self._spec_dict.values()) + + @property + def mfs(self) -> list[MolecularFamily]: + """Get all MolecularFamily objects.""" + return list(self._mf_dict.values()) @property - def data_dir(self): - """Returns path to nplinker/data directory (files packaged with the app itself).""" - return NPLINKER_APP_DATA_DIR + def mibig_bgcs(self) -> list[BGC]: + """Get all MiBIG BGC objects.""" + return self._mibig_bgcs @property - def bigscape_cutoff(self): - """Returns the current BiGSCAPE clustering cutoff value.""" - return self.config.bigscape.cutoff + def strains(self) -> StrainCollection: + """Get all Strain objects.""" + return self._strains + + @property + def product_types(self) -> list[str]: + """Get all BiGSCAPE product types.""" + return self._product_types + + @property + def chem_classes(self): + """Returns loaded ChemClassPredictions with the class predictions.""" + return self._chem_classes + + @property + def class_matches(self): + """ClassMatches with the matched classes and scoring tables from MIBiG.""" + return self._class_matches + + @property + def scoring_methods(self) -> list[str]: + """Get names of all valid scoring methods.""" + return list(self._valid_scoring_methods.keys()) def load_data(self): - """Loads the basic components of a dataset.""" + """Load all data from local files into memory. + + This method is a convenience function that calls the `DatasetArranger` and `DatasetLoader` + classes to load all data from the local filesystem into memory. The loaded data is then + stored in various private data containers for easy access. + """ arranger = DatasetArranger(self.config) arranger.arrange() loader = DatasetLoader(self.config) loader.load() - self._spectra = loader.spectra - self._mfs = loader.mfs - self._bgcs = loader.bgcs - self._gcfs = loader.gcfs + self._bgc_dict = {bgc.id: bgc for bgc in loader.bgcs} + self._gcf_dict = {gcf.id: gcf for gcf in loader.gcfs} + self._spec_dict = {spec.id: spec for spec in loader.spectra} + self._mf_dict = {mf.id: mf for mf in loader.mfs} + self._mibig_bgcs = loader.mibig_bgcs self._strains = loader.strains self._product_types = loader.product_types self._chem_classes = loader.chem_classes self._class_matches = loader.class_matches - # TODO CG: refactor this method and update its unit tests + @overload def get_links( - self, input_objects: list, scoring_methods: list, and_mode: bool = True - ) -> LinkCollection: - """Find links for a set of input objects (BGCs/GCFs/Spectra/mfs). - - The input objects can be any mix of the following NPLinker types: - - - BGC - - GCF - - Spectrum - - MolecularFamily + self, objects: Sequence[BGC], scoring_method: str, **scoring_params + ) -> LinkGraph: ... + @overload + def get_links( + self, objects: Sequence[GCF], scoring_method: str, **scoring_params + ) -> LinkGraph: ... + @overload + def get_links( + self, objects: Sequence[Spectrum], scoring_method: str, **scoring_params + ) -> LinkGraph: ... + @overload + def get_links( + self, objects: Sequence[MolecularFamily], scoring_method: str, **scoring_params + ) -> LinkGraph: ... - TODO longer description here + def get_links(self, objects, scoring_method, **scoring_params): + """Get the links for the given objects using the specified scoring method and parameters. Args: - input_objects: objects to be passed to the scoring method(s). - This may be either a flat list of a uniform type (one of the 4 - types above), or a list of such lists - scoring_methods: a list of one or more scoring methods to use - and_mode: determines how results from multiple methods are combined. - This is ignored if a single method is supplied. If multiple methods - are used and ``and_mode`` is True, the results will only contain - links found by ALL methods. If False, results will contain links - found by ANY method. + objects: A sequence of objects to get the links for. The objects must be of the same + type, i.e. `BGC`, `GCF`, `Spectrum` or `MolecularFamily` type. + For scoring method `metcalf`, the BGC objects are not supported. + scoring_method: The scoring method to use. Must be one of the valid scoring methods + `self.scoring_methods`. + scoring_params: Parameters to pass to the scoring method. If not provided, the default + parameters for the scoring method will be used. Returns: - An instance of ``nplinker.scoring.methods.LinkCollection`` - """ - if isinstance(input_objects, list) and len(input_objects) == 0: - raise Exception("input_objects length must be > 0") - - if isinstance(scoring_methods, list) and len(scoring_methods) == 0: - raise Exception("scoring_methods length must be > 0") - - # for convenience convert a single scoring object into a single entry - # list - if not isinstance(scoring_methods, list): - scoring_methods = [scoring_methods] - - # check if input_objects is a list of lists. if so there should be one - # entry for each supplied method for it to be a valid parameter - if isinstance(input_objects[0], list): - if len(input_objects) != len(scoring_methods): - raise Exception( - "Number of input_objects lists must match number of scoring_methods (found: {}, expected: {})".format( - len(input_objects), len(scoring_methods) - ) - ) - - # TODO check scoring_methods only contains ScoringMethod-derived - # instances - - # want everything to be in lists of lists - if not isinstance(input_objects, list) or ( - isinstance(input_objects, list) and not isinstance(input_objects[0], list) - ): - input_objects = [input_objects] - - logger.debug( - "get_links: {} object sets, {} methods".format(len(input_objects), len(scoring_methods)) - ) + A LinkGraph object containing the links for the given objects. - # copy the object set if required to make up the numbers - if len(input_objects) != len(scoring_methods): - if len(scoring_methods) < len(input_objects): - raise Exception("Number of scoring methods must be >= number of input object sets") - elif (len(scoring_methods) > len(input_objects)) and len(input_objects) != 1: - raise Exception( - "Mismatch between number of scoring methods and input objects ({} vs {})".format( - len(scoring_methods), len(input_objects) - ) - ) - elif len(scoring_methods) > len(input_objects): - # this is a special case for convenience: pass in 1 set of objects and multiple methods, - # result is that set is used for all methods - logger.debug("Duplicating input object set") - while len(input_objects) < len(scoring_methods): - input_objects.append(input_objects[0]) - logger.debug("Duplicating input object set") - - link_collection = LinkCollection(and_mode) - - for i, method in enumerate(scoring_methods): - # do any one-off initialisation required by this method - if not self._scoring_methods_setup_complete[method.name]: - logger.debug(f"Doing one-time setup for {method.name}") - self._scoring_methods[method.name].setup(self) - self._scoring_methods_setup_complete[method.name] = True - - # should construct a dict of {object_with_link: } - # entries - objects_for_method = input_objects[i] - logger.debug( - "Calling scoring method {} on {} objects".format( - method.name, len(objects_for_method) - ) + Raises: + ValueError: If input objects are empty or if the scoring method is invalid. + TypeError: If the input objects are not of the same type or if the object type is invalid. + """ + # Validate objects + if len(objects) == 0: + raise ValueError("No objects provided to get links for") + # check if all objects are of the same type + types = {type(i) for i in objects} + if len(types) > 1: + raise TypeError("Input objects must be of the same type.") + # check if the object type is valid + obj_type = next(iter(types)) + if obj_type not in (BGC, GCF, Spectrum, MolecularFamily): + raise TypeError( + f"Invalid type {obj_type}. Input objects must be BGC, GCF, Spectrum or MolecularFamily objects." ) - link_collection = method.get_links(*objects_for_method, link_collection=link_collection) - if len(link_collection) == 0: - logger.debug("No links found or remaining after merging all method results!") + # Validate scoring method + if scoring_method not in self._valid_scoring_methods: + raise ValueError(f"Invalid scoring method {scoring_method}.") - logger.info("Final size of link collection is {}".format(len(link_collection))) - return link_collection + # Check if the scoring method has been set up + if not self._scoring_methods_setup_done[scoring_method]: + self._valid_scoring_methods[scoring_method].setup(self) + self._scoring_methods_setup_done[scoring_method] = True - def has_bgc(self, id): - """Returns True if BGC ``id`` exists in the dataset.""" - return id in self._bgc_lookup + # Initialise the scoring method + scoring = self._valid_scoring_methods[scoring_method]() - def lookup_bgc(self, id): - """If BGC ``id`` exists, return it. Otherwise return None.""" - return self._bgc_lookup.get(id, None) + return scoring.get_links(*objects, **scoring_params) - def lookup_gcf(self, id): - """If GCF ``id`` exists, return it. Otherwise return None.""" - return self._gcf_lookup.get(id, None) + def lookup_bgc(self, id: str) -> BGC | None: + """Get the BGC object with the given ID. - def lookup_spectrum(self, id): - """If Spectrum ``name`` exists, return it. Otherwise return None.""" - return self._spec_lookup.get(id, None) - - def lookup_mf(self, id): - """If MolecularFamily `id` exists, return it. Otherwise return None.""" - return self._mf_lookup.get(id, None) - - @property - def strains(self): - """Returns a list of all the strains in the dataset.""" - return self._strains - - @property - def bgcs(self): - """Returns a list of all the BGCs in the dataset.""" - return self._bgcs - - @property - def gcfs(self): - """Returns a list of all the GCFs in the dataset.""" - return self._gcfs - - @property - def spectra(self): - """Returns a list of all the Spectra in the dataset.""" - return self._spectra - - @property - def mfs(self): - """Returns a list of all the MolecularFamilies in the dataset.""" - return self._mfs + Args: + id: the ID of the BGC to look up. - @property - def metadata(self): - return self._metadata + Returns: + The BGC object with the given ID, or None if no such object exists. + """ + return self._bgc_dict.get(id, None) - @property - def mibig_bgcs(self): - """Get a list of all the MIBiG BGCs in the dataset.""" - return self._mibig_bgcs + def lookup_gcf(self, id: str) -> GCF | None: + """Get the GCF object with the given ID. - @property - def product_types(self): - """Returns a list of the available BiGSCAPE product types in current dataset.""" - return self._product_types + Args: + id: the ID of the GCF to look up. - @property - def repro_data(self): - """Returns the dict containing reproducibility data.""" - return self._repro_data + Returns: + The GCF object with the given ID, or None if no such object exists. + """ + return self._gcf_dict.get(id, None) - @property - def scoring_methods(self): - """Returns a list of available scoring method names.""" - return list(self._scoring_methods.keys()) + def lookup_spectrum(self, id: str) -> Spectrum | None: + """Get the Spectrum object with the given ID. - @property - def chem_classes(self): - """Returns loaded ChemClassPredictions with the class predictions.""" - return self._chem_classes + Args: + id: the ID of the Spectrum to look up. - @property - def class_matches(self): - """ClassMatches with the matched classes and scoring tables from MIBiG.""" - return self._class_matches + Returns: + The Spectrum object with the given ID, or None if no such object exists. + """ + return self._spec_dict.get(id, None) - def scoring_method(self, name: str) -> ScoringBase | None: - """Return an instance of a scoring method. + def lookup_mf(self, id: str) -> MolecularFamily | None: + """Get the MolecularFamily object with the given ID. Args: - name: the name of the method (see :func:`scoring_methods`) + id: the ID of the MolecularFamily to look up. Returns: - An instance of the named scoring method class, or None if the name is invalid + The MolecularFamily object with the given ID, or None if no such object exists. """ - if name not in self._scoring_methods_setup_complete: - return None - - if not self._scoring_methods_setup_complete[name]: - self._scoring_methods[name].setup(self) - self._scoring_methods_setup_complete[name] = True - - return self._scoring_methods.get(name, None)(self) + return self._mf_dict.get(id, None) diff --git a/src/nplinker/scoring/metcalf_scoring.py b/src/nplinker/scoring/metcalf_scoring.py index 99336271..9c1df113 100644 --- a/src/nplinker/scoring/metcalf_scoring.py +++ b/src/nplinker/scoring/metcalf_scoring.py @@ -12,6 +12,7 @@ from .abc import ScoringBase from .link_graph import LinkGraph from .link_graph import Score +from .scoring_method import ScoringMethod from .utils import get_presence_gcf_strain from .utils import get_presence_mf_strain from .utils import get_presence_spec_strain @@ -71,7 +72,7 @@ class MetcalfScoring(ScoringBase): number of strains. """ - name = "metcalf" + name = ScoringMethod.METCALF.value npl: NPLinker | None = None CACHE: str = "cache_metcalf_scoring.pckl" metcalf_weights: tuple[int, int, int, int] = (10, -10, 0, 1) diff --git a/src/nplinker/scoring/np_class_scoring.py b/src/nplinker/scoring/np_class_scoring.py index 5e270488..55041ece 100644 --- a/src/nplinker/scoring/np_class_scoring.py +++ b/src/nplinker/scoring/np_class_scoring.py @@ -7,13 +7,14 @@ from .abc import ScoringBase from .link_graph import LinkGraph from .score import Score +from .scoring_method import ScoringMethod logger = logging.getLogger(__name__) class NPClassScoring(ScoringBase): - name = "npclassscore" + name = ScoringMethod.NPLCLASS.value def __init__(self, npl): super().__init__(npl) diff --git a/src/nplinker/scoring/rosetta/rosetta.py b/src/nplinker/scoring/rosetta/rosetta.py index c24f962a..524b7a42 100644 --- a/src/nplinker/scoring/rosetta/rosetta.py +++ b/src/nplinker/scoring/rosetta/rosetta.py @@ -21,6 +21,7 @@ from ...pickler import load_pickled_data from ...pickler import save_pickled_data from .spec_lib import SpecLib +from nplinker import NPLINKER_APP_DATA_DIR logger = logging.getLogger(__name__) @@ -38,9 +39,9 @@ def __init__(self, nplinker, ignore_genomic_cache=False): self._nplinker = nplinker self._mgf_data = {} self._csv_data = {} - self._mgf_path = os.path.join(nplinker.data_dir, "matched_mibig_gnps_update.mgf") - self._csv_path = os.path.join(nplinker.data_dir, "matched_mibig_gnps_update.csv") - self._data_path = nplinker.data_dir + self._mgf_path = os.path.join(NPLINKER_APP_DATA_DIR, "matched_mibig_gnps_update.mgf") + self._csv_path = os.path.join(NPLINKER_APP_DATA_DIR, "matched_mibig_gnps_update.csv") + self._data_path = NPLINKER_APP_DATA_DIR self._root_path = nplinker.root_dir self._dataset_id = nplinker.dataset_id self._ignore_genomic_cache = ignore_genomic_cache diff --git a/src/nplinker/scoring/rosetta_scoring.py b/src/nplinker/scoring/rosetta_scoring.py index 057d14d7..605f5859 100644 --- a/src/nplinker/scoring/rosetta_scoring.py +++ b/src/nplinker/scoring/rosetta_scoring.py @@ -7,13 +7,14 @@ from nplinker.scoring.rosetta.rosetta import Rosetta from .link_graph import LinkGraph from .score import Score +from .scoring_method import ScoringMethod logger = logging.getLogger(__name__) class RosettaScoring(ScoringBase): - name = "rosetta" + name = ScoringMethod.ROSETTA.value ROSETTA_OBJ = None def __init__(self, npl): diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index ede9f6cb..d4313afe 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -1,11 +1,11 @@ from __future__ import annotations import os import shutil -import tempfile import zipfile from os import PathLike from pathlib import Path import httpx +import pytest from rich.progress import Progress from . import DATA_DIR @@ -18,46 +18,27 @@ f"https://zenodo.org/records/{dataset_doi.split('.')[-1]}/files/nplinker_local_mode_example.zip" ) -# The temporary directory for the test session -temp_dir = tempfile.gettempdir() -nplinker_root_dir = os.path.join(temp_dir, "nplinker_local_mode_example") +@pytest.fixture(scope="module") +def root_dir(tmp_path_factory): + """Set up the NPLinker root directory for the local mode example dataset.""" + temp_dir = tmp_path_factory.mktemp("nplinker_integration_test") + nplinker_root_dir = temp_dir / "nplinker_local_mode_example" -def pytest_sessionstart(session): - """Pytest hook to run before the entire test session starts. - - This hook makes sure the temporary directory `nplinker_root_dir` is created before any test - starts. When running tests in parallel, the creation operation is done by the master process, - and worker processes are not allowed to do it. - - For more about this hook, see: - 1. https://docs.pytest.org/en/stable/reference.html#_pytest.hookspec.pytest_sessionstart - 2. https://github.com/pytest-dev/pytest-xdist/issues/271#issuecomment-826396320 - """ - workerinput = getattr(session.config, "workerinput", None) - # It's master process or not running in parallell when `workerinput` is None. - if workerinput is None: - if os.path.exists(nplinker_root_dir): - shutil.rmtree(nplinker_root_dir) - dataset = DATA_DIR / "nplinker_local_mode_example.zip" - if not dataset.exists(): - download_archive(dataset_url, DATA_DIR) - with zipfile.ZipFile(dataset, "r") as zip_ref: - zip_ref.extractall(temp_dir) - # NPLinker setting `root_dir` must be a path that exists, so setting it to a temporary directory. - os.environ["NPLINKER_ROOT_DIR"] = nplinker_root_dir - + # Download the dataset and extract it + if os.path.exists(nplinker_root_dir): + shutil.rmtree(nplinker_root_dir) + dataset = DATA_DIR / "nplinker_local_mode_example.zip" + if not dataset.exists(): + download_archive(dataset_url, DATA_DIR) + # the extracted directory is named "nplinker_local_mode_example" + with zipfile.ZipFile(dataset, "r") as zip_ref: + zip_ref.extractall(temp_dir) -def pytest_sessionfinish(session): - """Pytest hook to run after the entire test session finishes. + # Return the root directory + yield str(nplinker_root_dir) - This hook makes sure that temporary directory `nplinker_root_dir` is only removed after all - tests finish. When running tests in parallel, the deletion operation is done by the master - process, and worker processes are not allowed to do it. - """ - workerinput = getattr(session.config, "workerinput", None) - if workerinput is None: - shutil.rmtree(nplinker_root_dir) + shutil.rmtree(nplinker_root_dir) def download_archive( diff --git a/tests/integration/test_nplinker_local.py b/tests/integration/test_nplinker_local.py index ea19b43a..55ce5a84 100644 --- a/tests/integration/test_nplinker_local.py +++ b/tests/integration/test_nplinker_local.py @@ -1,38 +1,25 @@ -import hashlib -from pathlib import Path +import os import pytest from nplinker.nplinker import NPLinker from . import DATA_DIR -# Only tests related to data arranging and loading should be put here. -# For tests on scoring/links, add them to `scoring/test_nplinker_scoring.py`. - - -def get_file_hash(file_path): - h = hashlib.sha256() - with open(file_path, "rb") as file: - while True: - # Reading is buffered, so we can read smaller chunks. - chunk = file.read(h.block_size) - if not chunk: - break - h.update(chunk) - - return h.hexdigest() - - @pytest.fixture(scope="module") -def npl() -> NPLinker: +def npl(root_dir) -> NPLinker: + os.environ["NPLINKER_ROOT_DIR"] = root_dir npl = NPLinker(DATA_DIR / "nplinker_local_mode.toml") npl.load_data() - # remove cached score results before running tests - root_dir = Path(npl.root_dir) - score_cache = root_dir / "output" / "cache_metcalf_scoring.pckl" - score_cache.unlink(missing_ok=True) return npl +def test_init(npl, root_dir): + assert str(npl.config.root_dir) == root_dir + assert npl.config.mode == "local" + assert npl.config.log.level == "DEBUG" + + assert npl.root_dir == root_dir + + # --------------------------------------------------------------------------------------------------- # After manually checking data files for PODP MSV000079284, we have the following numbers: # 370 BGCs from antismash files @@ -63,3 +50,23 @@ def test_load_data(npl: NPLinker): assert len(npl.spectra) == 24652 assert len(npl.mfs) == 29 assert len(npl.strains) == 46 + + +def test_get_links(npl): + # default scoring parameters are used (cutoff=0, standardised=False), + # so all score values should be >= 0 + scoring_method = "metcalf" + lg = npl.get_links(npl.gcfs[:3], scoring_method) + for _, _, scores in lg.links: + score = scores[scoring_method] + assert score.value >= 0 + + lg = npl.get_links(npl.spectra[:1], scoring_method) + for _, _, scores in lg.links: + score = scores[scoring_method] + assert score.value >= 0 + + lg = npl.get_links(npl.mfs[:1], scoring_method) + for _, _, scores in lg.links: + score = scores[scoring_method] + assert score.value >= 0 diff --git a/tests/unit/scoring/conftest.py b/tests/unit/scoring/conftest.py index 8c7ac1f5..ec2f4ad6 100644 --- a/tests/unit/scoring/conftest.py +++ b/tests/unit/scoring/conftest.py @@ -77,9 +77,9 @@ def npl(gcfs, spectra, mfs, strains, tmp_path) -> NPLinker: npl._spectra = spectra npl._mfs = mfs npl._strains = strains - npl._gcf_lookup = {gcf.id: gcf for gcf in gcfs} - npl._mf_lookup = {mf.id: mf for mf in mfs} - npl._spec_lookup = {spec.id: spec for spec in spectra} + npl._gcf_dict = {gcf.id: gcf for gcf in gcfs} + npl._mf_dict = {mf.id: mf for mf in mfs} + npl._spec_dict = {spec.id: spec for spec in spectra} return npl diff --git a/tests/unit/scoring/test_nplinker_scoring.py b/tests/unit/scoring/test_nplinker_scoring.py deleted file mode 100644 index 3b65f911..00000000 --- a/tests/unit/scoring/test_nplinker_scoring.py +++ /dev/null @@ -1,116 +0,0 @@ -import numpy as np -import pytest - - -pytestmark = pytest.mark.skip(reason="Skip until refactoring relevant code.") - - -def test_get_links_gcf_standardised_false(npl, mc, gcfs, spectra, mfs, strains_list): - """Test `get_links` method when input is GCF objects and `standardised` is False.""" - # test raw scores (no standardisation) - mc.standardised = False - - # when cutoff is negative infinity, i.e. taking all scores - mc.cutoff = np.NINF - links = npl.get_links(list(gcfs), mc, and_mode=True) - assert isinstance(links, LinkCollection) - links = links.links # dict of link values - assert len(links) == 3 - assert {i.id for i in links.keys()} == {"gcf1", "gcf2", "gcf3"} - assert isinstance(links[gcfs[0]][spectra[0]], ObjectLink) - assert links[gcfs[0]][spectra[0]].data(mc) == 12 - assert links[gcfs[1]][spectra[0]].data(mc) == -9 - assert links[gcfs[2]][spectra[0]].data(mc) == 11 - assert links[gcfs[0]][mfs[0]].data(mc) == 12 - assert links[gcfs[1]][mfs[1]].data(mc) == 12 - assert links[gcfs[2]][mfs[2]].data(mc) == 21 - - # when test cutoff is 0, i.e. taking scores >= 0 - mc.cutoff = 0 - links = npl.get_links(list(gcfs), mc, and_mode=True) - assert isinstance(links, LinkCollection) - links = links.links - assert {i.id for i in links.keys()} == {"gcf1", "gcf2", "gcf3"} - assert isinstance(links[gcfs[0]][spectra[0]], ObjectLink) - # test scores - assert links[gcfs[0]][spectra[0]].data(mc) == 12 - assert links[gcfs[1]].get(spectra[0]) is None - assert links[gcfs[2]][spectra[0]].data(mc) == 11 - assert links[gcfs[0]][mfs[0]].data(mc) == 12 - assert links[gcfs[1]][mfs[1]].data(mc) == 12 - assert links[gcfs[2]][mfs[2]].data(mc) == 21 - - -@pytest.mark.skip(reason="To add after refactoring relevant code.") -def test_get_links_gcf_standardised_true(npl, mc, gcfs, spectra, mfs, strains_list): - """Test `get_links` method when input is GCF objects and `standardised` is True.""" - mc.standardised = True - ... - - -def test_get_links_spec_standardised_false(npl, mc, gcfs, spectra, strains_list): - """Test `get_links` method when input is Spectrum objects and `standardised` is False.""" - mc.standardised = False - - mc.cutoff = np.NINF - links = npl.get_links(list(spectra), mc, and_mode=True) - assert isinstance(links, LinkCollection) - links = links.links # dict of link values - assert len(links) == 3 - assert {i.id for i in links.keys()} == {"spectrum1", "spectrum2", "spectrum3"} - assert isinstance(links[spectra[0]][gcfs[0]], ObjectLink) - assert links[spectra[0]][gcfs[0]].data(mc) == 12 - assert links[spectra[0]][gcfs[1]].data(mc) == -9 - assert links[spectra[0]][gcfs[2]].data(mc) == 11 - - mc.cutoff = 0 - links = npl.get_links(list(spectra), mc, and_mode=True) - assert isinstance(links, LinkCollection) - links = links.links # dict of link values - assert len(links) == 3 - assert {i.id for i in links.keys()} == {"spectrum1", "spectrum2", "spectrum3"} - assert isinstance(links[spectra[0]][gcfs[0]], ObjectLink) - assert links[spectra[0]][gcfs[0]].data(mc) == 12 - assert links[spectra[0]].get(gcfs[1]) is None - assert links[spectra[0]][gcfs[2]].data(mc) == 11 - - -@pytest.mark.skip(reason="To add after refactoring relevant code.") -def test_get_links_spec_standardised_true(npl, mc, gcfs, spectra, strains_list): - """Test `get_links` method when input is Spectrum objects and `standardised` is True.""" - mc.standardised = True - ... - - -def test_get_links_mf_standardised_false(npl, mc, gcfs, mfs, strains_list): - """Test `get_links` method when input is MolecularFamily objects and `standardised` is False.""" - mc.standardised = False - - mc.cutoff = np.NINF - links = npl.get_links(list(mfs), mc, and_mode=True) - assert isinstance(links, LinkCollection) - links = links.links - assert len(links) == 3 - assert {i.id for i in links.keys()} == {"mf1", "mf2", "mf3"} - assert isinstance(links[mfs[0]][gcfs[0]], ObjectLink) - assert links[mfs[0]][gcfs[0]].data(mc) == 12 - assert links[mfs[0]][gcfs[1]].data(mc) == -9 - assert links[mfs[0]][gcfs[2]].data(mc) == 11 - - mc.cutoff = 0 - links = npl.get_links(list(mfs), mc, and_mode=True) - assert isinstance(links, LinkCollection) - links = links.links - assert len(links) == 3 - assert {i.id for i in links.keys()} == {"mf1", "mf2", "mf3"} - assert isinstance(links[mfs[0]][gcfs[0]], ObjectLink) - assert links[mfs[0]][gcfs[0]].data(mc) == 12 - assert links[mfs[0]].get(gcfs[1]) is None - assert links[mfs[0]][gcfs[2]].data(mc) == 11 - - -@pytest.mark.skip(reason="To add after refactoring relevant code.") -def test_get_links_mf_standardised_true(npl, mc, gcfs, mfs, strains_list): - """Test `get_links` method when input is MolecularFamily objects and `standardised` is True.""" - mc.standardised = True - ...