From 10e9d873a02c3617a89000298cf2b89e5c0bc310 Mon Sep 17 00:00:00 2001 From: Cunliang Geng Date: Fri, 14 Jun 2024 13:28:44 +0200 Subject: [PATCH] refactor MetcalfScoring methods (#254) * update abc.py to change `npl` to class attribute It might lead to inconsistent `npl` if both `__init__` and `setup` take `npl` as parameter. * remove `calc_score` method This method is just a wrapper on `_calc_raw_score` and `_calc_mean_std`. Better to move its code to `setup` method. * update `setup` to use `cls.npl` - use `cls.npl` - remove setup code from cache * rename objtype to type * change the format of raw score data frame * Update `get_links` to use new format of dataframe * Update metcalf_scoring.py * change LINK_TYPES to enum LinkType Update metcalf_scoring.py * update `_get_links` * add type variable ObjectType * change `obj_type` from a string value to actual data type * remove unused function * clean `get_links` * update docstrings for class attributes * use hashable object itself as index or column of dataframe E.g. use `GCF` instead of `GCF.id` as the index or column of a dataframe. This will skip the steps to look up actual object with a given id. * update the code for calculating standardised score * update unit tests * Apply suggestions from code review Co-authored-by: Giulia Crocioni <55382553+gcroci2@users.noreply.github.com> --------- Co-authored-by: Giulia Crocioni <55382553+gcroci2@users.noreply.github.com> --- src/nplinker/scoring/abc.py | 9 +- src/nplinker/scoring/metcalf_scoring.py | 519 ++++++++------------- src/nplinker/scoring/utils.py | 45 +- tests/unit/scoring/conftest.py | 2 +- tests/unit/scoring/test_metcalf_scoring.py | 395 +++------------- tests/unit/scoring/test_utils.py | 23 +- 6 files changed, 268 insertions(+), 725 deletions(-) diff --git a/src/nplinker/scoring/abc.py b/src/nplinker/scoring/abc.py index a96e858f..fa287190 100644 --- a/src/nplinker/scoring/abc.py +++ b/src/nplinker/scoring/abc.py @@ -21,14 +21,7 @@ class ScoringBase(ABC): """ name: str = "ScoringBase" - - def __init__(self, npl: NPLinker): - """Initialize the scoring method. - - Args: - npl: The NPLinker object. - """ - self.npl = npl + npl: NPLinker | None = None @classmethod @abstractmethod diff --git a/src/nplinker/scoring/metcalf_scoring.py b/src/nplinker/scoring/metcalf_scoring.py index e868e1f7..e1a09016 100644 --- a/src/nplinker/scoring/metcalf_scoring.py +++ b/src/nplinker/scoring/metcalf_scoring.py @@ -1,22 +1,20 @@ from __future__ import annotations import logging -import os +from enum import Enum from typing import TYPE_CHECKING +from typing import TypeVar import numpy as np import pandas as pd from scipy.stats import hypergeom from nplinker.genomics import GCF from nplinker.metabolomics import MolecularFamily from nplinker.metabolomics import Spectrum -from nplinker.pickler import load_pickled_data -from nplinker.pickler import save_pickled_data from .abc import ScoringBase from .link_graph import LinkGraph from .link_graph import Score from .utils import get_presence_gcf_strain from .utils import get_presence_mf_strain from .utils import get_presence_spec_strain -from .utils import isinstance_all if TYPE_CHECKING: @@ -25,7 +23,15 @@ logger = logging.getLogger(__name__) -LINK_TYPES = ["spec-gcf", "mf-gcf"] + +class LinkType(Enum): + """Enum class for link types.""" + + SPEC_GCF = "spec-gcf" + MF_GCF = "mf-gcf" + + +ObjectType = TypeVar("ObjectType", GCF, Spectrum, MolecularFamily) class MetcalfScoring(ScoringBase): @@ -33,173 +39,130 @@ class MetcalfScoring(ScoringBase): Attributes: name: The name of this scoring method, set to a fixed value `metcalf`. + npl: The NPLinker object. CACHE: The name of the cache file to use for storing the MetcalfScoring. + presence_gcf_strain: A DataFrame to store presence of gcfs with respect to strains. + The index of the DataFrame are the GCF objects and the columns are Strain objects. + The values are 1 where the gcf occurs in the strain, 0 otherwise. presence_spec_strain: A DataFrame to store presence of spectra with respect to strains. + The index of the DataFrame are the Spectrum objects and the columns are Strain objects. + The values are 1 where the spectrum occurs in the strain, 0 otherwise. presence_mf_strain: A DataFrame to store presence of molecular families with respect to strains. - raw_score_spec_gcf: The raw Metcalf scores for spectrum-GCF links. - raw_score_mf_gcf: The raw Metcalf scores for molecular family-GCF links. - metcalf_mean: The mean value used for standardising Metcalf scores. - metcalf_std: The standard deviation value used for standardising Metcalf scores. + The index of the DataFrame are the MolecularFamily objects and the columns are Strain objects. + The values are 1 where the molecular family occurs in the strain, 0 otherwise. + + raw_score_spec_gcf: A DataFrame to store the raw Metcalf scores for spectrum-gcf links. + The columns are "spec", "gcf" and "score": + + - The "spec" and "gcf" columns contain the Spectrum and GCF objects respectively, + - The "score" column contains the raw Metcalf scores. + + raw_score_mf_gcf: A DataFrame to store the raw Metcalf scores for molecular family-gcf links. + The columns are "mf", "gcf" and "score": + + - The "mf" and "gcf" columns contain the MolecularFamily and GCF objects respectively, + - the "score" column contains the raw Metcalf scores. + + metcalf_mean: A numpy array to store the mean value used for standardising Metcalf scores. + The array has shape (n_strains+1, n_strains+1), where n_strains is the number of strains. + metcalf_std: A numpy array to store the standard deviation value used for standardising + Metcalf scores. The array has shape (n_strains+1, n_strains+1), where n_strains is the + number of strains. """ name = "metcalf" + npl: NPLinker | None = None CACHE: str = "cache_metcalf_scoring.pckl" + metcalf_weights: tuple[int, int, int, int] = (10, -10, 0, 1) - # DataFrame to store presence of gcfs/spectra/mfs with respect to strains - # values = 1 where gcf/spec/fam occur in strain, 0 otherwise + # index: gcf/spec/mf ids, columns: strain ids, value: 0/1 presence_gcf_strain: pd.DataFrame = pd.DataFrame() presence_spec_strain: pd.DataFrame = pd.DataFrame() presence_mf_strain: pd.DataFrame = pd.DataFrame() - raw_score_spec_gcf: pd.DataFrame = pd.DataFrame() - raw_score_mf_gcf: pd.DataFrame = pd.DataFrame() + raw_score_spec_gcf: pd.DataFrame = pd.DataFrame(columns=["spec", "gcf", "score"]) + raw_score_mf_gcf: pd.DataFrame = pd.DataFrame(columns=["mf", "gcf", "score"]) + metcalf_mean: np.ndarray | None = None metcalf_std: np.ndarray | None = None - def __init__(self, npl: NPLinker) -> None: - """Create a MetcalfScoring object. - - Args: - npl: The NPLinker object to use for scoring. - - Attributes: - cutoff: The cutoff value to use for scoring. Scores below - this value will be discarded. Defaults to 1.0. - standardised: Whether to use standardised scores. Defaults - to True. - """ - super().__init__(npl) - - # TODO CG: refactor this method and extract code for cache file to a separate method @classmethod def setup(cls, npl: NPLinker): """Setup the MetcalfScoring object. This method is only called once to setup the MetcalfScoring object. + + Args: + npl: The NPLinker object. """ + if cls.npl is not None: + logger.info("MetcalfScoring.setup already called, skipping.") + return + logger.info( - "MetcalfScoring.setup (bgcs={}, gcfs={}, spectra={}, molfams={}, strains={})".format( - len(npl.bgcs), len(npl.gcfs), len(npl.spectra), len(npl.molfams), len(npl.strains) - ) + f"MetcalfScoring.setup starts: #bgcs={len(npl.bgcs)}, #gcfs={len(npl.gcfs)}, " + f"#spectra={len(npl.spectra)}, #molfams={len(npl.molfams)}, #strains={npl.strains}" ) + cls.npl = npl - cache_file = npl.output_dir / cls.CACHE - - # the metcalf preprocessing can take a long time for large datasets, so it's - # better to cache as the data won't change unless the number of objects does - dataset_counts = [ - len(npl.bgcs), - len(npl.gcfs), - len(npl.spectra), - len(npl.molfams), - len(npl.strains), - ] - if os.path.exists(cache_file): - logger.info("MetcalfScoring.setup loading cached data") - cache_data = load_pickled_data(npl, cache_file) - cache_ok = True - # TODO: wrap it as a validation method - if cache_data is not None: - (counts, metcalf_mean) = cache_data - # need to invalidate this if dataset appears to have changed - for i in range(len(counts)): - if counts[i] != dataset_counts[i]: - logger.info("MetcalfScoring.setup invalidating cached data!") - cache_ok = False - break - - if cache_ok: - cls.metcalf_mean = metcalf_mean - - if cls.metcalf_mean is None: - logger.info("MetcalfScoring.setup preprocessing dataset (this may take some time)") - cls.presence_gcf_strain = get_presence_gcf_strain(npl.gcfs, npl.strains) - cls.presence_spec_strain = get_presence_spec_strain(npl.spectra, npl.strains) - cls.presence_mf_strain = get_presence_mf_strain(npl.molfams, npl.strains) - cls.calc_score(link_type=LINK_TYPES[0]) - cls.calc_score(link_type=LINK_TYPES[1]) - logger.info("MetcalfScoring.setup caching results") - save_pickled_data((dataset_counts, cls.metcalf_mean), cache_file) - - logger.info("MetcalfScoring.setup completed") - - @classmethod - def calc_score( - cls, - link_type: str = "spec-gcf", - scoring_weights: tuple[int, int, int, int] = (10, -10, 0, 1), - ) -> None: - """Calculate Metcalf scores. - - This method calculates the `raw_score_spec_gcf`, `raw_score_mf_gcf`, `metcalf_mean`, and - `metcalf_std` attributes. + # calculate presence of gcfs/spectra/mfs with respect to strains + cls.presence_gcf_strain = get_presence_gcf_strain(npl.gcfs, npl.strains) + cls.presence_spec_strain = get_presence_spec_strain(npl.spectra, npl.strains) + cls.presence_mf_strain = get_presence_mf_strain(npl.molfams, npl.strains) - Args: - link_type: The type of link to score. Must be 'spec-gcf' or - 'mf-gcf'. Defaults to 'spec-gcf'. - scoring_weights: The weights to - use for Metcalf scoring. The weights are applied to - '(met_gcf, met_not_gcf, gcf_not_met, not_met_not_gcf)'. - Defaults to (10, -10, 0, 1). + # calculate raw Metcalf scores for spec-gcf links + raw_score_spec_gcf = cls._calc_raw_score( + cls.presence_spec_strain, cls.presence_gcf_strain, cls.metcalf_weights + ) + cls.raw_score_spec_gcf = raw_score_spec_gcf.reset_index().melt(id_vars="index") + cls.raw_score_spec_gcf.columns = ["spec", "gcf", "score"] - Raises: - ValueError: If an invalid link type is provided. - """ - if link_type not in LINK_TYPES: - raise ValueError(f"Invalid link type: {link_type}. Must be one of {LINK_TYPES}") + # calculate raw Metcalf scores for spec-gcf links + raw_score_mf_gcf = cls._calc_raw_score( + cls.presence_mf_strain, cls.presence_gcf_strain, cls.metcalf_weights + ) + cls.raw_score_mf_gcf = raw_score_mf_gcf.reset_index().melt(id_vars="index") + cls.raw_score_mf_gcf.columns = ["mf", "gcf", "score"] - if link_type == "spec-gcf": - logger.info("Create correlation matrices: spectra<->gcfs.") - cls.raw_score_spec_gcf = cls._calc_raw_score( - cls.presence_spec_strain, cls.presence_gcf_strain, scoring_weights - ) - if link_type == "mf-gcf": - logger.info("Create correlation matrices: mol-families<->gcfs.") - cls.raw_score_mf_gcf = cls._calc_raw_score( - cls.presence_mf_strain, cls.presence_gcf_strain, scoring_weights - ) + # calculate mean and std for standardising Metcalf scores + cls.metcalf_mean, cls.metcalf_std = cls._calc_mean_std( + len(npl.strains), cls.metcalf_weights + ) - if cls.metcalf_mean is None or cls.metcalf_std is None: - n_strains = cls.presence_gcf_strain.shape[1] - cls.metcalf_mean, cls.metcalf_std = cls._calc_mean_std(n_strains, scoring_weights) + logger.info("MetcalfScoring.setup completed") - def get_links(self, *objects: GCF | Spectrum | MolecularFamily, **parameters) -> LinkGraph: + def get_links(self, *objects: ObjectType, **parameters) -> LinkGraph: """Get links for the given objects. - The given objects are treated as input or source objects, which must be GCF, Spectrum or - MolecularFamily objects. - Args: - objects: The objects to get links for. Must be GCF, Spectrum or MolecularFamily objects. + objects: The objects to get links for. All objects must be of the same type, i.e. `GCF`, + `Spectrum` or `MolecularFamily` type. + If no objects are provided, all detected objects (`npl.gcfs`) will be used. parameters: The scoring parameters to use for the links. The parameters are: - cutoff: The minimum score to consider a link (≥cutoff). Default is 0. - standardised: Whether to use standardised scores. Default is False. Returns: - The LinkGraph object containing the links involving the input objects. + The `LinkGraph` object containing the links involving the input objects with the Metcalf + scores. Raises: - ValueError: If the input objects are empty. - TypeError: If the input objects are not of the correct type. + TypeError: If the input objects are not of the same type or the object type is invalid. """ # validate input objects - # if the input objects are empty, use all objects if len(objects) == 0: objects = self.npl.gcfs - - # TODO: allow mixed input types? - if isinstance_all(*objects, objtype=GCF): - obj_type = "gcf" - elif isinstance_all(*objects, objtype=Spectrum): - obj_type = "spec" - elif isinstance_all(*objects, objtype=MolecularFamily): - obj_type = "mf" - else: - types = [type(i) for i in objects] + # check if all objects are of the same type + types = {type(i) for i in objects} + if len(types) > 1: + raise TypeError("Input objects must be of the same type.") + # check if the object type is valid + obj_type = next(iter(types)) + if obj_type not in (GCF, Spectrum, MolecularFamily): raise TypeError( - f"Invalid type {set(types)}. Input objects must be GCF, Spectrum or MolecularFamily objects." + f"Invalid type {obj_type}. Input objects must be GCF, Spectrum or MolecularFamily objects." ) # validate scoring parameters @@ -207,68 +170,32 @@ def get_links(self, *objects: GCF | Spectrum | MolecularFamily, **parameters) -> self._standardised: bool = parameters.get("standardised", False) parameters.update({"cutoff": self._cutoff, "standardised": self._standardised}) - logger.info(f"MetcalfScoring: standardised = {self._standardised}") + logger.info( + f"MetcalfScoring: #objects={len(objects)}, type={obj_type}, cutoff={self._cutoff}, " + f"standardised={self._standardised}" + ) if not self._standardised: - scores_list = self._get_links(*objects, score_cutoff=self._cutoff) - # TODO CG: verify the logics of standardised score and add unit tests + scores_list = self._get_links(*objects, obj_type=obj_type, score_cutoff=self._cutoff) else: + if self.metcalf_mean is None or self.metcalf_std is None: + raise ValueError( + "MetcalfScoring.metcalf_mean and metcalf_std are not set. Run MetcalfScoring.setup first." + ) # use negative infinity as the score cutoff to ensure we get all links - # the self.cutoff will be applied later in the postprocessing step - scores_list = self._get_links(*objects, score_cutoff=np.NINF) - if obj_type == "gcf": - scores_list = self._calc_standardised_score_gen(scores_list) - else: - scores_list = self._calc_standardised_score_met(scores_list) + scores_list = self._get_links(*objects, obj_type=obj_type, score_cutoff=np.NINF) + scores_list = self._calc_standardised_score(scores_list) links = LinkGraph() - if obj_type == "gcf": - logger.info( - f"MetcalfScoring: input_type=GCF, result_type=Spec/MolFam, " - f"#inputs={len(objects)}." - ) - # scores is the DataFrame with index "source", "target", "score" - for scores in scores_list: - # when no links found - if scores.shape[1] == 0: - logger.info(f'MetcalfScoring: found no "{scores.name}" links') - else: - # when links found - for col_index in range(scores.shape[1]): - gcf = self.npl.lookup_gcf(scores.loc["source", col_index]) - if scores.name == LINK_TYPES[0]: - met = self.npl.lookup_spectrum(scores.loc["target", col_index]) - else: - met = self.npl.lookup_mf(scores.loc["target", col_index]) - links.add_link( - gcf, - met, - metcalf=Score(self.name, scores.loc["score", col_index], parameters), - ) - logger.info(f"MetcalfScoring: found {len(links)} {scores.name} links.") - else: - logger.info( - f"MetcalfScoring: input_type=Spec/MolFam, result_type=GCF, " - f"#inputs={len(objects)}." - ) - scores = scores_list[0] - # when no links found - if scores.shape[1] == 0: - logger.info(f'MetcalfScoring: found no links "{scores.name}" for input objects') - else: - for col_index in range(scores.shape[1]): - gcf = self.npl.lookup_gcf(scores.loc["target", col_index]) - if scores.name == LINK_TYPES[0]: - met = self.npl.lookup_spectrum(scores.loc["source", col_index]) - else: - met = self.npl.lookup_mf(scores.loc["source", col_index]) - links.add_link( - met, - gcf, - metcalf=Score(self.name, scores.loc["score", col_index], parameters), - ) - logger.info(f"MetcalfScoring: found {len(links)} {scores.name} links.") - - logger.info("MetcalfScoring: completed") + for score_df in scores_list: + for row in score_df.itertuples(index=False): # row has attributes: spec/mf, gcf, score + met = row.spec if score_df.name == LinkType.SPEC_GCF else row.mf + links.add_link( + row.gcf, + met, + metcalf=Score(self.name, row.score, parameters), + ) + + logger.info(f"MetcalfScoring: completed! Found {len(links.links)} links in total.") return links # TODO CG: refactor this method @@ -318,8 +245,18 @@ def _calc_raw_score( @staticmethod def _calc_mean_std( - n_strains: int, scoring_weights: tuple[int, int, int, int] + n_strains: int, weights: tuple[int, int, int, int] ) -> tuple[np.ndarray, np.ndarray]: + """Calculate the mean and standard deviation for Metcalf scoring. + + Args: + n_strains: The number of strains. + weights: The weights to use for Metcalf scoring. + + Returns: + Two numpy arrays containing the mean and standard deviation values for Metcalf scoring. + The arrays have shape (n_strains+1, n_strains+1). + """ sz = (n_strains + 1, n_strains + 1) mean = np.zeros(sz) variance = np.zeros(sz) @@ -332,10 +269,10 @@ def _calc_mean_std( for o in range(min_overlap, max_overlap + 1): o_prob = hypergeom.pmf(o, n_strains, n, m) # compute metcalf for n strains in type 1 and m in gcf - score = o * scoring_weights[0] - score += scoring_weights[1] * (n - o) - score += scoring_weights[2] * (m - o) - score += scoring_weights[3] * (n_strains - (n + m - o)) + score = o * weights[0] + score += weights[1] * (n - o) + score += weights[2] * (m - o) + score += weights[3] * (n_strains - (n + m - o)) expected_value += o_prob * score expected_sq += o_prob * (score**2) mean[n, m] = expected_value @@ -347,165 +284,89 @@ def _calc_mean_std( def _get_links( self, - *objects: tuple[GCF, ...] | tuple[Spectrum, ...] | tuple[MolecularFamily, ...], + *objects: ObjectType, + obj_type: GCF | Spectrum | MolecularFamily, score_cutoff: float = 0, ) -> list[pd.DataFrame]: - """Get links and scores for given objects. + """Get links and scores for the given objects. Args: - objects: A list of GCF, Spectrum or MolecularFamily objects - and all objects must be of the same type. + objects: A list of GCF, Spectrum or MolecularFamily objects and all objects must be of + the same type. + obj_type: The type of the objects. score_cutoff: Minimum score to consider a link (≥score_cutoff). Default is 0. Returns: - List of data frames containing the ids of the linked objects - and the score. The data frame has index names of - 'source', 'target' and 'score': + List of data frames containing the ids of the linked objects and the score. - - the 'source' row contains the ids of the input/source objects, - - the 'target' row contains the ids of the target objects, - - the 'score' row contains the scores. + The data frame is named by link types, see `LinkType`. It has column names of + ['spec', 'gcf', 'score'] or ['mf', 'gcf', 'score'] depending on the link type: - Raises: - ValueError: If input objects are empty. - TypeError: If input objects are not GCF, Spectrum or MolecularFamily objects. + - the 'spec', 'mf' or 'gcf' column contains the Spectrum, MolecularFamily or GCF objects, + - the 'score' column contains the scores. """ - if len(objects) == 0: - raise ValueError("Empty input objects.") - - if isinstance_all(*objects, objtype=GCF): - obj_type = "gcf" - elif isinstance_all(*objects, objtype=Spectrum): - obj_type = "spec" - elif isinstance_all(*objects, objtype=MolecularFamily): - obj_type = "mf" - else: - types = [type(i) for i in objects] - raise TypeError( - f"Invalid type {set(types)}. Input objects must be GCF, Spectrum or MolecularFamily objects." - ) - links = [] - if obj_type == "gcf": - obj_ids = [gcf.id for gcf in objects] - # spec-gcf - scores = self.raw_score_spec_gcf.loc[:, obj_ids] - df = self._get_scores_source_gcf(scores, score_cutoff) - df.name = LINK_TYPES[0] - links.append(df) - # mf-gcf - scores = self.raw_score_mf_gcf.loc[:, obj_ids] - df = self._get_scores_source_gcf(scores, score_cutoff) - df.name = LINK_TYPES[1] - links.append(df) + col_name = ( + "spec" if obj_type == Spectrum else "mf" if obj_type == MolecularFamily else "gcf" + ) - if obj_type == "spec": - obj_ids = [spec.id for spec in objects] - scores = self.raw_score_spec_gcf.loc[obj_ids, :] - df = self._get_scores_source_met(scores, score_cutoff) - df.name = LINK_TYPES[0] + # spec-gcf link + if obj_type in (GCF, Spectrum): + df = self.raw_score_spec_gcf[ + self.raw_score_spec_gcf[col_name].isin(objects) + & (self.raw_score_spec_gcf["score"] >= score_cutoff) + ] + df.name = LinkType.SPEC_GCF links.append(df) - if obj_type == "mf": - obj_ids = [mf.id for mf in objects] - scores = self.raw_score_mf_gcf.loc[obj_ids, :] - df = self._get_scores_source_met(scores, score_cutoff) - df.name = LINK_TYPES[1] + # mf-gcf link + if obj_type in (GCF, MolecularFamily): + df = self.raw_score_mf_gcf[ + self.raw_score_mf_gcf[col_name].isin(objects) + & (self.raw_score_mf_gcf["score"] >= score_cutoff) + ] + df.name = LinkType.MF_GCF links.append(df) - return links - @staticmethod - def _get_scores_source_gcf(scores: pd.DataFrame, score_cutoff: float) -> pd.DataFrame: - row_indexes, col_indexes = np.where(scores >= score_cutoff) - src_obj_ids = scores.columns[col_indexes].to_list() - target_obj_ids = scores.index[row_indexes].to_list() - scores_candidate = scores.values[row_indexes, col_indexes].tolist() - return pd.DataFrame( - [src_obj_ids, target_obj_ids, scores_candidate], index=["source", "target", "score"] - ) + return links - @staticmethod - def _get_scores_source_met(scores: pd.DataFrame, score_cutoff: float) -> pd.DataFrame: - row_indexes, col_indexes = np.where(scores >= score_cutoff) - src_obj_ids = scores.index[row_indexes].to_list() - target_obj_ids = scores.columns[col_indexes].to_list() - scores_candidate = scores.values[row_indexes, col_indexes].tolist() - return pd.DataFrame( - [src_obj_ids, target_obj_ids, scores_candidate], index=["source", "target", "score"] - ) + def _calc_standardised_score(self, raw_scores: list[pd.DataFrame]) -> list[pd.DataFrame]: + """Calculate standardised Metcalf scores. - def _calc_standardised_score_met(self, results: list) -> list[pd.DataFrame]: - if self.metcalf_mean is None or self.metcalf_std is None: - raise ValueError( - "Metcalf mean and std not found. Have you called `MetcalfScoring.setup(npl)`?" - ) - logger.info("Calculating standardised Metcalf scores (met input)") - raw_score = results[0] - z_scores = [] - for col_index in range(raw_score.shape[1]): - gcf = self.npl.lookup_gcf(raw_score.loc["target", col_index]) - if raw_score.name == LINK_TYPES[0]: - met = self.npl.lookup_spectrum(raw_score.at["source", col_index]) - else: - met = self.npl.lookup_mf(raw_score.at["source", col_index]) - - num_gcf_strains = len(gcf.strains) - num_met_strains = len(met.strains) - mean = self.metcalf_mean[num_met_strains][num_gcf_strains] - sqrt = self.metcalf_std[num_met_strains][num_gcf_strains] - z_score = (raw_score.at["score", col_index] - mean) / sqrt - z_scores.append(z_score) - - z_scores = np.array(z_scores) - mask = z_scores >= self._cutoff - - scores_df = pd.DataFrame( - [ - raw_score.loc["source"].values[mask], - raw_score.loc["target"].values[mask], - z_scores[mask], - ], - index=raw_score.index, - ) - scores_df.name = raw_score.name - - return [scores_df] + Args: + raw_scores: A list of DataFrames containing the raw Metcalf scores. - def _calc_standardised_score_gen(self, results: list) -> list[pd.DataFrame]: - if self.metcalf_mean is None or self.metcalf_std is None: - raise ValueError( - "Metcalf mean and std not found. Have you called `MetcalfScoring.setup(npl)`?" - ) - logger.info("Calculating standardised Metcalf scores (gen input)") - postprocessed_scores = [] - for raw_score in results: - z_scores = [] - for col_index in range(raw_score.shape[1]): - gcf = self.npl.lookup_gcf(raw_score.loc["source", col_index]) - if raw_score.name == LINK_TYPES[0]: - met = self.npl.lookup_spectrum(raw_score.at["target", col_index]) + Returns: + A list of DataFrames containing the standardised Metcalf scores. + """ + standardised_scores = [] + for raw_score_df in raw_scores: + # create a new DataFrame to store the standardised scores, with the same columns + # and name as the raw score DataFrame + standardised_score_df = pd.DataFrame(columns=raw_score_df.columns) + + for row in raw_score_df.itertuples(index=False): + met = row.spec if raw_score_df.name == LinkType.SPEC_GCF else row.mf + n_gcf_strains = len(row.gcf.strains) + n_met_strains = len(met.strains) + + mean = self.metcalf_mean[n_met_strains][n_gcf_strains] + sqrt = self.metcalf_std[n_met_strains][n_gcf_strains] + + z_score = (row.score - mean) / sqrt + + if z_score >= self._cutoff: + # add the row to the standardised score DataFrame with the z-score as score value + data = {col_name: getattr(row, col_name) for col_name in raw_score_df.columns} + data["score"] = z_score + new_row = pd.DataFrame(data, index=[0]) + standardised_score_df = pd.concat( + (standardised_score_df, new_row), ignore_index=True + ) else: - met = self.npl.lookup_mf(raw_score.at["target", col_index]) - - num_gcf_strains = len(gcf.strains) - num_met_strains = len(met.strains) - mean = self.metcalf_mean[num_met_strains][num_gcf_strains] - sqrt = self.metcalf_std[num_met_strains][num_gcf_strains] - z_score = (raw_score.at["score", col_index] - mean) / sqrt - z_scores.append(z_score) - - z_scores = np.array(z_scores) - mask = z_scores >= self._cutoff - - scores_df = pd.DataFrame( - [ - raw_score.loc["source"].values[mask], - raw_score.loc["target"].values[mask], - z_scores[mask], - ], - index=raw_score.index, - ) - scores_df.name = raw_score.name - postprocessed_scores.append(scores_df) + continue + + standardised_score_df.name = raw_score_df.name + standardised_scores.append(standardised_score_df) - return postprocessed_scores + return standardised_scores diff --git a/src/nplinker/scoring/utils.py b/src/nplinker/scoring/utils.py index a93ed24f..418fc1ec 100644 --- a/src/nplinker/scoring/utils.py +++ b/src/nplinker/scoring/utils.py @@ -12,71 +12,62 @@ from nplinker.strain import StrainCollection -def isinstance_all(*objects, objtype) -> bool: - """Check if all objects are of the given type.""" - return all(isinstance(x, objtype) for x in objects) - - def get_presence_gcf_strain(gcfs: Sequence[GCF], strains: StrainCollection) -> pd.DataFrame: - """Get the occurence of strains in gcfs. + """Get the occurrence of strains in gcfs. - The occurence is a DataFrame with gcfs as rows and strains as columns, - where index is `gcf.id` and column name is `strain.id`. The values - are 1 if the gcf contains the strain and 0 otherwise. + The occurrence is a DataFrame with GCF objects as index and Strain objects as columns, and the + values are 1 if the gcf occurs in the strain, 0 otherwise. """ df_gcf_strain = pd.DataFrame( np.zeros((len(gcfs), len(strains))), - index=[gcf.id for gcf in gcfs], - columns=[strain.id for strain in strains], + index=gcfs, + columns=list(strains), dtype=int, ) for gcf in gcfs: for strain in strains: if gcf.has_strain(strain): - df_gcf_strain.loc[gcf.id, strain.id] = 1 + df_gcf_strain.loc[gcf, strain] = 1 return df_gcf_strain def get_presence_spec_strain( spectra: Sequence[Spectrum], strains: StrainCollection ) -> pd.DataFrame: - """Get the occurence of strains in spectra. + """Get the occurrence of strains in spectra. - The occurence is a DataFrame with spectra as rows and strains as columns, - where index is `spectrum.id` and column name is `strain.id`. - The values are 1 if the spectrum contains the strain and 0 otherwise. + The occurrence is a DataFrame with Spectrum objects as index and Strain objects as columns, and + the values are 1 if the spectrum occurs in the strain, 0 otherwise. """ df_spec_strain = pd.DataFrame( np.zeros((len(spectra), len(strains))), - index=[spectrum.id for spectrum in spectra], - columns=[strain.id for strain in strains], + index=spectra, + columns=list(strains), dtype=int, ) for spectrum in spectra: for strain in strains: if spectrum.has_strain(strain): - df_spec_strain.loc[spectrum.id, strain.id] = 1 + df_spec_strain.loc[spectrum, strain] = 1 return df_spec_strain def get_presence_mf_strain( mfs: Sequence[MolecularFamily], strains: StrainCollection ) -> pd.DataFrame: - """Get the occurence of strains in molecular families. + """Get the occurrence of strains in molecular families. - The occurence is a DataFrame with molecular families as rows and - strains as columns, where index is `mf.id` and column name is - `strain.id`. The values are 1 if the molecular family contains the - strain and 0 otherwise. + The occurrence is a DataFrame with MolecularFamily objects as index and Strain objects as + columns, and the values are 1 if the molecular family occurs in the strain, 0 otherwise. """ df_mf_strain = pd.DataFrame( np.zeros((len(mfs), len(strains))), - index=[mf.id for mf in mfs], - columns=[strain.id for strain in strains], + index=mfs, + columns=list(strains), dtype=int, ) for mf in mfs: for strain in strains: if mf.has_strain(strain): - df_mf_strain.loc[mf.id, strain.id] = 1 + df_mf_strain.loc[mf, strain] = 1 return df_mf_strain diff --git a/tests/unit/scoring/conftest.py b/tests/unit/scoring/conftest.py index cb750407..a39381ee 100644 --- a/tests/unit/scoring/conftest.py +++ b/tests/unit/scoring/conftest.py @@ -86,6 +86,6 @@ def npl(gcfs, spectra, mfs, strains, tmp_path) -> NPLinker: @fixture(scope="function") def mc(npl) -> MetcalfScoring: """MetcalfScoring object.""" - mc = MetcalfScoring(npl) + mc = MetcalfScoring() mc.setup(npl) return mc diff --git a/tests/unit/scoring/test_metcalf_scoring.py b/tests/unit/scoring/test_metcalf_scoring.py index 4e6e5651..c9089fcc 100644 --- a/tests/unit/scoring/test_metcalf_scoring.py +++ b/tests/unit/scoring/test_metcalf_scoring.py @@ -6,14 +6,15 @@ def test_init(npl): - mc = MetcalfScoring(npl) - assert mc.npl == npl + mc = MetcalfScoring() assert mc.name == "metcalf" + assert mc.npl is None + assert mc.metcalf_weights == (10, -10, 0, 1) assert_frame_equal(mc.presence_gcf_strain, pd.DataFrame()) assert_frame_equal(mc.presence_spec_strain, pd.DataFrame()) assert_frame_equal(mc.presence_mf_strain, pd.DataFrame()) - assert_frame_equal(mc.raw_score_spec_gcf, pd.DataFrame()) - assert_frame_equal(mc.raw_score_mf_gcf, pd.DataFrame()) + assert_frame_equal(mc.raw_score_spec_gcf, pd.DataFrame(columns=["spec", "gcf", "score"])) + assert_frame_equal(mc.raw_score_mf_gcf, pd.DataFrame(columns=["mf", "gcf", "score"])) assert mc.metcalf_mean is None assert mc.metcalf_std is None @@ -23,140 +24,43 @@ def test_init(npl): # -def test_setup(mc): +def test_setup(mc, gcfs, spectra, mfs, strains): """Test `setup` method when cache file does not exist.""" - col_names = ["strain1", "strain2", "strain3"] assert_frame_equal( mc.presence_gcf_strain, - pd.DataFrame( - [[1, 0, 0], [0, 1, 0], [1, 1, 0]], index=["gcf1", "gcf2", "gcf3"], columns=col_names - ), + pd.DataFrame([[1, 0, 0], [0, 1, 0], [1, 1, 0]], index=gcfs, columns=list(strains)), ) assert_frame_equal( mc.presence_spec_strain, pd.DataFrame( [[1, 0, 0], [0, 1, 0], [1, 1, 0]], - index=["spectrum1", "spectrum2", "spectrum3"], - columns=col_names, + index=spectra, + columns=list(strains), ), ) assert_frame_equal( mc.presence_mf_strain, - pd.DataFrame( - [[1, 0, 0], [0, 1, 0], [1, 1, 0]], index=["mf1", "mf2", "mf3"], columns=col_names - ), - ) - - assert_frame_equal( - mc.raw_score_spec_gcf, - pd.DataFrame( - [[12, -9, 11], [-9, 12, 11], [1, 1, 21]], - index=["spectrum1", "spectrum2", "spectrum3"], - columns=["gcf1", "gcf2", "gcf3"], - ), - ) - assert_frame_equal( - mc.raw_score_mf_gcf, - pd.DataFrame( - [[12, -9, 11], [-9, 12, 11], [1, 1, 21]], - index=["mf1", "mf2", "mf3"], - columns=["gcf1", "gcf2", "gcf3"], - ), - ) - - assert isinstance(mc.metcalf_mean, np.ndarray) - assert isinstance(mc.metcalf_std, np.ndarray) - assert mc.metcalf_mean.shape == (4, 4) # (n_strains+1 , n_strains+1) - assert mc.metcalf_std.shape == (4, 4) - - -def test_setup_load_cache(mc, npl): - """Test `setup` method when cache file exists.""" - mc.setup(npl) - - col_names = ["strain1", "strain2", "strain3"] - assert_frame_equal( - mc.presence_gcf_strain, - pd.DataFrame( - [[1, 0, 0], [0, 1, 0], [1, 1, 0]], index=["gcf1", "gcf2", "gcf3"], columns=col_names - ), - ) - assert_frame_equal( - mc.presence_spec_strain, pd.DataFrame( [[1, 0, 0], [0, 1, 0], [1, 1, 0]], - index=["spectrum1", "spectrum2", "spectrum3"], - columns=col_names, - ), - ) - assert_frame_equal( - mc.presence_mf_strain, - pd.DataFrame( - [[1, 0, 0], [0, 1, 0], [1, 1, 0]], index=["mf1", "mf2", "mf3"], columns=col_names - ), - ) - - assert_frame_equal( - mc.raw_score_spec_gcf, - pd.DataFrame( - [[12, -9, 11], [-9, 12, 11], [1, 1, 21]], - index=["spectrum1", "spectrum2", "spectrum3"], - columns=["gcf1", "gcf2", "gcf3"], - ), - ) - assert_frame_equal( - mc.raw_score_mf_gcf, - pd.DataFrame( - [[12, -9, 11], [-9, 12, 11], [1, 1, 21]], - index=["mf1", "mf2", "mf3"], - columns=["gcf1", "gcf2", "gcf3"], + index=mfs, + columns=list(strains), ), ) - assert isinstance(mc.metcalf_mean, np.ndarray) - assert isinstance(mc.metcalf_std, np.ndarray) - assert mc.metcalf_mean.shape == (4, 4) # (n_strains+1 , n_strains+1) - assert mc.metcalf_std.shape == (4, 4) + df = pd.DataFrame([[12, -9, 11], [-9, 12, 11], [1, 1, 21]], index=spectra, columns=gcfs) + df_melted = df.reset_index().melt(id_vars="index") + df_melted.columns = ["spec", "gcf", "score"] + assert_frame_equal(mc.raw_score_spec_gcf, df_melted) + df = pd.DataFrame([[12, -9, 11], [-9, 12, 11], [1, 1, 21]], index=mfs, columns=gcfs) + df_melted = df.reset_index().melt(id_vars="index") + df_melted.columns = ["mf", "gcf", "score"] + assert_frame_equal(mc.raw_score_mf_gcf, df_melted) -# -# Test the `calc_score` method -# - - -def test_calc_score_raw_score(mc): - """Test `calc_score` method for `raw_score_spec_gcf` and `raw_score_mf_gcf`.""" - # link type = 'spec-gcf' - mc.calc_score(link_type="spec-gcf") - assert_frame_equal( - mc.raw_score_spec_gcf, - pd.DataFrame( - [[12, -9, 11], [-9, 12, 11], [1, 1, 21]], - index=["spectrum1", "spectrum2", "spectrum3"], - columns=["gcf1", "gcf2", "gcf3"], - ), - ) - # link type = 'mf-gcf' - mc.calc_score(link_type="mf-gcf") - assert_frame_equal( - mc.raw_score_mf_gcf, - pd.DataFrame( - [[12, -9, 11], [-9, 12, 11], [1, 1, 21]], - index=["mf1", "mf2", "mf3"], - columns=["gcf1", "gcf2", "gcf3"], - ), - ) - - -def test_calc_score_mean_std(mc): - """Test `calc_score` method for `metcalf_mean` and `metcalf_std`.""" - mc.calc_score(link_type="spec-gcf") assert isinstance(mc.metcalf_mean, np.ndarray) assert isinstance(mc.metcalf_std, np.ndarray) assert mc.metcalf_mean.shape == (4, 4) # (n_strains+1 , n_strains+1) assert mc.metcalf_std.shape == (4, 4) - # TODO CG: add tests for values after refactoring _calc_mean_std method - # assert mc.metcalf_mean == expected_array # @@ -165,6 +69,7 @@ def test_calc_score_mean_std(mc): def test_get_links_default(mc, gcfs, spectra, mfs): + # same as cutoff=0, standardised=False lg = mc.get_links() assert lg[gcfs[0]][spectra[0]][mc.name].value == 12 assert lg[gcfs[1]].get(spectra[0]) is None @@ -174,6 +79,25 @@ def test_get_links_default(mc, gcfs, spectra, mfs): assert lg[gcfs[2]][mfs[2]][mc.name].value == 21 +@pytest.mark.parametrize( + "objects, expected", + [ + ([1], "Invalid type . .*"), + ([1, 2], "Invalid type . .*"), + ("12", "Invalid type . .*"), + ], +) +def test_get_links_invalid_input_type(mc, objects, expected): + with pytest.raises(TypeError, match=expected): + mc.get_links(*objects) + + +def test_get_links_invalid_mixed_types(mc, spectra, mfs): + objects = (*spectra, *mfs) + with pytest.raises(TypeError, match="Input objects must be of the same type."): + mc.get_links(*objects) + + def test_get_links_gcf_standardised_false(mc, gcfs, spectra, mfs): """Test `get_links` method when input is GCF objects and `standardised` is False.""" # when cutoff is negative infinity, i.e. taking all scores @@ -195,10 +119,13 @@ def test_get_links_gcf_standardised_false(mc, gcfs, spectra, mfs): assert lg[gcfs[2]][mfs[2]][mc.name].value == 21 -@pytest.mark.skip(reason="To add after refactoring relevant code.") -def test_get_links_gcf_standardised_true(mc, gcfs, spectra, mfs): +def test_get_links_gcf_standardised_true(mc, gcfs): """Test `get_links` method when input is GCF objects and `standardised` is True.""" - ... + lg = mc.get_links(*gcfs, cutoff=np.NINF, standardised=True) + assert len(lg.links) == 18 + + lg = mc.get_links(*gcfs, cutoff=0, standardised=True) + assert len(lg.links) == 14 def test_get_links_spec_standardised_false(mc, gcfs, spectra): @@ -214,10 +141,13 @@ def test_get_links_spec_standardised_false(mc, gcfs, spectra): assert lg[spectra[0]][gcfs[2]][mc.name].value == 11 -@pytest.mark.skip(reason="To add after refactoring relevant code.") def test_get_links_spec_standardised_true(mc, gcfs, spectra): """Test `get_links` method when input is Spectrum objects and `standardised` is True.""" - ... + lg = mc.get_links(*spectra, cutoff=np.NINF, standardised=True) + assert len(lg.links) == 9 + + lg = mc.get_links(*spectra, cutoff=0, standardised=True) + assert len(lg.links) == 7 def test_get_links_mf_standardised_false(mc, gcfs, mfs): @@ -233,227 +163,10 @@ def test_get_links_mf_standardised_false(mc, gcfs, mfs): assert lg[mfs[0]][gcfs[2]][mc.name].value == 11 -@pytest.mark.skip(reason="To add after refactoring relevant code.") def test_get_links_mf_standardised_true(mc, gcfs, mfs): """Test `get_links` method when input is MolecularFamily objects and `standardised` is True.""" - ... - - -@pytest.mark.parametrize( - "objects, expected", - [ - ([1], "Invalid type {}"), - ([1, 2], "Invalid type {}"), - ("12", "Invalid type {}"), - ], -) -def test_get_links_invalid_input_type(mc, objects, expected): - with pytest.raises(TypeError) as e: - mc.get_links(*objects) - assert expected in str(e.value) - - -def test_get_links_invalid_mixed_types(mc, spectra, mfs): - objects = (*spectra, *mfs) - with pytest.raises(TypeError) as e: - mc.get_links(*objects) - assert "Invalid type" in str(e.value) - assert ".MolecularFamily" in str(e.value) - assert ".Spectrum" in str(e.value) - - -# -# Test the `_get_links` method -# - - -def test__get_links_gcf(mc, gcfs): - """Test `get_links` method for input GCF objects.""" - mc.calc_score(link_type="spec-gcf") - mc.calc_score(link_type="mf-gcf") - index_names = ["source", "target", "score"] - - # cutoff = negative infinity (float) - links = mc._get_links(*gcfs, score_cutoff=np.NINF) - assert len(links) == 2 - # expected values got from `test_calc_score_raw_score` - assert_frame_equal( - links[0], - pd.DataFrame( - [ - ["gcf1", "gcf2", "gcf3"] * 3, - [ - *["spectrum1"] * 3, - *["spectrum2"] * 3, - *["spectrum3"] * 3, - ], - [12, -9, 11, -9, 12, 11, 1, 1, 21], - ], - index=index_names, - ), - ) - assert_frame_equal( - links[1], - pd.DataFrame( - [ - ["gcf1", "gcf2", "gcf3"] * 3, - [ - *["mf1"] * 3, - *["mf2"] * 3, - *["mf3"] * 3, - ], - [12, -9, 11, -9, 12, 11, 1, 1, 21], - ], - index=index_names, - ), - ) - - # cutoff = 0 - links = mc._get_links(*gcfs, score_cutoff=0) - assert len(links) == 2 - assert_frame_equal( - links[0], - pd.DataFrame( - [ - ["gcf1", "gcf3", "gcf2", "gcf3", "gcf1", "gcf2", "gcf3"], - [ - *["spectrum1"] * 2, - *["spectrum2"] * 2, - *["spectrum3"] * 3, - ], - [12, 11, 12, 11, 1, 1, 21], - ], - index=index_names, - ), - ) - assert_frame_equal( - links[1], - pd.DataFrame( - [ - ["gcf1", "gcf3", "gcf2", "gcf3", "gcf1", "gcf2", "gcf3"], - [ - *["mf1"] * 2, - *["mf2"] * 2, - *["mf3"] * 3, - ], - [12, 11, 12, 11, 1, 1, 21], - ], - index=index_names, - ), - ) - + lg = mc.get_links(*mfs, cutoff=np.NINF, standardised=True) + assert len(lg.links) == 9 -def test__get_links_spec(mc, spectra): - """Test `get_links` method for input Spectrum objects.""" - mc.calc_score(link_type="spec-gcf") - mc.calc_score(link_type="mf-gcf") - index_names = ["source", "target", "score"] - # cutoff = negative infinity (float) - links = mc._get_links(*spectra, score_cutoff=np.NINF) - assert len(links) == 1 - assert_frame_equal( - links[0], - pd.DataFrame( - [ - [ - *["spectrum1"] * 3, - *["spectrum2"] * 3, - *["spectrum3"] * 3, - ], - ["gcf1", "gcf2", "gcf3"] * 3, - [12, -9, 11, -9, 12, 11, 1, 1, 21], - ], - index=index_names, - ), - ) - # cutoff = 0 - links = mc._get_links(*spectra, score_cutoff=0) - assert_frame_equal( - links[0], - pd.DataFrame( - [ - [ - *["spectrum1"] * 2, - *["spectrum2"] * 2, - *["spectrum3"] * 3, - ], - ["gcf1", "gcf3", "gcf2", "gcf3", "gcf1", "gcf2", "gcf3"], - [12, 11, 12, 11, 1, 1, 21], - ], - index=index_names, - ), - ) - - -def test__get_links_mf(mc, mfs): - """Test `get_links` method for input MolecularFamily objects.""" - mc.calc_score(link_type="spec-gcf") - mc.calc_score(link_type="mf-gcf") - index_names = ["source", "target", "score"] - # cutoff = negative infinity (float) - links = mc._get_links(*mfs, score_cutoff=np.NINF) - assert len(links) == 1 - assert_frame_equal( - links[0], - pd.DataFrame( - [ - [ - *["mf1"] * 3, - *["mf2"] * 3, - *["mf3"] * 3, - ], - ["gcf1", "gcf2", "gcf3"] * 3, - [12, -9, 11, -9, 12, 11, 1, 1, 21], - ], - index=index_names, - ), - ) - # cutoff = 0 - links = mc._get_links(*mfs, score_cutoff=0) - assert_frame_equal( - links[0], - pd.DataFrame( - [ - [ - *["mf1"] * 2, - *["mf2"] * 2, - *["mf3"] * 3, - ], - ["gcf1", "gcf3", "gcf2", "gcf3", "gcf1", "gcf2", "gcf3"], - [12, 11, 12, 11, 1, 1, 21], - ], - index=index_names, - ), - ) - - -@pytest.mark.parametrize( - "objects, expected", [([], "Empty input objects"), ("", "Empty input objects")] -) -def test_get_links_invalid_value(mc, objects, expected): - with pytest.raises(ValueError) as e: - mc._get_links(*objects) - assert expected in str(e.value) - - -@pytest.mark.parametrize( - "objects, expected", - [ - ([1], "Invalid type {}"), - ([1, 2], "Invalid type {}"), - ("12", "Invalid type {}"), - ], -) -def test__get_links_invalid_type(mc, objects, expected): - with pytest.raises(TypeError) as e: - mc._get_links(*objects) - assert expected in str(e.value) - - -def test__get_links_invalid_mixed_types(mc, spectra, mfs): - objects = (*spectra, *mfs) - with pytest.raises(TypeError) as e: - mc._get_links(*objects) - assert "Invalid type" in str(e.value) - assert ".MolecularFamily" in str(e.value) - assert ".Spectrum" in str(e.value) + lg = mc.get_links(*mfs, cutoff=0, standardised=True) + assert len(lg.links) == 7 diff --git a/tests/unit/scoring/test_utils.py b/tests/unit/scoring/test_utils.py index 40300a10..810c38d8 100644 --- a/tests/unit/scoring/test_utils.py +++ b/tests/unit/scoring/test_utils.py @@ -3,13 +3,6 @@ from nplinker.scoring.utils import get_presence_gcf_strain from nplinker.scoring.utils import get_presence_mf_strain from nplinker.scoring.utils import get_presence_spec_strain -from nplinker.scoring.utils import isinstance_all - - -def test_isinstance_all(): - assert isinstance_all(1, 2, 3, objtype=int) - assert not isinstance_all(1, 2, 3, objtype=str) - assert not isinstance_all(1, 2, "3", objtype=int) # @@ -23,8 +16,8 @@ def test_get_presence_gcf_strain(gcfs, strains): presence_gcf_strain, pd.DataFrame( [[1, 0, 0], [0, 1, 0], [1, 1, 0]], - index=["gcf1", "gcf2", "gcf3"], - columns=["strain1", "strain2", "strain3"], + index=gcfs, + columns=list(strains), ), ) @@ -33,11 +26,7 @@ def test_get_presence_spec_strain(spectra, strains): presence_spec_strain = get_presence_spec_strain(spectra, strains) assert_frame_equal( presence_spec_strain, - pd.DataFrame( - [[1, 0, 0], [0, 1, 0], [1, 1, 0]], - index=["spectrum1", "spectrum2", "spectrum3"], - columns=["strain1", "strain2", "strain3"], - ), + pd.DataFrame([[1, 0, 0], [0, 1, 0], [1, 1, 0]], index=spectra, columns=list(strains)), ) @@ -45,9 +34,5 @@ def test_get_presence_mf_strain(mfs, strains): presence_mf_strain = get_presence_mf_strain(mfs, strains) assert_frame_equal( presence_mf_strain, - pd.DataFrame( - [[1, 0, 0], [0, 1, 0], [1, 1, 0]], - index=["mf1", "mf2", "mf3"], - columns=["strain1", "strain2", "strain3"], - ), + pd.DataFrame([[1, 0, 0], [0, 1, 0], [1, 1, 0]], index=mfs, columns=list(strains)), )