diff --git a/malariagen_data/anoph/cnv_frq.py b/malariagen_data/anoph/cnv_frq.py index 29e83361..fda01aa6 100644 --- a/malariagen_data/anoph/cnv_frq.py +++ b/malariagen_data/anoph/cnv_frq.py @@ -10,6 +10,11 @@ from numpydoc_decorator import doc # type: ignore from . import base_params, cnv_params, frq_params +from .frq_funcs import ( + prep_samples_for_cohort_grouping, + build_cohorts_from_sample_grouping, + add_frequency_ci, +) from ..util import ( check_types, pandas_apply, @@ -17,17 +22,13 @@ parse_multi_region, region_str, simple_xarray_concat, - prep_samples_for_cohort_grouping, - build_cohorts_from_sample_grouping, - add_frequency_ci, ) from .cnv_data import AnophelesCnvData +from .frq_funcs import AnophelesFrequency from .sample_metadata import locate_cohorts -class AnophelesCnvFrequencyAnalysis( - AnophelesCnvData, -): +class AnophelesCnvFrequencyAnalysis(AnophelesCnvData, AnophelesFrequency): def __init__( self, **kwargs, diff --git a/malariagen_data/anoph/frq_funcs.py b/malariagen_data/anoph/frq_funcs.py new file mode 100644 index 00000000..425911a6 --- /dev/null +++ b/malariagen_data/anoph/frq_funcs.py @@ -0,0 +1,584 @@ +import numpy as np +import pandas as pd +import xarray as xr +import plotly.express as px +from textwrap import dedent +from typing import Optional, Union, List +from numpydoc_decorator import doc # type: ignore +from . import ( + plotly_params, + frq_params, + map_params, +) +from ..util import check_types +from .base import AnophelesBase + + +def prep_samples_for_cohort_grouping(*, df_samples, area_by, period_by): + # Take a copy, as we will modify the dataframe. + df_samples = df_samples.copy() + + # Fix "intermediate" or "unassigned" taxon values - we only want to build + # cohorts with clean taxon calls, so we set other values to None. + loc_intermediate_taxon = ( + df_samples["taxon"].str.startswith("intermediate").fillna(False) + ) + df_samples.loc[loc_intermediate_taxon, "taxon"] = None + loc_unassigned_taxon = ( + df_samples["taxon"].str.startswith("unassigned").fillna(False) + ) + df_samples.loc[loc_unassigned_taxon, "taxon"] = None + + # Add period column. + if period_by == "year": + make_period = _make_sample_period_year + elif period_by == "quarter": + make_period = _make_sample_period_quarter + elif period_by == "month": + make_period = _make_sample_period_month + else: # pragma: no cover + raise ValueError( + f"Value for period_by parameter must be one of 'year', 'quarter', 'month'; found {period_by!r}." + ) + sample_period = df_samples.apply(make_period, axis="columns") + df_samples["period"] = sample_period + + # Add area column for consistent output. + df_samples["area"] = df_samples[area_by] + + return df_samples + + +def build_cohorts_from_sample_grouping(*, group_samples_by_cohort, min_cohort_size): + # Build cohorts dataframe. + df_cohorts = group_samples_by_cohort.agg( + size=("sample_id", len), + lat_mean=("latitude", "mean"), + lat_max=("latitude", "max"), + lat_min=("latitude", "min"), + lon_mean=("longitude", "mean"), + lon_max=("longitude", "max"), + lon_min=("longitude", "min"), + ) + # Reset index so that the index fields are included as columns. + df_cohorts = df_cohorts.reset_index() + + # Add cohort helper variables. + cohort_period_start = df_cohorts["period"].apply(lambda v: v.start_time) + cohort_period_end = df_cohorts["period"].apply(lambda v: v.end_time) + df_cohorts["period_start"] = cohort_period_start + df_cohorts["period_end"] = cohort_period_end + # Create a label that is similar to the cohort metadata, + # although this won't be perfect. + df_cohorts["label"] = df_cohorts.apply( + lambda v: f"{v.area}_{v.taxon[:4]}_{v.period}", axis="columns" + ) + + # Apply minimum cohort size. + df_cohorts = df_cohorts.query(f"size >= {min_cohort_size}").reset_index(drop=True) + + # Early check for no cohorts. + if len(df_cohorts) == 0: + raise ValueError( + "No cohorts available for the given sample selection parameters and minimum cohort size." + ) + + return df_cohorts + + +def add_frequency_ci(*, ds, ci_method): + from statsmodels.stats.proportion import proportion_confint # type: ignore + + if ci_method is not None: + count = ds["event_count"].values + nobs = ds["event_nobs"].values + with np.errstate(divide="ignore", invalid="ignore"): + frq_ci_low, frq_ci_upp = proportion_confint( + count=count, nobs=nobs, method=ci_method + ) + ds["event_frequency_ci_low"] = ("variants", "cohorts"), frq_ci_low + ds["event_frequency_ci_upp"] = ("variants", "cohorts"), frq_ci_upp + + +def _make_sample_period_month(row): + year = row.year + month = row.month + if year > 0 and month > 0: + return pd.Period(freq="M", year=year, month=month) + else: + return pd.NaT + + +def _make_sample_period_quarter(row): + year = row.year + month = row.month + if year > 0 and month > 0: + return pd.Period(freq="Q", year=year, month=month) + else: + return pd.NaT + + +def _make_sample_period_year(row): + year = row.year + if year > 0: + return pd.Period(freq="Y", year=year) + else: + return pd.NaT + + +class AnophelesFrequency(AnophelesBase): + def __init__( + self, + **kwargs, + ): + # N.B., this class is designed to work cooperatively, and + # so it's important that any remaining parameters are passed + # to the superclass constructor. + super().__init__(**kwargs) + + @check_types + @doc( + summary=""" + Plot a heatmap from a pandas DataFrame of frequencies, e.g., output + from `snp_allele_frequencies()` or `gene_cnv_frequencies()`. + """, + parameters=dict( + df=""" + A DataFrame of frequencies, e.g., output from + `snp_allele_frequencies()` or `gene_cnv_frequencies()`. + """, + index=""" + One or more column headers that are present in the input dataframe. + This becomes the heatmap y-axis row labels. The column/s must + produce a unique index. + """, + max_len=""" + Displaying large styled dataframes may cause ipython notebooks to + crash. If the input dataframe is larger than this value, an error + will be raised. + """, + col_width=""" + Plot width per column in pixels (px). + """, + row_height=""" + Plot height per row in pixels (px). + """, + kwargs=""" + Passed through to `px.imshow()`. + """, + ), + notes=""" + It's recommended to filter the input DataFrame to just rows of interest, + i.e., fewer rows than `max_len`. + """, + ) + def plot_frequencies_heatmap( + self, + df: pd.DataFrame, + index: Optional[Union[str, List[str]]] = "label", + max_len: Optional[int] = 100, + col_width: int = 40, + row_height: int = 20, + x_label: plotly_params.x_label = "Cohorts", + y_label: plotly_params.y_label = "Variants", + colorbar: plotly_params.colorbar = True, + width: plotly_params.fig_width = None, + height: plotly_params.fig_height = None, + text_auto: plotly_params.text_auto = ".0%", + aspect: plotly_params.aspect = "auto", + color_continuous_scale: plotly_params.color_continuous_scale = "Reds", + title: plotly_params.title = True, + show: plotly_params.show = True, + renderer: plotly_params.renderer = None, + **kwargs, + ) -> plotly_params.figure: + # Check len of input. + if max_len and len(df) > max_len: + raise ValueError( + dedent( + f""" + Input DataFrame is longer than max_len parameter value {max_len}, which means + that the plot is likely to be very large. If you really want to go ahead, + please rerun the function with max_len=None. + """ + ) + ) + + # Handle title. + if title is True: + title = df.attrs.get("title", None) + + # Indexing. + if index is None: + index = list(df.index.names) + df = df.reset_index().copy() + if isinstance(index, list): + index_col = ( + df[index] + .astype(str) + .apply( + lambda row: ", ".join([o for o in row if o is not None]), + axis="columns", + ) + ) + else: + assert isinstance(index, str) + index_col = df[index].astype(str) + + # Check that index is unique. + if not index_col.is_unique: + raise ValueError(f"{index} does not produce a unique index") + + # Drop and re-order columns. + frq_cols = [col for col in df.columns if col.startswith("frq_")] + + # Keep only freq cols. + heatmap_df = df[frq_cols].copy() + + # Set index. + heatmap_df.set_index(index_col, inplace=True) + + # Clean column names. + heatmap_df.columns = heatmap_df.columns.str.lstrip("frq_") + + # Deal with width and height. + if width is None: + width = 400 + col_width * len(heatmap_df.columns) + if colorbar: + width += 40 + if height is None: + height = 200 + row_height * len(heatmap_df) + if title is not None: + height += 40 + + # Plotly heatmap styling. + fig = px.imshow( + img=heatmap_df, + zmin=0, + zmax=1, + width=width, + height=height, + text_auto=text_auto, + aspect=aspect, + color_continuous_scale=color_continuous_scale, + title=title, + **kwargs, + ) + + fig.update_xaxes(side="bottom", tickangle=30) + if x_label is not None: + fig.update_xaxes(title=x_label) + if y_label is not None: + fig.update_yaxes(title=y_label) + fig.update_layout( + coloraxis_colorbar=dict( + title="Frequency", + tickvals=[0, 0.2, 0.4, 0.6, 0.8, 1.0], + ticktext=["0%", "20%", "40%", "60%", "80%", "100%"], + ) + ) + if not colorbar: + fig.update(layout_coloraxis_showscale=False) + + if show: # pragma: no cover + fig.show(renderer=renderer) + return None + else: + return fig + + @check_types + @doc( + summary="Create a time series plot of variant frequencies using plotly.", + parameters=dict( + ds=""" + A dataset of variant frequencies, such as returned by + `snp_allele_frequencies_advanced()`, + `aa_allele_frequencies_advanced()` or + `gene_cnv_frequencies_advanced()`. + """, + kwargs="Passed through to `px.line()`.", + ), + returns=""" + A plotly figure containing line graphs. The resulting figure will + have one panel per cohort, grouped into columns by taxon, and + grouped into rows by area. Markers and lines show frequencies of + variants. + """, + ) + def plot_frequencies_time_series( + self, + ds: xr.Dataset, + height: plotly_params.fig_height = None, + width: plotly_params.fig_width = None, + title: plotly_params.title = True, + legend_sizing: plotly_params.legend_sizing = "constant", + show: plotly_params.show = True, + renderer: plotly_params.renderer = None, + taxa: frq_params.taxa = None, + areas: frq_params.areas = None, + **kwargs, + ) -> plotly_params.figure: + # Handle title. + if title is True: + title = ds.attrs.get("title", None) + + # Extract cohorts into a dataframe. + cohort_vars = [v for v in ds if str(v).startswith("cohort_")] + df_cohorts = ds[cohort_vars].to_dataframe() + df_cohorts.columns = [c.split("cohort_")[1] for c in df_cohorts.columns] # type: ignore + + # If specified, restrict the dataframe by taxa. + if isinstance(taxa, str): + df_cohorts = df_cohorts[df_cohorts["taxon"] == taxa] + elif isinstance(taxa, (list, tuple)): + df_cohorts = df_cohorts[df_cohorts["taxon"].isin(taxa)] + + # If specified, restrict the dataframe by areas. + if isinstance(areas, str): + df_cohorts = df_cohorts[df_cohorts["area"] == areas] + elif isinstance(areas, (list, tuple)): + df_cohorts = df_cohorts[df_cohorts["area"].isin(areas)] + + # Extract variant labels. + variant_labels = ds["variant_label"].values + + # Build a long-form dataframe from the dataset. + dfs = [] + for cohort_index, cohort in enumerate(df_cohorts.itertuples()): + ds_cohort = ds.isel(cohorts=cohort_index) + df = pd.DataFrame( + { + "taxon": cohort.taxon, + "area": cohort.area, + "date": cohort.period_start, + "period": str( + cohort.period + ), # use string representation for hover label + "sample_size": cohort.size, + "variant": variant_labels, + "count": ds_cohort["event_count"].values, + "nobs": ds_cohort["event_nobs"].values, + "frequency": ds_cohort["event_frequency"].values, + "frequency_ci_low": ds_cohort["event_frequency_ci_low"].values, + "frequency_ci_upp": ds_cohort["event_frequency_ci_upp"].values, + } + ) + dfs.append(df) + df_events = pd.concat(dfs, axis=0).reset_index(drop=True) + + # Remove events with no observations. + df_events = df_events.query("nobs > 0").copy() + + # Calculate error bars. + frq = df_events["frequency"] + frq_ci_low = df_events["frequency_ci_low"] + frq_ci_upp = df_events["frequency_ci_upp"] + df_events["frequency_error"] = frq_ci_upp - frq + df_events["frequency_error_minus"] = frq - frq_ci_low + + # Make a plot. + fig = px.line( + df_events, + facet_col="taxon", + facet_row="area", + x="date", + y="frequency", + error_y="frequency_error", + error_y_minus="frequency_error_minus", + color="variant", + markers=True, + hover_name="variant", + hover_data={ + "frequency": ":.0%", + "period": True, + "area": True, + "taxon": True, + "sample_size": True, + "date": False, + "variant": False, + }, + height=height, + width=width, + title=title, + labels={ + "date": "Date", + "frequency": "Frequency", + "variant": "Variant", + "taxon": "Taxon", + "area": "Area", + "period": "Period", + "sample_size": "Sample size", + }, + **kwargs, + ) + + # Tidy plot. + fig.update_layout( + yaxis_range=[-0.05, 1.05], + legend=dict(itemsizing=legend_sizing, tracegroupgap=0), + ) + + if show: # pragma: no cover + fig.show(renderer=renderer) + return None + else: + return fig + + @check_types + @doc( + summary=""" + Plot markers on a map showing variant frequencies for cohorts grouped + by area (space), period (time) and taxon. + """, + parameters=dict( + m="The map on which to add the markers.", + variant="Index or label of variant to plot.", + taxon="Taxon to show markers for.", + period="Time period to show markers for.", + clear=""" + If True, clear all layers (except the base layer) from the map + before adding new markers. + """, + ), + ) + def plot_frequencies_map_markers( + self, + m, + ds: frq_params.ds_frequencies_advanced, + variant: Union[int, str], + taxon: str, + period: pd.Period, + clear: bool = True, + ): + # Only import here because of some problems importing globally. + import ipyleaflet # type: ignore + import ipywidgets # type: ignore + + # Slice dataset to variant of interest. + if isinstance(variant, int): + ds_variant = ds.isel(variants=variant) + variant_label = ds["variant_label"].values[variant] + else: + assert isinstance(variant, str) + ds_variant = ds.set_index(variants="variant_label").sel(variants=variant) + variant_label = variant + + # Convert to a dataframe for convenience. + df_markers = ds_variant[ + [ + "cohort_taxon", + "cohort_area", + "cohort_period", + "cohort_lat_mean", + "cohort_lon_mean", + "cohort_size", + "event_frequency", + "event_frequency_ci_low", + "event_frequency_ci_upp", + ] + ].to_dataframe() + + # Select data matching taxon and period parameters. + df_markers = df_markers.loc[ + ( + (df_markers["cohort_taxon"] == taxon) + & (df_markers["cohort_period"] == period) + ) + ] + + # Clear existing layers in the map. + if clear: + for layer in m.layers[1:]: + m.remove_layer(layer) + + # Add markers. + for x in df_markers.itertuples(): + marker = ipyleaflet.CircleMarker() + marker.location = (x.cohort_lat_mean, x.cohort_lon_mean) + marker.radius = 20 + marker.color = "black" + marker.weight = 1 + marker.fill_color = "red" + marker.fill_opacity = x.event_frequency + popup_html = f""" + {variant_label}
+ Taxon: {x.cohort_taxon}
+ Area: {x.cohort_area}
+ Period: {x.cohort_period}
+ Sample size: {x.cohort_size}
+ Frequency: {x.event_frequency:.0%} + (95% CI: {x.event_frequency_ci_low:.0%} - {x.event_frequency_ci_upp:.0%}) + """ + marker.popup = ipyleaflet.Popup( + child=ipywidgets.HTML(popup_html), + auto_pan=False, + ) + m.add(marker) + + @check_types + @doc( + summary=""" + Create an interactive map with markers showing variant frequencies or + cohorts grouped by area (space), period (time) and taxon. + """, + parameters=dict( + title=""" + If True, attempt to use metadata from input dataset as a plot + title. Otherwise, use supplied value as a title. + """, + epilogue="Additional text to display below the map.", + ), + returns=""" + An interactive map with widgets for selecting which variant, taxon + and time period to display. + """, + ) + def plot_frequencies_interactive_map( + self, + ds: frq_params.ds_frequencies_advanced, + center: map_params.center = map_params.center_default, + zoom: map_params.zoom = map_params.zoom_default, + title: Optional[Union[bool, str]] = True, + epilogue: Union[bool, str] = True, + ): + import ipyleaflet + import ipywidgets + + # Handle title. + if title is True: + title = ds.attrs.get("title", None) + + # Create a map. + freq_map = ipyleaflet.Map(center=center, zoom=zoom) + + # Set up interactive controls. + variants = ds["variant_label"].values + taxa = ds["cohort_taxon"].to_pandas().dropna().unique() # type: ignore + periods = ds["cohort_period"].to_pandas().dropna().unique() # type: ignore + controls = ipywidgets.interactive( + self.plot_frequencies_map_markers, + m=ipywidgets.fixed(freq_map), + ds=ipywidgets.fixed(ds), + variant=ipywidgets.Dropdown(options=variants, description="Variant: "), + taxon=ipywidgets.Dropdown(options=taxa, description="Taxon: "), + period=ipywidgets.Dropdown(options=periods, description="Period: "), + clear=ipywidgets.fixed(True), + ) + + # Lay out widgets. + components = [] + if title is not None: + components.append(ipywidgets.HTML(value=f"

{title}

")) + components.append(controls) + components.append(freq_map) + if epilogue is True: + epilogue = """ + Variant frequencies are shown as coloured markers. Opacity of color + denotes frequency. Click on a marker for more information. + """ + if epilogue: + components.append(ipywidgets.HTML(value=f"{epilogue}")) + + out = ipywidgets.VBox(components) + + return out diff --git a/malariagen_data/anoph/hap_frq.py b/malariagen_data/anoph/hap_frq.py index 5dcf6ddb..b30e5318 100644 --- a/malariagen_data/anoph/hap_frq.py +++ b/malariagen_data/anoph/hap_frq.py @@ -9,18 +9,19 @@ from ..util import ( check_types, haplotype_frequencies, +) +from .hap_data import AnophelesHapData +from .frq_funcs import ( prep_samples_for_cohort_grouping, build_cohorts_from_sample_grouping, add_frequency_ci, ) -from .hap_data import AnophelesHapData from .sample_metadata import locate_cohorts +from .frq_funcs import AnophelesFrequency from . import base_params, frq_params -class AnophelesHapFrequencyAnalysis( - AnophelesHapData, -): +class AnophelesHapFrequencyAnalysis(AnophelesHapData, AnophelesFrequency): def __init__( self, **kwargs, @@ -119,7 +120,7 @@ def haplotypes_frequencies( df_haps_sorted["label"] = ["H" + str(i) for i in range(len(df_haps_sorted))] # Reset index after filtering. - df_haps_sorted.set_index(keys="label", drop=True) + df_haps_sorted.set_index(keys="label", drop=True, inplace=True) return df_haps_sorted diff --git a/malariagen_data/anoph/snp_frq.py b/malariagen_data/anoph/snp_frq.py index cd5304d7..0c1f05d3 100644 --- a/malariagen_data/anoph/snp_frq.py +++ b/malariagen_data/anoph/snp_frq.py @@ -1,6 +1,5 @@ from typing import Optional, Dict, Union, Callable, List import warnings -from textwrap import dedent import allel # type: ignore import numpy as np @@ -8,19 +7,21 @@ from numpydoc_decorator import doc # type: ignore import xarray as xr import numba # type: ignore -import plotly.express as px # type: ignore from .. import veff from ..util import ( check_types, pandas_apply, +) +from .snp_data import AnophelesSnpData +from .frq_funcs import ( prep_samples_for_cohort_grouping, build_cohorts_from_sample_grouping, add_frequency_ci, ) -from .snp_data import AnophelesSnpData from .sample_metadata import locate_cohorts -from . import base_params, frq_params, map_params, plotly_params +from .frq_funcs import AnophelesFrequency +from . import base_params, frq_params AA_CHANGE_QUERY = ( @@ -28,9 +29,7 @@ ) -class AnophelesSnpFrequencyAnalysis( - AnophelesSnpData, -): +class AnophelesSnpFrequencyAnalysis(AnophelesSnpData, AnophelesFrequency): def __init__( self, **kwargs, @@ -775,453 +774,6 @@ def aa_allele_frequencies_advanced( return ds_aa_frq - @check_types - @doc( - summary=""" - Plot a heatmap from a pandas DataFrame of frequencies, e.g., output - from `snp_allele_frequencies()` or `gene_cnv_frequencies()`. - """, - parameters=dict( - df=""" - A DataFrame of frequencies, e.g., output from - `snp_allele_frequencies()` or `gene_cnv_frequencies()`. - """, - index=""" - One or more column headers that are present in the input dataframe. - This becomes the heatmap y-axis row labels. The column/s must - produce a unique index. - """, - max_len=""" - Displaying large styled dataframes may cause ipython notebooks to - crash. If the input dataframe is larger than this value, an error - will be raised. - """, - col_width=""" - Plot width per column in pixels (px). - """, - row_height=""" - Plot height per row in pixels (px). - """, - kwargs=""" - Passed through to `px.imshow()`. - """, - ), - notes=""" - It's recommended to filter the input DataFrame to just rows of interest, - i.e., fewer rows than `max_len`. - """, - ) - def plot_frequencies_heatmap( - self, - df: pd.DataFrame, - index: Optional[Union[str, List[str]]] = "label", - max_len: Optional[int] = 100, - col_width: int = 40, - row_height: int = 20, - x_label: plotly_params.x_label = "Cohorts", - y_label: plotly_params.y_label = "Variants", - colorbar: plotly_params.colorbar = True, - width: plotly_params.fig_width = None, - height: plotly_params.fig_height = None, - text_auto: plotly_params.text_auto = ".0%", - aspect: plotly_params.aspect = "auto", - color_continuous_scale: plotly_params.color_continuous_scale = "Reds", - title: plotly_params.title = True, - show: plotly_params.show = True, - renderer: plotly_params.renderer = None, - **kwargs, - ) -> plotly_params.figure: - # Check len of input. - if max_len and len(df) > max_len: - raise ValueError( - dedent( - f""" - Input DataFrame is longer than max_len parameter value {max_len}, which means - that the plot is likely to be very large. If you really want to go ahead, - please rerun the function with max_len=None. - """ - ) - ) - - # Handle title. - if title is True: - title = df.attrs.get("title", None) - - # Indexing. - if index is None: - index = list(df.index.names) - df = df.reset_index().copy() - if isinstance(index, list): - index_col = ( - df[index] - .astype(str) - .apply( - lambda row: ", ".join([o for o in row if o is not None]), - axis="columns", - ) - ) - else: - assert isinstance(index, str) - index_col = df[index].astype(str) - - # Check that index is unique. - if not index_col.is_unique: - raise ValueError(f"{index} does not produce a unique index") - - # Drop and re-order columns. - frq_cols = [col for col in df.columns if col.startswith("frq_")] - - # Keep only freq cols. - heatmap_df = df[frq_cols].copy() - - # Set index. - heatmap_df.set_index(index_col, inplace=True) - - # Clean column names. - heatmap_df.columns = heatmap_df.columns.str.lstrip("frq_") - - # Deal with width and height. - if width is None: - width = 400 + col_width * len(heatmap_df.columns) - if colorbar: - width += 40 - if height is None: - height = 200 + row_height * len(heatmap_df) - if title is not None: - height += 40 - - # Plotly heatmap styling. - fig = px.imshow( - img=heatmap_df, - zmin=0, - zmax=1, - width=width, - height=height, - text_auto=text_auto, - aspect=aspect, - color_continuous_scale=color_continuous_scale, - title=title, - **kwargs, - ) - - fig.update_xaxes(side="bottom", tickangle=30) - if x_label is not None: - fig.update_xaxes(title=x_label) - if y_label is not None: - fig.update_yaxes(title=y_label) - fig.update_layout( - coloraxis_colorbar=dict( - title="Frequency", - tickvals=[0, 0.2, 0.4, 0.6, 0.8, 1.0], - ticktext=["0%", "20%", "40%", "60%", "80%", "100%"], - ) - ) - if not colorbar: - fig.update(layout_coloraxis_showscale=False) - - if show: # pragma: no cover - fig.show(renderer=renderer) - return None - else: - return fig - - @check_types - @doc( - summary="Create a time series plot of variant frequencies using plotly.", - parameters=dict( - ds=""" - A dataset of variant frequencies, such as returned by - `snp_allele_frequencies_advanced()`, - `aa_allele_frequencies_advanced()` or - `gene_cnv_frequencies_advanced()`. - """, - kwargs="Passed through to `px.line()`.", - ), - returns=""" - A plotly figure containing line graphs. The resulting figure will - have one panel per cohort, grouped into columns by taxon, and - grouped into rows by area. Markers and lines show frequencies of - variants. - """, - ) - def plot_frequencies_time_series( - self, - ds: xr.Dataset, - height: plotly_params.fig_height = None, - width: plotly_params.fig_width = None, - title: plotly_params.title = True, - legend_sizing: plotly_params.legend_sizing = "constant", - show: plotly_params.show = True, - renderer: plotly_params.renderer = None, - taxa: frq_params.taxa = None, - areas: frq_params.areas = None, - **kwargs, - ) -> plotly_params.figure: - # Handle title. - if title is True: - title = ds.attrs.get("title", None) - - # Extract cohorts into a dataframe. - cohort_vars = [v for v in ds if str(v).startswith("cohort_")] - df_cohorts = ds[cohort_vars].to_dataframe() - df_cohorts.columns = [c.split("cohort_")[1] for c in df_cohorts.columns] # type: ignore - - # If specified, restrict the dataframe by taxa. - if isinstance(taxa, str): - df_cohorts = df_cohorts[df_cohorts["taxon"] == taxa] - elif isinstance(taxa, (list, tuple)): - df_cohorts = df_cohorts[df_cohorts["taxon"].isin(taxa)] - - # If specified, restrict the dataframe by areas. - if isinstance(areas, str): - df_cohorts = df_cohorts[df_cohorts["area"] == areas] - elif isinstance(areas, (list, tuple)): - df_cohorts = df_cohorts[df_cohorts["area"].isin(areas)] - - # Extract variant labels. - variant_labels = ds["variant_label"].values - - # Build a long-form dataframe from the dataset. - dfs = [] - for cohort_index, cohort in enumerate(df_cohorts.itertuples()): - ds_cohort = ds.isel(cohorts=cohort_index) - df = pd.DataFrame( - { - "taxon": cohort.taxon, - "area": cohort.area, - "date": cohort.period_start, - "period": str( - cohort.period - ), # use string representation for hover label - "sample_size": cohort.size, - "variant": variant_labels, - "count": ds_cohort["event_count"].values, - "nobs": ds_cohort["event_nobs"].values, - "frequency": ds_cohort["event_frequency"].values, - "frequency_ci_low": ds_cohort["event_frequency_ci_low"].values, - "frequency_ci_upp": ds_cohort["event_frequency_ci_upp"].values, - } - ) - dfs.append(df) - df_events = pd.concat(dfs, axis=0).reset_index(drop=True) - - # Remove events with no observations. - df_events = df_events.query("nobs > 0").copy() - - # Calculate error bars. - frq = df_events["frequency"] - frq_ci_low = df_events["frequency_ci_low"] - frq_ci_upp = df_events["frequency_ci_upp"] - df_events["frequency_error"] = frq_ci_upp - frq - df_events["frequency_error_minus"] = frq - frq_ci_low - - # Make a plot. - fig = px.line( - df_events, - facet_col="taxon", - facet_row="area", - x="date", - y="frequency", - error_y="frequency_error", - error_y_minus="frequency_error_minus", - color="variant", - markers=True, - hover_name="variant", - hover_data={ - "frequency": ":.0%", - "period": True, - "area": True, - "taxon": True, - "sample_size": True, - "date": False, - "variant": False, - }, - height=height, - width=width, - title=title, - labels={ - "date": "Date", - "frequency": "Frequency", - "variant": "Variant", - "taxon": "Taxon", - "area": "Area", - "period": "Period", - "sample_size": "Sample size", - }, - **kwargs, - ) - - # Tidy plot. - fig.update_layout( - yaxis_range=[-0.05, 1.05], - legend=dict(itemsizing=legend_sizing, tracegroupgap=0), - ) - - if show: # pragma: no cover - fig.show(renderer=renderer) - return None - else: - return fig - - @check_types - @doc( - summary=""" - Plot markers on a map showing variant frequencies for cohorts grouped - by area (space), period (time) and taxon. - """, - parameters=dict( - m="The map on which to add the markers.", - variant="Index or label of variant to plot.", - taxon="Taxon to show markers for.", - period="Time period to show markers for.", - clear=""" - If True, clear all layers (except the base layer) from the map - before adding new markers. - """, - ), - ) - def plot_frequencies_map_markers( - self, - m, - ds: frq_params.ds_frequencies_advanced, - variant: Union[int, str], - taxon: str, - period: pd.Period, - clear: bool = True, - ): - # Only import here because of some problems importing globally. - import ipyleaflet # type: ignore - import ipywidgets # type: ignore - - # Slice dataset to variant of interest. - if isinstance(variant, int): - ds_variant = ds.isel(variants=variant) - variant_label = ds["variant_label"].values[variant] - else: - assert isinstance(variant, str) - ds_variant = ds.set_index(variants="variant_label").sel(variants=variant) - variant_label = variant - - # Convert to a dataframe for convenience. - df_markers = ds_variant[ - [ - "cohort_taxon", - "cohort_area", - "cohort_period", - "cohort_lat_mean", - "cohort_lon_mean", - "cohort_size", - "event_frequency", - "event_frequency_ci_low", - "event_frequency_ci_upp", - ] - ].to_dataframe() - - # Select data matching taxon and period parameters. - df_markers = df_markers.loc[ - ( - (df_markers["cohort_taxon"] == taxon) - & (df_markers["cohort_period"] == period) - ) - ] - - # Clear existing layers in the map. - if clear: - for layer in m.layers[1:]: - m.remove_layer(layer) - - # Add markers. - for x in df_markers.itertuples(): - marker = ipyleaflet.CircleMarker() - marker.location = (x.cohort_lat_mean, x.cohort_lon_mean) - marker.radius = 20 - marker.color = "black" - marker.weight = 1 - marker.fill_color = "red" - marker.fill_opacity = x.event_frequency - popup_html = f""" - {variant_label}
- Taxon: {x.cohort_taxon}
- Area: {x.cohort_area}
- Period: {x.cohort_period}
- Sample size: {x.cohort_size}
- Frequency: {x.event_frequency:.0%} - (95% CI: {x.event_frequency_ci_low:.0%} - {x.event_frequency_ci_upp:.0%}) - """ - marker.popup = ipyleaflet.Popup( - child=ipywidgets.HTML(popup_html), - auto_pan=False, - ) - m.add(marker) - - @check_types - @doc( - summary=""" - Create an interactive map with markers showing variant frequencies or - cohorts grouped by area (space), period (time) and taxon. - """, - parameters=dict( - title=""" - If True, attempt to use metadata from input dataset as a plot - title. Otherwise, use supplied value as a title. - """, - epilogue="Additional text to display below the map.", - ), - returns=""" - An interactive map with widgets for selecting which variant, taxon - and time period to display. - """, - ) - def plot_frequencies_interactive_map( - self, - ds: frq_params.ds_frequencies_advanced, - center: map_params.center = map_params.center_default, - zoom: map_params.zoom = map_params.zoom_default, - title: Union[bool, str] = True, - epilogue: Union[bool, str] = True, - ): - import ipyleaflet - import ipywidgets - - # Handle title. - if title is True: - title = ds.attrs.get("title", None) - - # Create a map. - freq_map = ipyleaflet.Map(center=center, zoom=zoom) - - # Set up interactive controls. - variants = ds["variant_label"].values - taxa = ds["cohort_taxon"].to_pandas().dropna().unique() # type: ignore - periods = ds["cohort_period"].to_pandas().dropna().unique() # type: ignore - controls = ipywidgets.interactive( - self.plot_frequencies_map_markers, - m=ipywidgets.fixed(freq_map), - ds=ipywidgets.fixed(ds), - variant=ipywidgets.Dropdown(options=variants, description="Variant: "), - taxon=ipywidgets.Dropdown(options=taxa, description="Taxon: "), - period=ipywidgets.Dropdown(options=periods, description="Period: "), - clear=ipywidgets.fixed(True), - ) - - # Lay out widgets. - components = [] - if title is not None: - components.append(ipywidgets.HTML(value=f"

{title}

")) - components.append(controls) - components.append(freq_map) - if epilogue is True: - epilogue = """ - Variant frequencies are shown as coloured markers. Opacity of color - denotes frequency. Click on a marker for more information. - """ - if epilogue: - components.append(ipywidgets.HTML(value=f"{epilogue}")) - - out = ipywidgets.VBox(components) - - return out - def snp_genotype_allele_counts( self, transcript: base_params.transcript, diff --git a/tests/anoph/conftest.py b/tests/anoph/conftest.py index 1564b029..5927debb 100644 --- a/tests/anoph/conftest.py +++ b/tests/anoph/conftest.py @@ -1013,6 +1013,14 @@ def contigs(self) -> Tuple[str, ...]: def random_contig(self): return choice(self.contigs) + def random_transcript_id(self): + df_transcripts = self.genome_features.query("type == 'mRNA'") + transcript_ids = [ + t.split(";")[0].split("=")[1] for t in df_transcripts.loc[:, "attributes"] + ] + transcript_id = choice(transcript_ids) + return transcript_id + def random_region_str(self, region_size=None): contig = self.random_contig() contig_size = self.contig_sizes[contig] diff --git a/tests/anoph/test_cnv_frq.py b/tests/anoph/test_cnv_frq.py index ef96a400..a96cc0bf 100644 --- a/tests/anoph/test_cnv_frq.py +++ b/tests/anoph/test_cnv_frq.py @@ -11,6 +11,13 @@ from malariagen_data import ag3 as _ag3 from malariagen_data.anoph.cnv_frq import AnophelesCnvFrequencyAnalysis from malariagen_data.util import compare_series_like +from .test_frq import ( + test_plot_frequencies_heatmap, + test_plot_frequencies_time_series, + test_plot_frequencies_time_series_with_taxa, + test_plot_frequencies_time_series_with_areas, + test_plot_frequencies_interactive_map, +) @pytest.fixture @@ -109,6 +116,8 @@ def test_gene_cnv_frequencies_with_str_cohorts( # Run the function under test. df_cnv = api.gene_cnv_frequencies(**params) + test_plot_frequencies_heatmap(api, df_cnv) + # Figure out expected cohort labels. df_samples = api.sample_metadata(sample_sets=sample_sets) if "cohort_" + cohorts in df_samples: @@ -166,12 +175,14 @@ def test_gene_cnv_frequencies_with_min_cohort_size( return # Run the function under test. - df_snp = api.gene_cnv_frequencies(**params) + df_cnv = api.gene_cnv_frequencies(**params) + + test_plot_frequencies_heatmap(api, df_cnv) # Standard checks. check_gene_cnv_frequencies( api=api, - df=df_snp, + df=df_cnv, cohort_labels=cohort_labels, region=region, ) @@ -212,12 +223,14 @@ def test_gene_cnv_frequencies_with_str_cohorts_and_sample_query( ) # Run the function under test. - df_snp = api.gene_cnv_frequencies(**params) + df_cnv = api.gene_cnv_frequencies(**params) + + test_plot_frequencies_heatmap(api, df_cnv) # Standard checks. check_gene_cnv_frequencies( api=api, - df=df_snp, + df=df_cnv, cohort_labels=cohort_labels, region=region, ) @@ -268,12 +281,14 @@ def test_gene_cnv_frequencies_with_str_cohorts_and_sample_query_options( ) # Run the function under test. - df_snp = api.gene_cnv_frequencies(**params) + df_cnv = api.gene_cnv_frequencies(**params) + + test_plot_frequencies_heatmap(api, df_cnv) # Standard checks. check_gene_cnv_frequencies( api=api, - df=df_snp, + df=df_cnv, cohort_labels=cohort_labels, region=region, ) @@ -305,12 +320,14 @@ def test_gene_cnv_frequencies_with_dict_cohorts( ) # Run the function under test. - df_snp = api.gene_cnv_frequencies(**params) + df_cnv = api.gene_cnv_frequencies(**params) + + test_plot_frequencies_heatmap(api, df_cnv) # Standard checks. check_gene_cnv_frequencies( api=api, - df=df_snp, + df=df_cnv, cohort_labels=cohort_labels, region=region, ) @@ -350,6 +367,9 @@ def test_gene_cnv_frequencies_without_drop_invariant( df_cnv_a = api.gene_cnv_frequencies(drop_invariant=True, **params) df_cnv_b = api.gene_cnv_frequencies(drop_invariant=False, **params) + test_plot_frequencies_heatmap(api, df_cnv_a) + test_plot_frequencies_heatmap(api, df_cnv_b) + # Standard checks. check_gene_cnv_frequencies( api=api, @@ -418,6 +438,8 @@ def test_gene_cnv_frequencies_with_max_coverage_variance( # checks. df_cnv = api.gene_cnv_frequencies(**params) + test_plot_frequencies_heatmap(api, df_cnv) + # Figure out expected cohort labels. df_samples = api.sample_metadata(sample_sets=sample_sets) if "cohort_" + cohorts in df_samples: @@ -711,6 +733,10 @@ def check_gene_cnv_frequencies_advanced( # Check the result. assert isinstance(ds, xr.Dataset) + test_plot_frequencies_time_series(api, ds) + test_plot_frequencies_time_series_with_taxa(api, ds) + test_plot_frequencies_time_series_with_areas(api, ds) + test_plot_frequencies_interactive_map(api, ds) assert set(ds.dims) == {"cohorts", "variants"} # Check variant variables. diff --git a/tests/anoph/test_frq.py b/tests/anoph/test_frq.py new file mode 100644 index 00000000..dfba620c --- /dev/null +++ b/tests/anoph/test_frq.py @@ -0,0 +1,98 @@ +import pytest +import plotly.graph_objects as go # type: ignore + +import random + + +@pytest.mark.skip +def test_plot_frequencies_heatmap(api, frq_df): + fig = api.plot_frequencies_heatmap(frq_df, show=False, max_len=None) + assert isinstance(fig, go.Figure) + + # Test max_len behaviour. + with pytest.raises(ValueError): + api.plot_frequencies_heatmap(frq_df, show=False, max_len=len(frq_df) - 1) + + # Test index parameter - if None, should use dataframe index. + fig = api.plot_frequencies_heatmap(frq_df, show=False, index=None, max_len=None) + + if "contig" in list(frq_df.columns): + # Not unique. + with pytest.raises(ValueError): + api.plot_frequencies_heatmap( + frq_df, show=False, index="contig", max_len=None + ) + + +@pytest.mark.skip +def test_plot_frequencies_time_series(api, ds): + # Trim things down a bit for speed. + ds = ds.isel(variants=slice(0, 100)) + + # Plot. + fig = api.plot_frequencies_time_series(ds, show=False) + + # Test. + assert isinstance(fig, go.Figure) + + +@pytest.mark.skip +def test_plot_frequencies_time_series_with_taxa(api, ds): + # Trim things down a bit for speed. + ds = ds.isel(variants=slice(0, 100)) + + taxa = list(ds.cohort_taxon.to_dataframe()["cohort_taxon"].unique()) + taxon = random.choice(taxa) + + # Plot with taxon. + fig = api.plot_frequencies_time_series(ds, show=False, taxa=taxon) + + # Test taxon plot. + assert isinstance(fig, go.Figure) + + # Plot with taxa. + fig = api.plot_frequencies_time_series(ds, show=False, taxa=taxa) + + # Test taxa plot. + assert isinstance(fig, go.Figure) + + +@pytest.mark.skip +def test_plot_frequencies_time_series_with_areas(api, ds): + # Trim things down a bit for speed. + ds = ds.isel(variants=slice(0, 100)) + + # Extract cohorts into a DataFrame. + cohort_vars = [v for v in ds if str(v).startswith("cohort_")] + df_cohorts = ds[cohort_vars].to_dataframe() + + # Pick a random area and areas from valid areas. + cohorts_areas = df_cohorts["cohort_area"].dropna().unique().tolist() + area = random.choice(cohorts_areas) + areas = random.sample(cohorts_areas, random.randint(1, len(cohorts_areas))) + + # Plot with area. + fig = api.plot_frequencies_time_series(ds, show=False, areas=area) + + # Test areas plot. + assert isinstance(fig, go.Figure) + + # Plot with areas. + fig = api.plot_frequencies_time_series(ds, show=False, areas=areas) + + # Test area plot. + assert isinstance(fig, go.Figure) + + +@pytest.mark.skip +def test_plot_frequencies_interactive_map(api, ds): + import ipywidgets # type: ignore + + # Trim things down a bit for speed. + ds = ds.isel(variants=slice(0, 100)) + + # Plot. + fig = api.plot_frequencies_interactive_map(ds) + + # Test. + assert isinstance(fig, ipywidgets.Widget) diff --git a/tests/anoph/test_hap_frq.py b/tests/anoph/test_hap_frq.py index 7dc54346..a2077d6c 100644 --- a/tests/anoph/test_hap_frq.py +++ b/tests/anoph/test_hap_frq.py @@ -7,7 +7,15 @@ from pytest_cases import parametrize_with_cases from malariagen_data import ag3 as _ag3 +from malariagen_data import af1 as _af1 from malariagen_data.anoph.hap_frq import AnophelesHapFrequencyAnalysis +from .test_frq import ( + test_plot_frequencies_heatmap, + test_plot_frequencies_time_series, + test_plot_frequencies_time_series_with_taxa, + test_plot_frequencies_time_series_with_areas, + test_plot_frequencies_interactive_map, +) @pytest.fixture @@ -36,6 +44,23 @@ def ag3_sim_api(ag3_sim_fixture): ) +@pytest.fixture +def af1_sim_api(af1_sim_fixture): + return AnophelesHapFrequencyAnalysis( + url=af1_sim_fixture.url, + config_path=_af1.CONFIG_PATH, + major_version_number=_af1.MAJOR_VERSION_NUMBER, + major_version_path=_af1.MAJOR_VERSION_PATH, + pre=False, + gff_gene_type="protein_coding_gene", + gff_gene_name_attribute="Note", + gff_default_attributes=("ID", "Parent", "Note", "description"), + results_cache=af1_sim_fixture.results_cache_path.as_posix(), + taxon_colors=_af1.TAXON_COLORS, + default_phasing_analysis="funestus", + ) + + # N.B., here we use pytest_cases to parametrize tests. Each # function whose name begins with "case_" defines a set of # inputs to the test functions. See the documentation for @@ -52,6 +77,10 @@ def case_ag3_sim(ag3_sim_fixture, ag3_sim_api): return ag3_sim_fixture, ag3_sim_api +def case_af1_sim(af1_sim_fixture, af1_sim_api): + return af1_sim_fixture, af1_sim_api + + def check_frequency(x): loc_nan = np.isnan(x) assert np.all(x[~loc_nan] >= 0) @@ -70,9 +99,8 @@ def check_hap_frequencies(*, api, df, sample_sets, cohorts, min_cohort_size): cohort_counts = df_samples[cohort_column].value_counts() cohort_labels = cohort_counts[cohort_counts >= min_cohort_size].index.to_list() - universal_fields = ["label"] frq_fields = ["frq_" + s for s in cohort_labels] + ["max_af"] - expected_fields = universal_fields + frq_fields + expected_fields = frq_fields assert sorted(df.columns.tolist()) == sorted(expected_fields) @@ -82,6 +110,10 @@ def check_hap_frequencies_advanced( ds, ): assert isinstance(ds, xr.Dataset) + test_plot_frequencies_time_series(api, ds) + test_plot_frequencies_time_series_with_taxa(api, ds) + test_plot_frequencies_time_series_with_areas(api, ds) + test_plot_frequencies_interactive_map(api, ds) assert set(ds.dims) == {"cohorts", "variants"} expected_cohort_vars = [ @@ -154,6 +186,8 @@ def test_hap_frequencies_with_str_cohorts( # Run the function under test. df_hap = api.haplotypes_frequencies(**params) + test_plot_frequencies_heatmap(api, df_hap) + # Standard checks. check_hap_frequencies( api=api, diff --git a/tests/anoph/test_snp_frq.py b/tests/anoph/test_snp_frq.py index e18d9a19..412fdc21 100644 --- a/tests/anoph/test_snp_frq.py +++ b/tests/anoph/test_snp_frq.py @@ -7,12 +7,18 @@ from pytest_cases import parametrize_with_cases import xarray as xr from numpy.testing import assert_allclose, assert_array_equal -import plotly.graph_objects as go # type: ignore from malariagen_data import af1 as _af1 from malariagen_data import ag3 as _ag3 from malariagen_data.anoph.snp_frq import AnophelesSnpFrequencyAnalysis from malariagen_data.util import compare_series_like +from .test_frq import ( + test_plot_frequencies_heatmap, + test_plot_frequencies_time_series, + test_plot_frequencies_time_series_with_taxa, + test_plot_frequencies_time_series_with_areas, + test_plot_frequencies_interactive_map, +) @pytest.fixture @@ -315,6 +321,8 @@ def test_allele_frequencies_with_str_cohorts( # Run the function under test. df_snp = api.snp_allele_frequencies(**params) + test_plot_frequencies_heatmap(api, df_snp) + # Figure out expected cohort labels. df_samples = api.sample_metadata(sample_sets=sample_sets) if "cohort_" + cohorts in df_samples: @@ -335,6 +343,8 @@ def test_allele_frequencies_with_str_cohorts( # Run the function under test. df_aa = api.aa_allele_frequencies(**params) + test_plot_frequencies_heatmap(api, df_aa) + # Standard checks. check_aa_allele_frequencies( df=df_aa, @@ -387,6 +397,8 @@ def test_allele_frequencies_with_min_cohort_size( # Run the function under test. df_snp = api.snp_allele_frequencies(**params) + test_plot_frequencies_heatmap(api, df_snp) + # Standard checks. check_snp_allele_frequencies( api=api, @@ -398,6 +410,8 @@ def test_allele_frequencies_with_min_cohort_size( # Run the function under test. df_aa = api.aa_allele_frequencies(**params) + test_plot_frequencies_heatmap(api, df_aa) + # Standard checks. check_aa_allele_frequencies( df=df_aa, @@ -444,6 +458,8 @@ def test_allele_frequencies_with_str_cohorts_and_sample_query( # Run the function under test. df_snp = api.snp_allele_frequencies(**params) + test_plot_frequencies_heatmap(api, df_snp) + # Standard checks. check_snp_allele_frequencies( api=api, @@ -455,6 +471,8 @@ def test_allele_frequencies_with_str_cohorts_and_sample_query( # Run the function under test. df_aa = api.aa_allele_frequencies(**params) + test_plot_frequencies_heatmap(api, df_aa) + # Standard checks. check_aa_allele_frequencies( df=df_aa, @@ -511,6 +529,8 @@ def test_allele_frequencies_with_str_cohorts_and_sample_query_options( # Run the function under test. df_snp = api.snp_allele_frequencies(**params) + test_plot_frequencies_heatmap(api, df_snp) + # Standard checks. check_snp_allele_frequencies( api=api, @@ -522,6 +542,8 @@ def test_allele_frequencies_with_str_cohorts_and_sample_query_options( # Run the function under test. df_aa = api.aa_allele_frequencies(**params) + test_plot_frequencies_heatmap(api, df_aa) + # Standard checks. check_aa_allele_frequencies( df=df_aa, @@ -559,6 +581,8 @@ def test_allele_frequencies_with_dict_cohorts( # Run the function under test. df_snp = api.snp_allele_frequencies(**params) + test_plot_frequencies_heatmap(api, df_snp) + # Standard checks. check_snp_allele_frequencies( api=api, @@ -570,6 +594,8 @@ def test_allele_frequencies_with_dict_cohorts( # Run the function under test. df_aa = api.aa_allele_frequencies(**params) + test_plot_frequencies_heatmap(api, df_aa) + # Standard checks. check_aa_allele_frequencies( df=df_aa, @@ -613,6 +639,9 @@ def test_allele_frequencies_without_drop_invariant( df_snp_a = api.snp_allele_frequencies(drop_invariant=True, **params) df_snp_b = api.snp_allele_frequencies(drop_invariant=False, **params) + test_plot_frequencies_heatmap(api, df_snp_a) + test_plot_frequencies_heatmap(api, df_snp_b) + # Standard checks. check_snp_allele_frequencies( api=api, @@ -667,6 +696,9 @@ def test_allele_frequencies_without_effects( df_snp_a = api.snp_allele_frequencies(effects=True, **params) df_snp_b = api.snp_allele_frequencies(effects=False, **params) + test_plot_frequencies_heatmap(api, df_snp_a) + test_plot_frequencies_heatmap(api, df_snp_b) + # Standard checks. check_snp_allele_frequencies( api=api, @@ -766,6 +798,8 @@ def test_allele_frequencies_with_region( # Run the function under test. df_snp = api.snp_allele_frequencies(**params) + test_plot_frequencies_heatmap(api, df_snp) + # Basic checks. assert isinstance(df_snp, pd.DataFrame) assert len(df_snp) > 0 @@ -820,6 +854,9 @@ def test_allele_frequencies_with_dup_samples( sample_sets=[sample_set, sample_set], **params ) + test_plot_frequencies_heatmap(api, df_snp_a) + test_plot_frequencies_heatmap(api, df_snp_b) + # Expect automatically deduplicate sample sets. assert_frame_equal(df_snp_b, df_snp_a) @@ -876,6 +913,10 @@ def check_snp_allele_frequencies_advanced( # Check the result. assert isinstance(ds, xr.Dataset) + test_plot_frequencies_time_series(api, ds) + test_plot_frequencies_time_series_with_taxa(api, ds) + test_plot_frequencies_time_series_with_areas(api, ds) + test_plot_frequencies_interactive_map(api, ds) assert set(ds.dims) == {"cohorts", "variants"} # Check variant variables. @@ -1063,6 +1104,10 @@ def check_aa_allele_frequencies_advanced( # Check the result. assert isinstance(ds, xr.Dataset) + test_plot_frequencies_time_series(api, ds) + test_plot_frequencies_time_series_with_taxa(api, ds) + test_plot_frequencies_time_series_with_areas(api, ds) + test_plot_frequencies_interactive_map(api, ds) assert set(ds.dims) == {"cohorts", "variants"} expected_variant_vars = ( @@ -1429,248 +1474,3 @@ def test_allele_frequencies_advanced_with_dup_samples( api=api, sample_sets=sample_sets, ) - - -@parametrize_with_cases("fixture,api", cases=".") -def test_plot_frequencies_heatmap( - fixture, - api: AnophelesSnpFrequencyAnalysis, -): - # Pick test parameters at random. - all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) - site_mask = random.choice(api.site_mask_ids + (None,)) - min_cohort_size = random.randint(0, 2) - transcript = random_transcript(api=api).name - cohorts = random.choice( - ["admin1_year", "admin1_month", "admin2_year", "admin2_month"] - ) - - # Set up call params. - params = dict( - transcript=transcript, - cohorts=cohorts, - min_cohort_size=min_cohort_size, - site_mask=site_mask, - sample_sets=sample_sets, - ) - - # Test SNP allele frequencies. - df_snp = api.snp_allele_frequencies(**params) - fig = api.plot_frequencies_heatmap(df_snp, show=False, max_len=None) - assert isinstance(fig, go.Figure) - - # Test amino acid change allele frequencies. - df_aa = api.aa_allele_frequencies(**params) - fig = api.plot_frequencies_heatmap(df_aa, show=False, max_len=None) - assert isinstance(fig, go.Figure) - - # Test max_len behaviour. - with pytest.raises(ValueError): - api.plot_frequencies_heatmap(df_snp, show=False, max_len=len(df_snp) - 1) - - # Test index parameter - if None, should use dataframe index. - fig = api.plot_frequencies_heatmap(df_snp, show=False, index=None, max_len=None) - # Not unique. - with pytest.raises(ValueError): - api.plot_frequencies_heatmap(df_snp, show=False, index="contig", max_len=None) - - -@parametrize_with_cases("fixture,api", cases=".") -def test_plot_frequencies_time_series( - fixture, - api: AnophelesSnpFrequencyAnalysis, -): - # Pick test parameters at random. - all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) - site_mask = random.choice(api.site_mask_ids + (None,)) - min_cohort_size = random.randint(0, 2) - transcript = random_transcript(api=api).name - area_by = random.choice(["country", "admin1_iso", "admin2_name"]) - period_by = random.choice(["year", "quarter", "month"]) - - # Compute SNP frequencies. - ds = api.snp_allele_frequencies_advanced( - transcript=transcript, - area_by=area_by, - period_by=period_by, - sample_sets=sample_sets, - min_cohort_size=min_cohort_size, - site_mask=site_mask, - ) - - # Trim things down a bit for speed. - ds = ds.isel(variants=slice(0, 100)) - - # Plot. - fig = api.plot_frequencies_time_series(ds, show=False) - - # Test. - assert isinstance(fig, go.Figure) - - # Compute amino acid change frequencies. - ds = api.aa_allele_frequencies_advanced( - transcript=transcript, - area_by=area_by, - period_by=period_by, - sample_sets=sample_sets, - min_cohort_size=min_cohort_size, - ) - - # Trim things down a bit for speed. - ds = ds.isel(variants=slice(0, 100)) - - # Plot. - fig = api.plot_frequencies_time_series(ds, show=False) - - # Test. - assert isinstance(fig, go.Figure) - - -@parametrize_with_cases("fixture,api", cases=".") -def test_plot_frequencies_time_series_with_taxa( - fixture, - api: AnophelesSnpFrequencyAnalysis, -): - # Pick test parameters at random. - all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) - site_mask = random.choice(api.site_mask_ids + (None,)) - transcript = random_transcript(api=api).name - area_by = random.choice(["country", "admin1_iso", "admin2_name"]) - period_by = random.choice(["year", "quarter", "month"]) - - # Pick a random taxon and taxa from valid taxa. - sample_sets_taxa = ( - api.sample_metadata(sample_sets=sample_sets)["taxon"].dropna().unique().tolist() - ) - taxon = random.choice(sample_sets_taxa) - taxa = random.sample(sample_sets_taxa, random.randint(1, len(sample_sets_taxa))) - - # Compute SNP frequencies. - ds = api.snp_allele_frequencies_advanced( - transcript=transcript, - area_by=area_by, - period_by=period_by, - sample_sets=sample_sets, - min_cohort_size=1, # Don't exclude any samples. - site_mask=site_mask, - ) - - # Trim things down a bit for speed. - ds = ds.isel(variants=slice(0, 100)) - - # Plot with taxon. - fig = api.plot_frequencies_time_series(ds, show=False, taxa=taxon) - - # Test taxon plot. - assert isinstance(fig, go.Figure) - - # Plot with taxa. - fig = api.plot_frequencies_time_series(ds, show=False, taxa=taxa) - - # Test taxa plot. - assert isinstance(fig, go.Figure) - - -@parametrize_with_cases("fixture,api", cases=".") -def test_plot_frequencies_time_series_with_areas( - fixture, - api: AnophelesSnpFrequencyAnalysis, -): - # Pick test parameters at random. - all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) - site_mask = random.choice(api.site_mask_ids + (None,)) - transcript = random_transcript(api=api).name - area_by = random.choice(["country", "admin1_iso", "admin2_name"]) - period_by = random.choice(["year", "quarter", "month"]) - - # Compute SNP frequencies. - ds = api.snp_allele_frequencies_advanced( - transcript=transcript, - area_by=area_by, - period_by=period_by, - sample_sets=sample_sets, - min_cohort_size=1, # Don't exclude any samples. - site_mask=site_mask, - ) - - # Trim things down a bit for speed. - ds = ds.isel(variants=slice(0, 100)) - - # Extract cohorts into a DataFrame. - cohort_vars = [v for v in ds if str(v).startswith("cohort_")] - df_cohorts = ds[cohort_vars].to_dataframe() - - # Pick a random area and areas from valid areas. - cohorts_areas = df_cohorts["cohort_area"].dropna().unique().tolist() - area = random.choice(cohorts_areas) - areas = random.sample(cohorts_areas, random.randint(1, len(cohorts_areas))) - - # Plot with area. - fig = api.plot_frequencies_time_series(ds, show=False, areas=area) - - # Test areas plot. - assert isinstance(fig, go.Figure) - - # Plot with areas. - fig = api.plot_frequencies_time_series(ds, show=False, areas=areas) - - # Test area plot. - assert isinstance(fig, go.Figure) - - -@parametrize_with_cases("fixture,api", cases=".") -def test_plot_frequencies_interactive_map( - fixture, - api: AnophelesSnpFrequencyAnalysis, -): - import ipywidgets # type: ignore - - # Pick test parameters at random. - all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) - site_mask = random.choice(api.site_mask_ids + (None,)) - min_cohort_size = random.randint(0, 2) - transcript = random_transcript(api=api).name - area_by = random.choice(["country", "admin1_iso", "admin2_name"]) - period_by = random.choice(["year", "quarter", "month"]) - - # Compute SNP frequencies. - ds = api.snp_allele_frequencies_advanced( - transcript=transcript, - area_by=area_by, - period_by=period_by, - sample_sets=sample_sets, - min_cohort_size=min_cohort_size, - site_mask=site_mask, - ) - - # Trim things down a bit for speed. - ds = ds.isel(variants=slice(0, 100)) - - # Plot. - fig = api.plot_frequencies_interactive_map(ds) - - # Test. - assert isinstance(fig, ipywidgets.Widget) - - # Compute amino acid change frequencies. - ds = api.aa_allele_frequencies_advanced( - transcript=transcript, - area_by=area_by, - period_by=period_by, - sample_sets=sample_sets, - min_cohort_size=min_cohort_size, - ) - - # Trim things down a bit for speed. - ds = ds.isel(variants=slice(0, 100)) - - # Plot. - fig = api.plot_frequencies_interactive_map(ds) - - # Test. - assert isinstance(fig, ipywidgets.Widget)