diff --git a/malariagen_data/af1.py b/malariagen_data/af1.py index 546e45fb..bc0a6141 100644 --- a/malariagen_data/af1.py +++ b/malariagen_data/af1.py @@ -123,6 +123,7 @@ def __init__( taxon_colors=TAXON_COLORS, virtual_contigs=None, gene_names=None, + inversion_tag_path=None, ) def __repr__(self): diff --git a/malariagen_data/ag3.py b/malariagen_data/ag3.py index 316008a7..443c594b 100644 --- a/malariagen_data/ag3.py +++ b/malariagen_data/ag3.py @@ -2,18 +2,11 @@ import dask import pandas as pd # type: ignore -from pandas import CategoricalDtype -import numpy as np # type: ignore -import allel # type: ignore import plotly.express as px # type: ignore import malariagen_data from .anopheles import AnophelesDataResource -from numpydoc_decorator import doc -from .util import check_types, _karyotype_tags_n_alt -from .anoph import base_params -from typing import Optional, Literal, Annotated, TypeAlias # silence dask performance warnings dask.config.set(**{"array.slicing.split_native_chunks": False}) # type: ignore @@ -35,6 +28,7 @@ GENE_NAMES = { "AGAP004707": "Vgsc/para", } +INVERSION_TAG_PATH = "karyotype_tag_snps.csv" def _setup_aim_palettes(): @@ -83,12 +77,6 @@ def _setup_aim_palettes(): } -inversion_param: TypeAlias = Annotated[ - Literal["2La", "2Rb", "2Rc_gam", "2Rc_col", "2Rd", "2Rj"], - "Name of inversion to infer karyotype for.", -] - - class Ag3(AnophelesDataResource): """Provides access to data from Ag3.x releases. @@ -203,6 +191,7 @@ def __init__( taxon_colors=TAXON_COLORS, virtual_contigs=VIRTUAL_CONTIGS, gene_names=GENE_NAMES, + inversion_tag_path=INVERSION_TAG_PATH, ) # set up caches @@ -355,82 +344,3 @@ def _results_cache_add_analysis_params(self, params): super()._results_cache_add_analysis_params(params) # override parent class to add AIM analysis params["aim_analysis"] = self._aim_analysis - - @check_types - @doc( - summary="Load tag SNPs for a given inversion in Ag.", - ) - def load_inversion_tags(self, inversion: inversion_param) -> pd.DataFrame: - # needs to be modified depending on where we are hosting - import importlib.resources - from . import resources - - with importlib.resources.path(resources, "karyotype_tag_snps.csv") as path: - df_tag_snps = pd.read_csv(path, sep=",") - return df_tag_snps.query(f"inversion == '{inversion}'").reset_index() - - @check_types - @doc( - summary="Infer karyotype from tag SNPs for a given inversion in Ag.", - ) - def karyotype( - self, - inversion: inversion_param, - sample_sets: Optional[base_params.sample_sets] = None, - sample_query: Optional[base_params.sample_query] = None, - sample_query_options: Optional[base_params.sample_query_options] = None, - ) -> pd.DataFrame: - # load tag snp data - df_tagsnps = self.load_inversion_tags(inversion=inversion) - inversion_pos = df_tagsnps["position"] - inversion_alts = df_tagsnps["alt_allele"] - contig = inversion[0:2] - - # get snp calls for inversion region - start, end = np.min(inversion_pos), np.max(inversion_pos) - region = f"{contig}:{start}-{end}" - - ds_snps = self.snp_calls( - region=region, - sample_sets=sample_sets, - sample_query=sample_query, - sample_query_options=sample_query_options, - ) - - with self._spinner("Inferring karyotype from tag SNPs"): - # access variables we need - geno = allel.GenotypeDaskArray(ds_snps["call_genotype"].data) - pos = allel.SortedIndex(ds_snps["variant_position"].values) - samples = ds_snps["sample_id"].values - alts = ds_snps["variant_allele"].values.astype(str) - - # subset to position of inversion tags - mask = pos.locate_intersection(inversion_pos)[0] - alts = alts[mask] - geno = geno.compress(mask, axis=0).compute() - - # infer karyotype - gn_alt = _karyotype_tags_n_alt( - gt=geno, alts=alts, inversion_alts=inversion_alts - ) - is_called = geno.is_called() - - # calculate mean genotype for each sample whilst masking missing calls - av_gts = np.mean(np.ma.MaskedArray(gn_alt, mask=~is_called), axis=0) - total_sites = np.sum(is_called, axis=0) - - df = pd.DataFrame( - { - "sample_id": samples, - "inversion": inversion, - f"karyotype_{inversion}_mean": av_gts, - # round the genotypes then convert to int - f"karyotype_{inversion}": av_gts.round().astype(int), - "total_tag_snps": total_sites, - }, - ) - # Allow filling missing values with "" visible placeholder. - kt_dtype = CategoricalDtype(categories=[0, 1, 2, ""], ordered=True) - df[f"karyotype_{inversion}"] = df[f"karyotype_{inversion}"].astype(kt_dtype) - - return df diff --git a/malariagen_data/anoph/karyotype.py b/malariagen_data/anoph/karyotype.py new file mode 100644 index 00000000..2921d19a --- /dev/null +++ b/malariagen_data/anoph/karyotype.py @@ -0,0 +1,131 @@ +import pandas as pd # type: ignore +from pandas import CategoricalDtype +import numpy as np # type: ignore +import allel # type: ignore + +from numpydoc_decorator import doc +from ..util import check_types +from . import base_params +from typing import Optional + +from .snp_data import AnophelesSnpData +from .karyotype_params import inversion_param + + +def _karyotype_tags_n_alt(gt, alts, inversion_alts): + # could be Numba'd for speed but was already quick (not many inversion tag snps) + n_sites = gt.shape[0] + n_samples = gt.shape[1] + + # create empty array + inv_n_alt = np.empty((n_sites, n_samples), dtype=np.int8) + + # for every site + for i in range(n_sites): + # find the index of the correct tag snp allele + tagsnp_index = np.where(alts[i] == inversion_alts[i])[0] + + for j in range(n_samples): + # count alleles which == tag snp allele and store + n_tag_alleles = np.sum(gt[i, j] == tagsnp_index[0]) + inv_n_alt[i, j] = n_tag_alleles + + return inv_n_alt + + +class AnophelesKaryotypeAnalysis(AnophelesSnpData): + def __init__( + self, + inversion_tag_path: Optional[str] = None, + **kwargs, + ): + # N.B., this class is designed to work cooperatively, and + # so it's important that any remaining parameters are passed + # to the superclass constructor. + super().__init__(**kwargs) + + self._inversion_tag_path = inversion_tag_path + + @check_types + @doc( + summary="Load tag SNPs for a given inversion.", + ) + def load_inversion_tags(self, inversion: inversion_param) -> pd.DataFrame: + # needs to be modified depending on where we are hosting + import importlib.resources + from .. import resources + + if not self._inversion_tag_path: + raise FileNotFoundError( + "The file containing the inversion tags is missing." + ) + else: + with importlib.resources.path(resources, self._inversion_tag_path) as path: + df_tag_snps = pd.read_csv(path, sep=",") + return df_tag_snps.query(f"inversion == '{inversion}'").reset_index() + + @check_types + @doc( + summary="Infer karyotype from tag SNPs for a given inversion in Ag.", + ) + def karyotype( + self, + inversion: inversion_param, + sample_sets: Optional[base_params.sample_sets] = None, + sample_query: Optional[base_params.sample_query] = None, + sample_query_options: Optional[base_params.sample_query_options] = None, + ) -> pd.DataFrame: + # load tag snp data + df_tagsnps = self.load_inversion_tags(inversion=inversion) + inversion_pos = df_tagsnps["position"] + inversion_alts = df_tagsnps["alt_allele"] + contig = inversion[0:2] + + # get snp calls for inversion region + start, end = np.min(inversion_pos), np.max(inversion_pos) + region = f"{contig}:{start}-{end}" + + ds_snps = self.snp_calls( + region=region, + sample_sets=sample_sets, + sample_query=sample_query, + sample_query_options=sample_query_options, + ) + + with self._spinner("Inferring karyotype from tag SNPs"): + # access variables we need + geno = allel.GenotypeDaskArray(ds_snps["call_genotype"].data) + pos = allel.SortedIndex(ds_snps["variant_position"].values) + samples = ds_snps["sample_id"].values + alts = ds_snps["variant_allele"].values.astype(str) + + # subset to position of inversion tags + mask = pos.locate_intersection(inversion_pos)[0] + alts = alts[mask] + geno = geno.compress(mask, axis=0).compute() + + # infer karyotype + gn_alt = _karyotype_tags_n_alt( + gt=geno, alts=alts, inversion_alts=inversion_alts + ) + is_called = geno.is_called() + + # calculate mean genotype for each sample whilst masking missing calls + av_gts = np.mean(np.ma.MaskedArray(gn_alt, mask=~is_called), axis=0) + total_sites = np.sum(is_called, axis=0) + + df = pd.DataFrame( + { + "sample_id": samples, + "inversion": inversion, + f"karyotype_{inversion}_mean": av_gts, + # round the genotypes then convert to int + f"karyotype_{inversion}": av_gts.round().astype(int), + "total_tag_snps": total_sites, + }, + ) + # Allow filling missing values with "" visible placeholder. + kt_dtype = CategoricalDtype(categories=[0, 1, 2, ""], ordered=True) + df[f"karyotype_{inversion}"] = df[f"karyotype_{inversion}"].astype(kt_dtype) + + return df diff --git a/malariagen_data/anoph/karyotype_params.py b/malariagen_data/anoph/karyotype_params.py new file mode 100644 index 00000000..e13eaffc --- /dev/null +++ b/malariagen_data/anoph/karyotype_params.py @@ -0,0 +1,9 @@ +"""Parameter definitions for karyotype analysis functions.""" + + +from typing_extensions import Annotated, TypeAlias + +inversion_param: TypeAlias = Annotated[ + str, + "Name of inversion to infer karyotype for.", +] diff --git a/malariagen_data/anopheles.py b/malariagen_data/anopheles.py index d932b39c..4a2f63df 100644 --- a/malariagen_data/anopheles.py +++ b/malariagen_data/anopheles.py @@ -29,6 +29,7 @@ plotly_params, xpehh_params, ) +from .anoph.karyotype import AnophelesKaryotypeAnalysis from .anoph.aim_data import AnophelesAimData from .anoph.base import AnophelesBase from .anoph.cnv_data import AnophelesCnvData @@ -94,6 +95,7 @@ class AnophelesDataResource( AnophelesPca, PlinkConverter, AnophelesIgv, + AnophelesKaryotypeAnalysis, AnophelesAimData, AnophelesHapData, AnophelesSnpData, @@ -138,6 +140,7 @@ def __init__( taxon_colors: Optional[Mapping[str, str]], virtual_contigs: Optional[Mapping[str, Sequence[str]]], gene_names: Optional[Mapping[str, str]], + inversion_tag_path: Optional[str], ): super().__init__( url=url, @@ -171,6 +174,7 @@ def __init__( taxon_colors=taxon_colors, virtual_contigs=virtual_contigs, gene_names=gene_names, + inversion_tag_path=inversion_tag_path, ) @property diff --git a/malariagen_data/util.py b/malariagen_data/util.py index 81ed4733..09e8fbe9 100644 --- a/malariagen_data/util.py +++ b/malariagen_data/util.py @@ -1591,27 +1591,6 @@ def distributed_client(): return client -def _karyotype_tags_n_alt(gt, alts, inversion_alts): - # could be Numba'd for speed but was already quick (not many inversion tag snps) - n_sites = gt.shape[0] - n_samples = gt.shape[1] - - # create empty array - inv_n_alt = np.empty((n_sites, n_samples), dtype=np.int8) - - # for every site - for i in range(n_sites): - # find the index of the correct tag snp allele - tagsnp_index = np.where(alts[i] == inversion_alts[i])[0] - - for j in range(n_samples): - # count alleles which == tag snp allele and store - n_tag_alleles = np.sum(gt[i, j] == tagsnp_index[0]) - inv_n_alt[i, j] = n_tag_alleles - - return inv_n_alt - - def prep_samples_for_cohort_grouping(*, df_samples, area_by, period_by): # Take a copy, as we will modify the dataframe. df_samples = df_samples.copy() diff --git a/notebooks/karyotype.ipynb b/notebooks/karyotype.ipynb index d487c947..bd26b8a8 100644 --- a/notebooks/karyotype.ipynb +++ b/notebooks/karyotype.ipynb @@ -359,14 +359,6 @@ "source": [ "ag3.plot_pca_coords(pca_df_2rc_col, color=\"karyotype_2Rc_col\", symbol=\"country\", width=600, height=500)" ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d6fb7237-bb4e-490e-9ae0-33a52d4fa650", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { @@ -391,7 +383,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.10.11" } }, "nbformat": 4, diff --git a/tests/integration/test_af1.py b/tests/integration/test_af1.py index 869ba0f0..fcad34a5 100644 --- a/tests/integration/test_af1.py +++ b/tests/integration/test_af1.py @@ -75,3 +75,18 @@ def test_locate_region(region_raw): assert region == Region("2RL", 48714463, 48715355) if region_raw == "2RL:24,630,355-24,633,221": assert region == Region("2RL", 24630355, 24633221) + + +@pytest.mark.parametrize( + "inversion", + ["2La", "2Rb", "2Rc_col", "X_x"], +) +def test_karyotyping(inversion): + af1 = setup_af1() + + with pytest.raises(FileNotFoundError): + af1.karyotype( + inversion=inversion, + sample_sets="1229-VO-GH-DADZIE-VMF00095", + sample_query=None, + ) diff --git a/tests/integration/test_ag3.py b/tests/integration/test_ag3.py index 8536a56f..5ee539d2 100644 --- a/tests/integration/test_ag3.py +++ b/tests/integration/test_ag3.py @@ -159,19 +159,30 @@ def test_xpehh_gwss(): assert_allclose(xpehh[:, 2][100], 0.4817561326426265) -def test_karyotyping(): +@pytest.mark.parametrize( + "inversion", + ["2La", "2Rb", "2Rc_col", "X_x"], +) +def test_karyotyping(inversion): ag3 = setup_ag3(cohorts_analysis="20230516") - df = ag3.karyotype(inversion="2La", sample_sets="AG1000G-GH", sample_query=None) - - assert isinstance(df, pd.DataFrame) - expected_cols = [ - "sample_id", - "inversion", - "karyotype_2La_mean", - "karyotype_2La", - "total_tag_snps", - ] - assert set(df.columns) == set(expected_cols) - assert all(df["karyotype_2La"].isin([0, 1, 2])) - assert all(df["karyotype_2La_mean"].between(0, 2)) + if inversion == "X_x": + with pytest.raises(ValueError): + ag3.karyotype( + inversion=inversion, sample_sets="AG1000G-GH", sample_query=None + ) + else: + df = ag3.karyotype( + inversion=inversion, sample_sets="AG1000G-GH", sample_query=None + ) + assert isinstance(df, pd.DataFrame) + expected_cols = [ + "sample_id", + "inversion", + f"karyotype_{inversion}_mean", + f"karyotype_{inversion}", + "total_tag_snps", + ] + assert set(df.columns) == set(expected_cols) + assert all(df[f"karyotype_{inversion}"].isin([0, 1, 2])) + assert all(df[f"karyotype_{inversion}_mean"].between(0, 2))