diff --git a/malariagen_data/anoph/cnv_frq.py b/malariagen_data/anoph/cnv_frq.py
index 29e83361..fda01aa6 100644
--- a/malariagen_data/anoph/cnv_frq.py
+++ b/malariagen_data/anoph/cnv_frq.py
@@ -10,6 +10,11 @@
from numpydoc_decorator import doc # type: ignore
from . import base_params, cnv_params, frq_params
+from .frq_funcs import (
+ prep_samples_for_cohort_grouping,
+ build_cohorts_from_sample_grouping,
+ add_frequency_ci,
+)
from ..util import (
check_types,
pandas_apply,
@@ -17,17 +22,13 @@
parse_multi_region,
region_str,
simple_xarray_concat,
- prep_samples_for_cohort_grouping,
- build_cohorts_from_sample_grouping,
- add_frequency_ci,
)
from .cnv_data import AnophelesCnvData
+from .frq_funcs import AnophelesFrequency
from .sample_metadata import locate_cohorts
-class AnophelesCnvFrequencyAnalysis(
- AnophelesCnvData,
-):
+class AnophelesCnvFrequencyAnalysis(AnophelesCnvData, AnophelesFrequency):
def __init__(
self,
**kwargs,
diff --git a/malariagen_data/anoph/frq_funcs.py b/malariagen_data/anoph/frq_funcs.py
new file mode 100644
index 00000000..425911a6
--- /dev/null
+++ b/malariagen_data/anoph/frq_funcs.py
@@ -0,0 +1,584 @@
+import numpy as np
+import pandas as pd
+import xarray as xr
+import plotly.express as px
+from textwrap import dedent
+from typing import Optional, Union, List
+from numpydoc_decorator import doc # type: ignore
+from . import (
+ plotly_params,
+ frq_params,
+ map_params,
+)
+from ..util import check_types
+from .base import AnophelesBase
+
+
+def prep_samples_for_cohort_grouping(*, df_samples, area_by, period_by):
+ # Take a copy, as we will modify the dataframe.
+ df_samples = df_samples.copy()
+
+ # Fix "intermediate" or "unassigned" taxon values - we only want to build
+ # cohorts with clean taxon calls, so we set other values to None.
+ loc_intermediate_taxon = (
+ df_samples["taxon"].str.startswith("intermediate").fillna(False)
+ )
+ df_samples.loc[loc_intermediate_taxon, "taxon"] = None
+ loc_unassigned_taxon = (
+ df_samples["taxon"].str.startswith("unassigned").fillna(False)
+ )
+ df_samples.loc[loc_unassigned_taxon, "taxon"] = None
+
+ # Add period column.
+ if period_by == "year":
+ make_period = _make_sample_period_year
+ elif period_by == "quarter":
+ make_period = _make_sample_period_quarter
+ elif period_by == "month":
+ make_period = _make_sample_period_month
+ else: # pragma: no cover
+ raise ValueError(
+ f"Value for period_by parameter must be one of 'year', 'quarter', 'month'; found {period_by!r}."
+ )
+ sample_period = df_samples.apply(make_period, axis="columns")
+ df_samples["period"] = sample_period
+
+ # Add area column for consistent output.
+ df_samples["area"] = df_samples[area_by]
+
+ return df_samples
+
+
+def build_cohorts_from_sample_grouping(*, group_samples_by_cohort, min_cohort_size):
+ # Build cohorts dataframe.
+ df_cohorts = group_samples_by_cohort.agg(
+ size=("sample_id", len),
+ lat_mean=("latitude", "mean"),
+ lat_max=("latitude", "max"),
+ lat_min=("latitude", "min"),
+ lon_mean=("longitude", "mean"),
+ lon_max=("longitude", "max"),
+ lon_min=("longitude", "min"),
+ )
+ # Reset index so that the index fields are included as columns.
+ df_cohorts = df_cohorts.reset_index()
+
+ # Add cohort helper variables.
+ cohort_period_start = df_cohorts["period"].apply(lambda v: v.start_time)
+ cohort_period_end = df_cohorts["period"].apply(lambda v: v.end_time)
+ df_cohorts["period_start"] = cohort_period_start
+ df_cohorts["period_end"] = cohort_period_end
+ # Create a label that is similar to the cohort metadata,
+ # although this won't be perfect.
+ df_cohorts["label"] = df_cohorts.apply(
+ lambda v: f"{v.area}_{v.taxon[:4]}_{v.period}", axis="columns"
+ )
+
+ # Apply minimum cohort size.
+ df_cohorts = df_cohorts.query(f"size >= {min_cohort_size}").reset_index(drop=True)
+
+ # Early check for no cohorts.
+ if len(df_cohorts) == 0:
+ raise ValueError(
+ "No cohorts available for the given sample selection parameters and minimum cohort size."
+ )
+
+ return df_cohorts
+
+
+def add_frequency_ci(*, ds, ci_method):
+ from statsmodels.stats.proportion import proportion_confint # type: ignore
+
+ if ci_method is not None:
+ count = ds["event_count"].values
+ nobs = ds["event_nobs"].values
+ with np.errstate(divide="ignore", invalid="ignore"):
+ frq_ci_low, frq_ci_upp = proportion_confint(
+ count=count, nobs=nobs, method=ci_method
+ )
+ ds["event_frequency_ci_low"] = ("variants", "cohorts"), frq_ci_low
+ ds["event_frequency_ci_upp"] = ("variants", "cohorts"), frq_ci_upp
+
+
+def _make_sample_period_month(row):
+ year = row.year
+ month = row.month
+ if year > 0 and month > 0:
+ return pd.Period(freq="M", year=year, month=month)
+ else:
+ return pd.NaT
+
+
+def _make_sample_period_quarter(row):
+ year = row.year
+ month = row.month
+ if year > 0 and month > 0:
+ return pd.Period(freq="Q", year=year, month=month)
+ else:
+ return pd.NaT
+
+
+def _make_sample_period_year(row):
+ year = row.year
+ if year > 0:
+ return pd.Period(freq="Y", year=year)
+ else:
+ return pd.NaT
+
+
+class AnophelesFrequency(AnophelesBase):
+ def __init__(
+ self,
+ **kwargs,
+ ):
+ # N.B., this class is designed to work cooperatively, and
+ # so it's important that any remaining parameters are passed
+ # to the superclass constructor.
+ super().__init__(**kwargs)
+
+ @check_types
+ @doc(
+ summary="""
+ Plot a heatmap from a pandas DataFrame of frequencies, e.g., output
+ from `snp_allele_frequencies()` or `gene_cnv_frequencies()`.
+ """,
+ parameters=dict(
+ df="""
+ A DataFrame of frequencies, e.g., output from
+ `snp_allele_frequencies()` or `gene_cnv_frequencies()`.
+ """,
+ index="""
+ One or more column headers that are present in the input dataframe.
+ This becomes the heatmap y-axis row labels. The column/s must
+ produce a unique index.
+ """,
+ max_len="""
+ Displaying large styled dataframes may cause ipython notebooks to
+ crash. If the input dataframe is larger than this value, an error
+ will be raised.
+ """,
+ col_width="""
+ Plot width per column in pixels (px).
+ """,
+ row_height="""
+ Plot height per row in pixels (px).
+ """,
+ kwargs="""
+ Passed through to `px.imshow()`.
+ """,
+ ),
+ notes="""
+ It's recommended to filter the input DataFrame to just rows of interest,
+ i.e., fewer rows than `max_len`.
+ """,
+ )
+ def plot_frequencies_heatmap(
+ self,
+ df: pd.DataFrame,
+ index: Optional[Union[str, List[str]]] = "label",
+ max_len: Optional[int] = 100,
+ col_width: int = 40,
+ row_height: int = 20,
+ x_label: plotly_params.x_label = "Cohorts",
+ y_label: plotly_params.y_label = "Variants",
+ colorbar: plotly_params.colorbar = True,
+ width: plotly_params.fig_width = None,
+ height: plotly_params.fig_height = None,
+ text_auto: plotly_params.text_auto = ".0%",
+ aspect: plotly_params.aspect = "auto",
+ color_continuous_scale: plotly_params.color_continuous_scale = "Reds",
+ title: plotly_params.title = True,
+ show: plotly_params.show = True,
+ renderer: plotly_params.renderer = None,
+ **kwargs,
+ ) -> plotly_params.figure:
+ # Check len of input.
+ if max_len and len(df) > max_len:
+ raise ValueError(
+ dedent(
+ f"""
+ Input DataFrame is longer than max_len parameter value {max_len}, which means
+ that the plot is likely to be very large. If you really want to go ahead,
+ please rerun the function with max_len=None.
+ """
+ )
+ )
+
+ # Handle title.
+ if title is True:
+ title = df.attrs.get("title", None)
+
+ # Indexing.
+ if index is None:
+ index = list(df.index.names)
+ df = df.reset_index().copy()
+ if isinstance(index, list):
+ index_col = (
+ df[index]
+ .astype(str)
+ .apply(
+ lambda row: ", ".join([o for o in row if o is not None]),
+ axis="columns",
+ )
+ )
+ else:
+ assert isinstance(index, str)
+ index_col = df[index].astype(str)
+
+ # Check that index is unique.
+ if not index_col.is_unique:
+ raise ValueError(f"{index} does not produce a unique index")
+
+ # Drop and re-order columns.
+ frq_cols = [col for col in df.columns if col.startswith("frq_")]
+
+ # Keep only freq cols.
+ heatmap_df = df[frq_cols].copy()
+
+ # Set index.
+ heatmap_df.set_index(index_col, inplace=True)
+
+ # Clean column names.
+ heatmap_df.columns = heatmap_df.columns.str.lstrip("frq_")
+
+ # Deal with width and height.
+ if width is None:
+ width = 400 + col_width * len(heatmap_df.columns)
+ if colorbar:
+ width += 40
+ if height is None:
+ height = 200 + row_height * len(heatmap_df)
+ if title is not None:
+ height += 40
+
+ # Plotly heatmap styling.
+ fig = px.imshow(
+ img=heatmap_df,
+ zmin=0,
+ zmax=1,
+ width=width,
+ height=height,
+ text_auto=text_auto,
+ aspect=aspect,
+ color_continuous_scale=color_continuous_scale,
+ title=title,
+ **kwargs,
+ )
+
+ fig.update_xaxes(side="bottom", tickangle=30)
+ if x_label is not None:
+ fig.update_xaxes(title=x_label)
+ if y_label is not None:
+ fig.update_yaxes(title=y_label)
+ fig.update_layout(
+ coloraxis_colorbar=dict(
+ title="Frequency",
+ tickvals=[0, 0.2, 0.4, 0.6, 0.8, 1.0],
+ ticktext=["0%", "20%", "40%", "60%", "80%", "100%"],
+ )
+ )
+ if not colorbar:
+ fig.update(layout_coloraxis_showscale=False)
+
+ if show: # pragma: no cover
+ fig.show(renderer=renderer)
+ return None
+ else:
+ return fig
+
+ @check_types
+ @doc(
+ summary="Create a time series plot of variant frequencies using plotly.",
+ parameters=dict(
+ ds="""
+ A dataset of variant frequencies, such as returned by
+ `snp_allele_frequencies_advanced()`,
+ `aa_allele_frequencies_advanced()` or
+ `gene_cnv_frequencies_advanced()`.
+ """,
+ kwargs="Passed through to `px.line()`.",
+ ),
+ returns="""
+ A plotly figure containing line graphs. The resulting figure will
+ have one panel per cohort, grouped into columns by taxon, and
+ grouped into rows by area. Markers and lines show frequencies of
+ variants.
+ """,
+ )
+ def plot_frequencies_time_series(
+ self,
+ ds: xr.Dataset,
+ height: plotly_params.fig_height = None,
+ width: plotly_params.fig_width = None,
+ title: plotly_params.title = True,
+ legend_sizing: plotly_params.legend_sizing = "constant",
+ show: plotly_params.show = True,
+ renderer: plotly_params.renderer = None,
+ taxa: frq_params.taxa = None,
+ areas: frq_params.areas = None,
+ **kwargs,
+ ) -> plotly_params.figure:
+ # Handle title.
+ if title is True:
+ title = ds.attrs.get("title", None)
+
+ # Extract cohorts into a dataframe.
+ cohort_vars = [v for v in ds if str(v).startswith("cohort_")]
+ df_cohorts = ds[cohort_vars].to_dataframe()
+ df_cohorts.columns = [c.split("cohort_")[1] for c in df_cohorts.columns] # type: ignore
+
+ # If specified, restrict the dataframe by taxa.
+ if isinstance(taxa, str):
+ df_cohorts = df_cohorts[df_cohorts["taxon"] == taxa]
+ elif isinstance(taxa, (list, tuple)):
+ df_cohorts = df_cohorts[df_cohorts["taxon"].isin(taxa)]
+
+ # If specified, restrict the dataframe by areas.
+ if isinstance(areas, str):
+ df_cohorts = df_cohorts[df_cohorts["area"] == areas]
+ elif isinstance(areas, (list, tuple)):
+ df_cohorts = df_cohorts[df_cohorts["area"].isin(areas)]
+
+ # Extract variant labels.
+ variant_labels = ds["variant_label"].values
+
+ # Build a long-form dataframe from the dataset.
+ dfs = []
+ for cohort_index, cohort in enumerate(df_cohorts.itertuples()):
+ ds_cohort = ds.isel(cohorts=cohort_index)
+ df = pd.DataFrame(
+ {
+ "taxon": cohort.taxon,
+ "area": cohort.area,
+ "date": cohort.period_start,
+ "period": str(
+ cohort.period
+ ), # use string representation for hover label
+ "sample_size": cohort.size,
+ "variant": variant_labels,
+ "count": ds_cohort["event_count"].values,
+ "nobs": ds_cohort["event_nobs"].values,
+ "frequency": ds_cohort["event_frequency"].values,
+ "frequency_ci_low": ds_cohort["event_frequency_ci_low"].values,
+ "frequency_ci_upp": ds_cohort["event_frequency_ci_upp"].values,
+ }
+ )
+ dfs.append(df)
+ df_events = pd.concat(dfs, axis=0).reset_index(drop=True)
+
+ # Remove events with no observations.
+ df_events = df_events.query("nobs > 0").copy()
+
+ # Calculate error bars.
+ frq = df_events["frequency"]
+ frq_ci_low = df_events["frequency_ci_low"]
+ frq_ci_upp = df_events["frequency_ci_upp"]
+ df_events["frequency_error"] = frq_ci_upp - frq
+ df_events["frequency_error_minus"] = frq - frq_ci_low
+
+ # Make a plot.
+ fig = px.line(
+ df_events,
+ facet_col="taxon",
+ facet_row="area",
+ x="date",
+ y="frequency",
+ error_y="frequency_error",
+ error_y_minus="frequency_error_minus",
+ color="variant",
+ markers=True,
+ hover_name="variant",
+ hover_data={
+ "frequency": ":.0%",
+ "period": True,
+ "area": True,
+ "taxon": True,
+ "sample_size": True,
+ "date": False,
+ "variant": False,
+ },
+ height=height,
+ width=width,
+ title=title,
+ labels={
+ "date": "Date",
+ "frequency": "Frequency",
+ "variant": "Variant",
+ "taxon": "Taxon",
+ "area": "Area",
+ "period": "Period",
+ "sample_size": "Sample size",
+ },
+ **kwargs,
+ )
+
+ # Tidy plot.
+ fig.update_layout(
+ yaxis_range=[-0.05, 1.05],
+ legend=dict(itemsizing=legend_sizing, tracegroupgap=0),
+ )
+
+ if show: # pragma: no cover
+ fig.show(renderer=renderer)
+ return None
+ else:
+ return fig
+
+ @check_types
+ @doc(
+ summary="""
+ Plot markers on a map showing variant frequencies for cohorts grouped
+ by area (space), period (time) and taxon.
+ """,
+ parameters=dict(
+ m="The map on which to add the markers.",
+ variant="Index or label of variant to plot.",
+ taxon="Taxon to show markers for.",
+ period="Time period to show markers for.",
+ clear="""
+ If True, clear all layers (except the base layer) from the map
+ before adding new markers.
+ """,
+ ),
+ )
+ def plot_frequencies_map_markers(
+ self,
+ m,
+ ds: frq_params.ds_frequencies_advanced,
+ variant: Union[int, str],
+ taxon: str,
+ period: pd.Period,
+ clear: bool = True,
+ ):
+ # Only import here because of some problems importing globally.
+ import ipyleaflet # type: ignore
+ import ipywidgets # type: ignore
+
+ # Slice dataset to variant of interest.
+ if isinstance(variant, int):
+ ds_variant = ds.isel(variants=variant)
+ variant_label = ds["variant_label"].values[variant]
+ else:
+ assert isinstance(variant, str)
+ ds_variant = ds.set_index(variants="variant_label").sel(variants=variant)
+ variant_label = variant
+
+ # Convert to a dataframe for convenience.
+ df_markers = ds_variant[
+ [
+ "cohort_taxon",
+ "cohort_area",
+ "cohort_period",
+ "cohort_lat_mean",
+ "cohort_lon_mean",
+ "cohort_size",
+ "event_frequency",
+ "event_frequency_ci_low",
+ "event_frequency_ci_upp",
+ ]
+ ].to_dataframe()
+
+ # Select data matching taxon and period parameters.
+ df_markers = df_markers.loc[
+ (
+ (df_markers["cohort_taxon"] == taxon)
+ & (df_markers["cohort_period"] == period)
+ )
+ ]
+
+ # Clear existing layers in the map.
+ if clear:
+ for layer in m.layers[1:]:
+ m.remove_layer(layer)
+
+ # Add markers.
+ for x in df_markers.itertuples():
+ marker = ipyleaflet.CircleMarker()
+ marker.location = (x.cohort_lat_mean, x.cohort_lon_mean)
+ marker.radius = 20
+ marker.color = "black"
+ marker.weight = 1
+ marker.fill_color = "red"
+ marker.fill_opacity = x.event_frequency
+ popup_html = f"""
+ {variant_label}
+ Taxon: {x.cohort_taxon}
+ Area: {x.cohort_area}
+ Period: {x.cohort_period}
+ Sample size: {x.cohort_size}
+ Frequency: {x.event_frequency:.0%}
+ (95% CI: {x.event_frequency_ci_low:.0%} - {x.event_frequency_ci_upp:.0%})
+ """
+ marker.popup = ipyleaflet.Popup(
+ child=ipywidgets.HTML(popup_html),
+ auto_pan=False,
+ )
+ m.add(marker)
+
+ @check_types
+ @doc(
+ summary="""
+ Create an interactive map with markers showing variant frequencies or
+ cohorts grouped by area (space), period (time) and taxon.
+ """,
+ parameters=dict(
+ title="""
+ If True, attempt to use metadata from input dataset as a plot
+ title. Otherwise, use supplied value as a title.
+ """,
+ epilogue="Additional text to display below the map.",
+ ),
+ returns="""
+ An interactive map with widgets for selecting which variant, taxon
+ and time period to display.
+ """,
+ )
+ def plot_frequencies_interactive_map(
+ self,
+ ds: frq_params.ds_frequencies_advanced,
+ center: map_params.center = map_params.center_default,
+ zoom: map_params.zoom = map_params.zoom_default,
+ title: Optional[Union[bool, str]] = True,
+ epilogue: Union[bool, str] = True,
+ ):
+ import ipyleaflet
+ import ipywidgets
+
+ # Handle title.
+ if title is True:
+ title = ds.attrs.get("title", None)
+
+ # Create a map.
+ freq_map = ipyleaflet.Map(center=center, zoom=zoom)
+
+ # Set up interactive controls.
+ variants = ds["variant_label"].values
+ taxa = ds["cohort_taxon"].to_pandas().dropna().unique() # type: ignore
+ periods = ds["cohort_period"].to_pandas().dropna().unique() # type: ignore
+ controls = ipywidgets.interactive(
+ self.plot_frequencies_map_markers,
+ m=ipywidgets.fixed(freq_map),
+ ds=ipywidgets.fixed(ds),
+ variant=ipywidgets.Dropdown(options=variants, description="Variant: "),
+ taxon=ipywidgets.Dropdown(options=taxa, description="Taxon: "),
+ period=ipywidgets.Dropdown(options=periods, description="Period: "),
+ clear=ipywidgets.fixed(True),
+ )
+
+ # Lay out widgets.
+ components = []
+ if title is not None:
+ components.append(ipywidgets.HTML(value=f"
{title}
"))
+ components.append(controls)
+ components.append(freq_map)
+ if epilogue is True:
+ epilogue = """
+ Variant frequencies are shown as coloured markers. Opacity of color
+ denotes frequency. Click on a marker for more information.
+ """
+ if epilogue:
+ components.append(ipywidgets.HTML(value=f"{epilogue}"))
+
+ out = ipywidgets.VBox(components)
+
+ return out
diff --git a/malariagen_data/anoph/hap_frq.py b/malariagen_data/anoph/hap_frq.py
index 5dcf6ddb..b30e5318 100644
--- a/malariagen_data/anoph/hap_frq.py
+++ b/malariagen_data/anoph/hap_frq.py
@@ -9,18 +9,19 @@
from ..util import (
check_types,
haplotype_frequencies,
+)
+from .hap_data import AnophelesHapData
+from .frq_funcs import (
prep_samples_for_cohort_grouping,
build_cohorts_from_sample_grouping,
add_frequency_ci,
)
-from .hap_data import AnophelesHapData
from .sample_metadata import locate_cohorts
+from .frq_funcs import AnophelesFrequency
from . import base_params, frq_params
-class AnophelesHapFrequencyAnalysis(
- AnophelesHapData,
-):
+class AnophelesHapFrequencyAnalysis(AnophelesHapData, AnophelesFrequency):
def __init__(
self,
**kwargs,
@@ -119,7 +120,7 @@ def haplotypes_frequencies(
df_haps_sorted["label"] = ["H" + str(i) for i in range(len(df_haps_sorted))]
# Reset index after filtering.
- df_haps_sorted.set_index(keys="label", drop=True)
+ df_haps_sorted.set_index(keys="label", drop=True, inplace=True)
return df_haps_sorted
diff --git a/malariagen_data/anoph/snp_frq.py b/malariagen_data/anoph/snp_frq.py
index cd5304d7..0c1f05d3 100644
--- a/malariagen_data/anoph/snp_frq.py
+++ b/malariagen_data/anoph/snp_frq.py
@@ -1,6 +1,5 @@
from typing import Optional, Dict, Union, Callable, List
import warnings
-from textwrap import dedent
import allel # type: ignore
import numpy as np
@@ -8,19 +7,21 @@
from numpydoc_decorator import doc # type: ignore
import xarray as xr
import numba # type: ignore
-import plotly.express as px # type: ignore
from .. import veff
from ..util import (
check_types,
pandas_apply,
+)
+from .snp_data import AnophelesSnpData
+from .frq_funcs import (
prep_samples_for_cohort_grouping,
build_cohorts_from_sample_grouping,
add_frequency_ci,
)
-from .snp_data import AnophelesSnpData
from .sample_metadata import locate_cohorts
-from . import base_params, frq_params, map_params, plotly_params
+from .frq_funcs import AnophelesFrequency
+from . import base_params, frq_params
AA_CHANGE_QUERY = (
@@ -28,9 +29,7 @@
)
-class AnophelesSnpFrequencyAnalysis(
- AnophelesSnpData,
-):
+class AnophelesSnpFrequencyAnalysis(AnophelesSnpData, AnophelesFrequency):
def __init__(
self,
**kwargs,
@@ -775,453 +774,6 @@ def aa_allele_frequencies_advanced(
return ds_aa_frq
- @check_types
- @doc(
- summary="""
- Plot a heatmap from a pandas DataFrame of frequencies, e.g., output
- from `snp_allele_frequencies()` or `gene_cnv_frequencies()`.
- """,
- parameters=dict(
- df="""
- A DataFrame of frequencies, e.g., output from
- `snp_allele_frequencies()` or `gene_cnv_frequencies()`.
- """,
- index="""
- One or more column headers that are present in the input dataframe.
- This becomes the heatmap y-axis row labels. The column/s must
- produce a unique index.
- """,
- max_len="""
- Displaying large styled dataframes may cause ipython notebooks to
- crash. If the input dataframe is larger than this value, an error
- will be raised.
- """,
- col_width="""
- Plot width per column in pixels (px).
- """,
- row_height="""
- Plot height per row in pixels (px).
- """,
- kwargs="""
- Passed through to `px.imshow()`.
- """,
- ),
- notes="""
- It's recommended to filter the input DataFrame to just rows of interest,
- i.e., fewer rows than `max_len`.
- """,
- )
- def plot_frequencies_heatmap(
- self,
- df: pd.DataFrame,
- index: Optional[Union[str, List[str]]] = "label",
- max_len: Optional[int] = 100,
- col_width: int = 40,
- row_height: int = 20,
- x_label: plotly_params.x_label = "Cohorts",
- y_label: plotly_params.y_label = "Variants",
- colorbar: plotly_params.colorbar = True,
- width: plotly_params.fig_width = None,
- height: plotly_params.fig_height = None,
- text_auto: plotly_params.text_auto = ".0%",
- aspect: plotly_params.aspect = "auto",
- color_continuous_scale: plotly_params.color_continuous_scale = "Reds",
- title: plotly_params.title = True,
- show: plotly_params.show = True,
- renderer: plotly_params.renderer = None,
- **kwargs,
- ) -> plotly_params.figure:
- # Check len of input.
- if max_len and len(df) > max_len:
- raise ValueError(
- dedent(
- f"""
- Input DataFrame is longer than max_len parameter value {max_len}, which means
- that the plot is likely to be very large. If you really want to go ahead,
- please rerun the function with max_len=None.
- """
- )
- )
-
- # Handle title.
- if title is True:
- title = df.attrs.get("title", None)
-
- # Indexing.
- if index is None:
- index = list(df.index.names)
- df = df.reset_index().copy()
- if isinstance(index, list):
- index_col = (
- df[index]
- .astype(str)
- .apply(
- lambda row: ", ".join([o for o in row if o is not None]),
- axis="columns",
- )
- )
- else:
- assert isinstance(index, str)
- index_col = df[index].astype(str)
-
- # Check that index is unique.
- if not index_col.is_unique:
- raise ValueError(f"{index} does not produce a unique index")
-
- # Drop and re-order columns.
- frq_cols = [col for col in df.columns if col.startswith("frq_")]
-
- # Keep only freq cols.
- heatmap_df = df[frq_cols].copy()
-
- # Set index.
- heatmap_df.set_index(index_col, inplace=True)
-
- # Clean column names.
- heatmap_df.columns = heatmap_df.columns.str.lstrip("frq_")
-
- # Deal with width and height.
- if width is None:
- width = 400 + col_width * len(heatmap_df.columns)
- if colorbar:
- width += 40
- if height is None:
- height = 200 + row_height * len(heatmap_df)
- if title is not None:
- height += 40
-
- # Plotly heatmap styling.
- fig = px.imshow(
- img=heatmap_df,
- zmin=0,
- zmax=1,
- width=width,
- height=height,
- text_auto=text_auto,
- aspect=aspect,
- color_continuous_scale=color_continuous_scale,
- title=title,
- **kwargs,
- )
-
- fig.update_xaxes(side="bottom", tickangle=30)
- if x_label is not None:
- fig.update_xaxes(title=x_label)
- if y_label is not None:
- fig.update_yaxes(title=y_label)
- fig.update_layout(
- coloraxis_colorbar=dict(
- title="Frequency",
- tickvals=[0, 0.2, 0.4, 0.6, 0.8, 1.0],
- ticktext=["0%", "20%", "40%", "60%", "80%", "100%"],
- )
- )
- if not colorbar:
- fig.update(layout_coloraxis_showscale=False)
-
- if show: # pragma: no cover
- fig.show(renderer=renderer)
- return None
- else:
- return fig
-
- @check_types
- @doc(
- summary="Create a time series plot of variant frequencies using plotly.",
- parameters=dict(
- ds="""
- A dataset of variant frequencies, such as returned by
- `snp_allele_frequencies_advanced()`,
- `aa_allele_frequencies_advanced()` or
- `gene_cnv_frequencies_advanced()`.
- """,
- kwargs="Passed through to `px.line()`.",
- ),
- returns="""
- A plotly figure containing line graphs. The resulting figure will
- have one panel per cohort, grouped into columns by taxon, and
- grouped into rows by area. Markers and lines show frequencies of
- variants.
- """,
- )
- def plot_frequencies_time_series(
- self,
- ds: xr.Dataset,
- height: plotly_params.fig_height = None,
- width: plotly_params.fig_width = None,
- title: plotly_params.title = True,
- legend_sizing: plotly_params.legend_sizing = "constant",
- show: plotly_params.show = True,
- renderer: plotly_params.renderer = None,
- taxa: frq_params.taxa = None,
- areas: frq_params.areas = None,
- **kwargs,
- ) -> plotly_params.figure:
- # Handle title.
- if title is True:
- title = ds.attrs.get("title", None)
-
- # Extract cohorts into a dataframe.
- cohort_vars = [v for v in ds if str(v).startswith("cohort_")]
- df_cohorts = ds[cohort_vars].to_dataframe()
- df_cohorts.columns = [c.split("cohort_")[1] for c in df_cohorts.columns] # type: ignore
-
- # If specified, restrict the dataframe by taxa.
- if isinstance(taxa, str):
- df_cohorts = df_cohorts[df_cohorts["taxon"] == taxa]
- elif isinstance(taxa, (list, tuple)):
- df_cohorts = df_cohorts[df_cohorts["taxon"].isin(taxa)]
-
- # If specified, restrict the dataframe by areas.
- if isinstance(areas, str):
- df_cohorts = df_cohorts[df_cohorts["area"] == areas]
- elif isinstance(areas, (list, tuple)):
- df_cohorts = df_cohorts[df_cohorts["area"].isin(areas)]
-
- # Extract variant labels.
- variant_labels = ds["variant_label"].values
-
- # Build a long-form dataframe from the dataset.
- dfs = []
- for cohort_index, cohort in enumerate(df_cohorts.itertuples()):
- ds_cohort = ds.isel(cohorts=cohort_index)
- df = pd.DataFrame(
- {
- "taxon": cohort.taxon,
- "area": cohort.area,
- "date": cohort.period_start,
- "period": str(
- cohort.period
- ), # use string representation for hover label
- "sample_size": cohort.size,
- "variant": variant_labels,
- "count": ds_cohort["event_count"].values,
- "nobs": ds_cohort["event_nobs"].values,
- "frequency": ds_cohort["event_frequency"].values,
- "frequency_ci_low": ds_cohort["event_frequency_ci_low"].values,
- "frequency_ci_upp": ds_cohort["event_frequency_ci_upp"].values,
- }
- )
- dfs.append(df)
- df_events = pd.concat(dfs, axis=0).reset_index(drop=True)
-
- # Remove events with no observations.
- df_events = df_events.query("nobs > 0").copy()
-
- # Calculate error bars.
- frq = df_events["frequency"]
- frq_ci_low = df_events["frequency_ci_low"]
- frq_ci_upp = df_events["frequency_ci_upp"]
- df_events["frequency_error"] = frq_ci_upp - frq
- df_events["frequency_error_minus"] = frq - frq_ci_low
-
- # Make a plot.
- fig = px.line(
- df_events,
- facet_col="taxon",
- facet_row="area",
- x="date",
- y="frequency",
- error_y="frequency_error",
- error_y_minus="frequency_error_minus",
- color="variant",
- markers=True,
- hover_name="variant",
- hover_data={
- "frequency": ":.0%",
- "period": True,
- "area": True,
- "taxon": True,
- "sample_size": True,
- "date": False,
- "variant": False,
- },
- height=height,
- width=width,
- title=title,
- labels={
- "date": "Date",
- "frequency": "Frequency",
- "variant": "Variant",
- "taxon": "Taxon",
- "area": "Area",
- "period": "Period",
- "sample_size": "Sample size",
- },
- **kwargs,
- )
-
- # Tidy plot.
- fig.update_layout(
- yaxis_range=[-0.05, 1.05],
- legend=dict(itemsizing=legend_sizing, tracegroupgap=0),
- )
-
- if show: # pragma: no cover
- fig.show(renderer=renderer)
- return None
- else:
- return fig
-
- @check_types
- @doc(
- summary="""
- Plot markers on a map showing variant frequencies for cohorts grouped
- by area (space), period (time) and taxon.
- """,
- parameters=dict(
- m="The map on which to add the markers.",
- variant="Index or label of variant to plot.",
- taxon="Taxon to show markers for.",
- period="Time period to show markers for.",
- clear="""
- If True, clear all layers (except the base layer) from the map
- before adding new markers.
- """,
- ),
- )
- def plot_frequencies_map_markers(
- self,
- m,
- ds: frq_params.ds_frequencies_advanced,
- variant: Union[int, str],
- taxon: str,
- period: pd.Period,
- clear: bool = True,
- ):
- # Only import here because of some problems importing globally.
- import ipyleaflet # type: ignore
- import ipywidgets # type: ignore
-
- # Slice dataset to variant of interest.
- if isinstance(variant, int):
- ds_variant = ds.isel(variants=variant)
- variant_label = ds["variant_label"].values[variant]
- else:
- assert isinstance(variant, str)
- ds_variant = ds.set_index(variants="variant_label").sel(variants=variant)
- variant_label = variant
-
- # Convert to a dataframe for convenience.
- df_markers = ds_variant[
- [
- "cohort_taxon",
- "cohort_area",
- "cohort_period",
- "cohort_lat_mean",
- "cohort_lon_mean",
- "cohort_size",
- "event_frequency",
- "event_frequency_ci_low",
- "event_frequency_ci_upp",
- ]
- ].to_dataframe()
-
- # Select data matching taxon and period parameters.
- df_markers = df_markers.loc[
- (
- (df_markers["cohort_taxon"] == taxon)
- & (df_markers["cohort_period"] == period)
- )
- ]
-
- # Clear existing layers in the map.
- if clear:
- for layer in m.layers[1:]:
- m.remove_layer(layer)
-
- # Add markers.
- for x in df_markers.itertuples():
- marker = ipyleaflet.CircleMarker()
- marker.location = (x.cohort_lat_mean, x.cohort_lon_mean)
- marker.radius = 20
- marker.color = "black"
- marker.weight = 1
- marker.fill_color = "red"
- marker.fill_opacity = x.event_frequency
- popup_html = f"""
- {variant_label}
- Taxon: {x.cohort_taxon}
- Area: {x.cohort_area}
- Period: {x.cohort_period}
- Sample size: {x.cohort_size}
- Frequency: {x.event_frequency:.0%}
- (95% CI: {x.event_frequency_ci_low:.0%} - {x.event_frequency_ci_upp:.0%})
- """
- marker.popup = ipyleaflet.Popup(
- child=ipywidgets.HTML(popup_html),
- auto_pan=False,
- )
- m.add(marker)
-
- @check_types
- @doc(
- summary="""
- Create an interactive map with markers showing variant frequencies or
- cohorts grouped by area (space), period (time) and taxon.
- """,
- parameters=dict(
- title="""
- If True, attempt to use metadata from input dataset as a plot
- title. Otherwise, use supplied value as a title.
- """,
- epilogue="Additional text to display below the map.",
- ),
- returns="""
- An interactive map with widgets for selecting which variant, taxon
- and time period to display.
- """,
- )
- def plot_frequencies_interactive_map(
- self,
- ds: frq_params.ds_frequencies_advanced,
- center: map_params.center = map_params.center_default,
- zoom: map_params.zoom = map_params.zoom_default,
- title: Union[bool, str] = True,
- epilogue: Union[bool, str] = True,
- ):
- import ipyleaflet
- import ipywidgets
-
- # Handle title.
- if title is True:
- title = ds.attrs.get("title", None)
-
- # Create a map.
- freq_map = ipyleaflet.Map(center=center, zoom=zoom)
-
- # Set up interactive controls.
- variants = ds["variant_label"].values
- taxa = ds["cohort_taxon"].to_pandas().dropna().unique() # type: ignore
- periods = ds["cohort_period"].to_pandas().dropna().unique() # type: ignore
- controls = ipywidgets.interactive(
- self.plot_frequencies_map_markers,
- m=ipywidgets.fixed(freq_map),
- ds=ipywidgets.fixed(ds),
- variant=ipywidgets.Dropdown(options=variants, description="Variant: "),
- taxon=ipywidgets.Dropdown(options=taxa, description="Taxon: "),
- period=ipywidgets.Dropdown(options=periods, description="Period: "),
- clear=ipywidgets.fixed(True),
- )
-
- # Lay out widgets.
- components = []
- if title is not None:
- components.append(ipywidgets.HTML(value=f"{title}
"))
- components.append(controls)
- components.append(freq_map)
- if epilogue is True:
- epilogue = """
- Variant frequencies are shown as coloured markers. Opacity of color
- denotes frequency. Click on a marker for more information.
- """
- if epilogue:
- components.append(ipywidgets.HTML(value=f"{epilogue}"))
-
- out = ipywidgets.VBox(components)
-
- return out
-
def snp_genotype_allele_counts(
self,
transcript: base_params.transcript,
diff --git a/tests/anoph/conftest.py b/tests/anoph/conftest.py
index 1564b029..5927debb 100644
--- a/tests/anoph/conftest.py
+++ b/tests/anoph/conftest.py
@@ -1013,6 +1013,14 @@ def contigs(self) -> Tuple[str, ...]:
def random_contig(self):
return choice(self.contigs)
+ def random_transcript_id(self):
+ df_transcripts = self.genome_features.query("type == 'mRNA'")
+ transcript_ids = [
+ t.split(";")[0].split("=")[1] for t in df_transcripts.loc[:, "attributes"]
+ ]
+ transcript_id = choice(transcript_ids)
+ return transcript_id
+
def random_region_str(self, region_size=None):
contig = self.random_contig()
contig_size = self.contig_sizes[contig]
diff --git a/tests/anoph/test_cnv_frq.py b/tests/anoph/test_cnv_frq.py
index ef96a400..a96cc0bf 100644
--- a/tests/anoph/test_cnv_frq.py
+++ b/tests/anoph/test_cnv_frq.py
@@ -11,6 +11,13 @@
from malariagen_data import ag3 as _ag3
from malariagen_data.anoph.cnv_frq import AnophelesCnvFrequencyAnalysis
from malariagen_data.util import compare_series_like
+from .test_frq import (
+ test_plot_frequencies_heatmap,
+ test_plot_frequencies_time_series,
+ test_plot_frequencies_time_series_with_taxa,
+ test_plot_frequencies_time_series_with_areas,
+ test_plot_frequencies_interactive_map,
+)
@pytest.fixture
@@ -109,6 +116,8 @@ def test_gene_cnv_frequencies_with_str_cohorts(
# Run the function under test.
df_cnv = api.gene_cnv_frequencies(**params)
+ test_plot_frequencies_heatmap(api, df_cnv)
+
# Figure out expected cohort labels.
df_samples = api.sample_metadata(sample_sets=sample_sets)
if "cohort_" + cohorts in df_samples:
@@ -166,12 +175,14 @@ def test_gene_cnv_frequencies_with_min_cohort_size(
return
# Run the function under test.
- df_snp = api.gene_cnv_frequencies(**params)
+ df_cnv = api.gene_cnv_frequencies(**params)
+
+ test_plot_frequencies_heatmap(api, df_cnv)
# Standard checks.
check_gene_cnv_frequencies(
api=api,
- df=df_snp,
+ df=df_cnv,
cohort_labels=cohort_labels,
region=region,
)
@@ -212,12 +223,14 @@ def test_gene_cnv_frequencies_with_str_cohorts_and_sample_query(
)
# Run the function under test.
- df_snp = api.gene_cnv_frequencies(**params)
+ df_cnv = api.gene_cnv_frequencies(**params)
+
+ test_plot_frequencies_heatmap(api, df_cnv)
# Standard checks.
check_gene_cnv_frequencies(
api=api,
- df=df_snp,
+ df=df_cnv,
cohort_labels=cohort_labels,
region=region,
)
@@ -268,12 +281,14 @@ def test_gene_cnv_frequencies_with_str_cohorts_and_sample_query_options(
)
# Run the function under test.
- df_snp = api.gene_cnv_frequencies(**params)
+ df_cnv = api.gene_cnv_frequencies(**params)
+
+ test_plot_frequencies_heatmap(api, df_cnv)
# Standard checks.
check_gene_cnv_frequencies(
api=api,
- df=df_snp,
+ df=df_cnv,
cohort_labels=cohort_labels,
region=region,
)
@@ -305,12 +320,14 @@ def test_gene_cnv_frequencies_with_dict_cohorts(
)
# Run the function under test.
- df_snp = api.gene_cnv_frequencies(**params)
+ df_cnv = api.gene_cnv_frequencies(**params)
+
+ test_plot_frequencies_heatmap(api, df_cnv)
# Standard checks.
check_gene_cnv_frequencies(
api=api,
- df=df_snp,
+ df=df_cnv,
cohort_labels=cohort_labels,
region=region,
)
@@ -350,6 +367,9 @@ def test_gene_cnv_frequencies_without_drop_invariant(
df_cnv_a = api.gene_cnv_frequencies(drop_invariant=True, **params)
df_cnv_b = api.gene_cnv_frequencies(drop_invariant=False, **params)
+ test_plot_frequencies_heatmap(api, df_cnv_a)
+ test_plot_frequencies_heatmap(api, df_cnv_b)
+
# Standard checks.
check_gene_cnv_frequencies(
api=api,
@@ -418,6 +438,8 @@ def test_gene_cnv_frequencies_with_max_coverage_variance(
# checks.
df_cnv = api.gene_cnv_frequencies(**params)
+ test_plot_frequencies_heatmap(api, df_cnv)
+
# Figure out expected cohort labels.
df_samples = api.sample_metadata(sample_sets=sample_sets)
if "cohort_" + cohorts in df_samples:
@@ -711,6 +733,10 @@ def check_gene_cnv_frequencies_advanced(
# Check the result.
assert isinstance(ds, xr.Dataset)
+ test_plot_frequencies_time_series(api, ds)
+ test_plot_frequencies_time_series_with_taxa(api, ds)
+ test_plot_frequencies_time_series_with_areas(api, ds)
+ test_plot_frequencies_interactive_map(api, ds)
assert set(ds.dims) == {"cohorts", "variants"}
# Check variant variables.
diff --git a/tests/anoph/test_frq.py b/tests/anoph/test_frq.py
new file mode 100644
index 00000000..dfba620c
--- /dev/null
+++ b/tests/anoph/test_frq.py
@@ -0,0 +1,98 @@
+import pytest
+import plotly.graph_objects as go # type: ignore
+
+import random
+
+
+@pytest.mark.skip
+def test_plot_frequencies_heatmap(api, frq_df):
+ fig = api.plot_frequencies_heatmap(frq_df, show=False, max_len=None)
+ assert isinstance(fig, go.Figure)
+
+ # Test max_len behaviour.
+ with pytest.raises(ValueError):
+ api.plot_frequencies_heatmap(frq_df, show=False, max_len=len(frq_df) - 1)
+
+ # Test index parameter - if None, should use dataframe index.
+ fig = api.plot_frequencies_heatmap(frq_df, show=False, index=None, max_len=None)
+
+ if "contig" in list(frq_df.columns):
+ # Not unique.
+ with pytest.raises(ValueError):
+ api.plot_frequencies_heatmap(
+ frq_df, show=False, index="contig", max_len=None
+ )
+
+
+@pytest.mark.skip
+def test_plot_frequencies_time_series(api, ds):
+ # Trim things down a bit for speed.
+ ds = ds.isel(variants=slice(0, 100))
+
+ # Plot.
+ fig = api.plot_frequencies_time_series(ds, show=False)
+
+ # Test.
+ assert isinstance(fig, go.Figure)
+
+
+@pytest.mark.skip
+def test_plot_frequencies_time_series_with_taxa(api, ds):
+ # Trim things down a bit for speed.
+ ds = ds.isel(variants=slice(0, 100))
+
+ taxa = list(ds.cohort_taxon.to_dataframe()["cohort_taxon"].unique())
+ taxon = random.choice(taxa)
+
+ # Plot with taxon.
+ fig = api.plot_frequencies_time_series(ds, show=False, taxa=taxon)
+
+ # Test taxon plot.
+ assert isinstance(fig, go.Figure)
+
+ # Plot with taxa.
+ fig = api.plot_frequencies_time_series(ds, show=False, taxa=taxa)
+
+ # Test taxa plot.
+ assert isinstance(fig, go.Figure)
+
+
+@pytest.mark.skip
+def test_plot_frequencies_time_series_with_areas(api, ds):
+ # Trim things down a bit for speed.
+ ds = ds.isel(variants=slice(0, 100))
+
+ # Extract cohorts into a DataFrame.
+ cohort_vars = [v for v in ds if str(v).startswith("cohort_")]
+ df_cohorts = ds[cohort_vars].to_dataframe()
+
+ # Pick a random area and areas from valid areas.
+ cohorts_areas = df_cohorts["cohort_area"].dropna().unique().tolist()
+ area = random.choice(cohorts_areas)
+ areas = random.sample(cohorts_areas, random.randint(1, len(cohorts_areas)))
+
+ # Plot with area.
+ fig = api.plot_frequencies_time_series(ds, show=False, areas=area)
+
+ # Test areas plot.
+ assert isinstance(fig, go.Figure)
+
+ # Plot with areas.
+ fig = api.plot_frequencies_time_series(ds, show=False, areas=areas)
+
+ # Test area plot.
+ assert isinstance(fig, go.Figure)
+
+
+@pytest.mark.skip
+def test_plot_frequencies_interactive_map(api, ds):
+ import ipywidgets # type: ignore
+
+ # Trim things down a bit for speed.
+ ds = ds.isel(variants=slice(0, 100))
+
+ # Plot.
+ fig = api.plot_frequencies_interactive_map(ds)
+
+ # Test.
+ assert isinstance(fig, ipywidgets.Widget)
diff --git a/tests/anoph/test_hap_frq.py b/tests/anoph/test_hap_frq.py
index 7dc54346..a2077d6c 100644
--- a/tests/anoph/test_hap_frq.py
+++ b/tests/anoph/test_hap_frq.py
@@ -7,7 +7,15 @@
from pytest_cases import parametrize_with_cases
from malariagen_data import ag3 as _ag3
+from malariagen_data import af1 as _af1
from malariagen_data.anoph.hap_frq import AnophelesHapFrequencyAnalysis
+from .test_frq import (
+ test_plot_frequencies_heatmap,
+ test_plot_frequencies_time_series,
+ test_plot_frequencies_time_series_with_taxa,
+ test_plot_frequencies_time_series_with_areas,
+ test_plot_frequencies_interactive_map,
+)
@pytest.fixture
@@ -36,6 +44,23 @@ def ag3_sim_api(ag3_sim_fixture):
)
+@pytest.fixture
+def af1_sim_api(af1_sim_fixture):
+ return AnophelesHapFrequencyAnalysis(
+ url=af1_sim_fixture.url,
+ config_path=_af1.CONFIG_PATH,
+ major_version_number=_af1.MAJOR_VERSION_NUMBER,
+ major_version_path=_af1.MAJOR_VERSION_PATH,
+ pre=False,
+ gff_gene_type="protein_coding_gene",
+ gff_gene_name_attribute="Note",
+ gff_default_attributes=("ID", "Parent", "Note", "description"),
+ results_cache=af1_sim_fixture.results_cache_path.as_posix(),
+ taxon_colors=_af1.TAXON_COLORS,
+ default_phasing_analysis="funestus",
+ )
+
+
# N.B., here we use pytest_cases to parametrize tests. Each
# function whose name begins with "case_" defines a set of
# inputs to the test functions. See the documentation for
@@ -52,6 +77,10 @@ def case_ag3_sim(ag3_sim_fixture, ag3_sim_api):
return ag3_sim_fixture, ag3_sim_api
+def case_af1_sim(af1_sim_fixture, af1_sim_api):
+ return af1_sim_fixture, af1_sim_api
+
+
def check_frequency(x):
loc_nan = np.isnan(x)
assert np.all(x[~loc_nan] >= 0)
@@ -70,9 +99,8 @@ def check_hap_frequencies(*, api, df, sample_sets, cohorts, min_cohort_size):
cohort_counts = df_samples[cohort_column].value_counts()
cohort_labels = cohort_counts[cohort_counts >= min_cohort_size].index.to_list()
- universal_fields = ["label"]
frq_fields = ["frq_" + s for s in cohort_labels] + ["max_af"]
- expected_fields = universal_fields + frq_fields
+ expected_fields = frq_fields
assert sorted(df.columns.tolist()) == sorted(expected_fields)
@@ -82,6 +110,10 @@ def check_hap_frequencies_advanced(
ds,
):
assert isinstance(ds, xr.Dataset)
+ test_plot_frequencies_time_series(api, ds)
+ test_plot_frequencies_time_series_with_taxa(api, ds)
+ test_plot_frequencies_time_series_with_areas(api, ds)
+ test_plot_frequencies_interactive_map(api, ds)
assert set(ds.dims) == {"cohorts", "variants"}
expected_cohort_vars = [
@@ -154,6 +186,8 @@ def test_hap_frequencies_with_str_cohorts(
# Run the function under test.
df_hap = api.haplotypes_frequencies(**params)
+ test_plot_frequencies_heatmap(api, df_hap)
+
# Standard checks.
check_hap_frequencies(
api=api,
diff --git a/tests/anoph/test_snp_frq.py b/tests/anoph/test_snp_frq.py
index e18d9a19..412fdc21 100644
--- a/tests/anoph/test_snp_frq.py
+++ b/tests/anoph/test_snp_frq.py
@@ -7,12 +7,18 @@
from pytest_cases import parametrize_with_cases
import xarray as xr
from numpy.testing import assert_allclose, assert_array_equal
-import plotly.graph_objects as go # type: ignore
from malariagen_data import af1 as _af1
from malariagen_data import ag3 as _ag3
from malariagen_data.anoph.snp_frq import AnophelesSnpFrequencyAnalysis
from malariagen_data.util import compare_series_like
+from .test_frq import (
+ test_plot_frequencies_heatmap,
+ test_plot_frequencies_time_series,
+ test_plot_frequencies_time_series_with_taxa,
+ test_plot_frequencies_time_series_with_areas,
+ test_plot_frequencies_interactive_map,
+)
@pytest.fixture
@@ -315,6 +321,8 @@ def test_allele_frequencies_with_str_cohorts(
# Run the function under test.
df_snp = api.snp_allele_frequencies(**params)
+ test_plot_frequencies_heatmap(api, df_snp)
+
# Figure out expected cohort labels.
df_samples = api.sample_metadata(sample_sets=sample_sets)
if "cohort_" + cohorts in df_samples:
@@ -335,6 +343,8 @@ def test_allele_frequencies_with_str_cohorts(
# Run the function under test.
df_aa = api.aa_allele_frequencies(**params)
+ test_plot_frequencies_heatmap(api, df_aa)
+
# Standard checks.
check_aa_allele_frequencies(
df=df_aa,
@@ -387,6 +397,8 @@ def test_allele_frequencies_with_min_cohort_size(
# Run the function under test.
df_snp = api.snp_allele_frequencies(**params)
+ test_plot_frequencies_heatmap(api, df_snp)
+
# Standard checks.
check_snp_allele_frequencies(
api=api,
@@ -398,6 +410,8 @@ def test_allele_frequencies_with_min_cohort_size(
# Run the function under test.
df_aa = api.aa_allele_frequencies(**params)
+ test_plot_frequencies_heatmap(api, df_aa)
+
# Standard checks.
check_aa_allele_frequencies(
df=df_aa,
@@ -444,6 +458,8 @@ def test_allele_frequencies_with_str_cohorts_and_sample_query(
# Run the function under test.
df_snp = api.snp_allele_frequencies(**params)
+ test_plot_frequencies_heatmap(api, df_snp)
+
# Standard checks.
check_snp_allele_frequencies(
api=api,
@@ -455,6 +471,8 @@ def test_allele_frequencies_with_str_cohorts_and_sample_query(
# Run the function under test.
df_aa = api.aa_allele_frequencies(**params)
+ test_plot_frequencies_heatmap(api, df_aa)
+
# Standard checks.
check_aa_allele_frequencies(
df=df_aa,
@@ -511,6 +529,8 @@ def test_allele_frequencies_with_str_cohorts_and_sample_query_options(
# Run the function under test.
df_snp = api.snp_allele_frequencies(**params)
+ test_plot_frequencies_heatmap(api, df_snp)
+
# Standard checks.
check_snp_allele_frequencies(
api=api,
@@ -522,6 +542,8 @@ def test_allele_frequencies_with_str_cohorts_and_sample_query_options(
# Run the function under test.
df_aa = api.aa_allele_frequencies(**params)
+ test_plot_frequencies_heatmap(api, df_aa)
+
# Standard checks.
check_aa_allele_frequencies(
df=df_aa,
@@ -559,6 +581,8 @@ def test_allele_frequencies_with_dict_cohorts(
# Run the function under test.
df_snp = api.snp_allele_frequencies(**params)
+ test_plot_frequencies_heatmap(api, df_snp)
+
# Standard checks.
check_snp_allele_frequencies(
api=api,
@@ -570,6 +594,8 @@ def test_allele_frequencies_with_dict_cohorts(
# Run the function under test.
df_aa = api.aa_allele_frequencies(**params)
+ test_plot_frequencies_heatmap(api, df_aa)
+
# Standard checks.
check_aa_allele_frequencies(
df=df_aa,
@@ -613,6 +639,9 @@ def test_allele_frequencies_without_drop_invariant(
df_snp_a = api.snp_allele_frequencies(drop_invariant=True, **params)
df_snp_b = api.snp_allele_frequencies(drop_invariant=False, **params)
+ test_plot_frequencies_heatmap(api, df_snp_a)
+ test_plot_frequencies_heatmap(api, df_snp_b)
+
# Standard checks.
check_snp_allele_frequencies(
api=api,
@@ -667,6 +696,9 @@ def test_allele_frequencies_without_effects(
df_snp_a = api.snp_allele_frequencies(effects=True, **params)
df_snp_b = api.snp_allele_frequencies(effects=False, **params)
+ test_plot_frequencies_heatmap(api, df_snp_a)
+ test_plot_frequencies_heatmap(api, df_snp_b)
+
# Standard checks.
check_snp_allele_frequencies(
api=api,
@@ -766,6 +798,8 @@ def test_allele_frequencies_with_region(
# Run the function under test.
df_snp = api.snp_allele_frequencies(**params)
+ test_plot_frequencies_heatmap(api, df_snp)
+
# Basic checks.
assert isinstance(df_snp, pd.DataFrame)
assert len(df_snp) > 0
@@ -820,6 +854,9 @@ def test_allele_frequencies_with_dup_samples(
sample_sets=[sample_set, sample_set], **params
)
+ test_plot_frequencies_heatmap(api, df_snp_a)
+ test_plot_frequencies_heatmap(api, df_snp_b)
+
# Expect automatically deduplicate sample sets.
assert_frame_equal(df_snp_b, df_snp_a)
@@ -876,6 +913,10 @@ def check_snp_allele_frequencies_advanced(
# Check the result.
assert isinstance(ds, xr.Dataset)
+ test_plot_frequencies_time_series(api, ds)
+ test_plot_frequencies_time_series_with_taxa(api, ds)
+ test_plot_frequencies_time_series_with_areas(api, ds)
+ test_plot_frequencies_interactive_map(api, ds)
assert set(ds.dims) == {"cohorts", "variants"}
# Check variant variables.
@@ -1063,6 +1104,10 @@ def check_aa_allele_frequencies_advanced(
# Check the result.
assert isinstance(ds, xr.Dataset)
+ test_plot_frequencies_time_series(api, ds)
+ test_plot_frequencies_time_series_with_taxa(api, ds)
+ test_plot_frequencies_time_series_with_areas(api, ds)
+ test_plot_frequencies_interactive_map(api, ds)
assert set(ds.dims) == {"cohorts", "variants"}
expected_variant_vars = (
@@ -1429,248 +1474,3 @@ def test_allele_frequencies_advanced_with_dup_samples(
api=api,
sample_sets=sample_sets,
)
-
-
-@parametrize_with_cases("fixture,api", cases=".")
-def test_plot_frequencies_heatmap(
- fixture,
- api: AnophelesSnpFrequencyAnalysis,
-):
- # Pick test parameters at random.
- all_sample_sets = api.sample_sets()["sample_set"].to_list()
- sample_sets = random.choice(all_sample_sets)
- site_mask = random.choice(api.site_mask_ids + (None,))
- min_cohort_size = random.randint(0, 2)
- transcript = random_transcript(api=api).name
- cohorts = random.choice(
- ["admin1_year", "admin1_month", "admin2_year", "admin2_month"]
- )
-
- # Set up call params.
- params = dict(
- transcript=transcript,
- cohorts=cohorts,
- min_cohort_size=min_cohort_size,
- site_mask=site_mask,
- sample_sets=sample_sets,
- )
-
- # Test SNP allele frequencies.
- df_snp = api.snp_allele_frequencies(**params)
- fig = api.plot_frequencies_heatmap(df_snp, show=False, max_len=None)
- assert isinstance(fig, go.Figure)
-
- # Test amino acid change allele frequencies.
- df_aa = api.aa_allele_frequencies(**params)
- fig = api.plot_frequencies_heatmap(df_aa, show=False, max_len=None)
- assert isinstance(fig, go.Figure)
-
- # Test max_len behaviour.
- with pytest.raises(ValueError):
- api.plot_frequencies_heatmap(df_snp, show=False, max_len=len(df_snp) - 1)
-
- # Test index parameter - if None, should use dataframe index.
- fig = api.plot_frequencies_heatmap(df_snp, show=False, index=None, max_len=None)
- # Not unique.
- with pytest.raises(ValueError):
- api.plot_frequencies_heatmap(df_snp, show=False, index="contig", max_len=None)
-
-
-@parametrize_with_cases("fixture,api", cases=".")
-def test_plot_frequencies_time_series(
- fixture,
- api: AnophelesSnpFrequencyAnalysis,
-):
- # Pick test parameters at random.
- all_sample_sets = api.sample_sets()["sample_set"].to_list()
- sample_sets = random.choice(all_sample_sets)
- site_mask = random.choice(api.site_mask_ids + (None,))
- min_cohort_size = random.randint(0, 2)
- transcript = random_transcript(api=api).name
- area_by = random.choice(["country", "admin1_iso", "admin2_name"])
- period_by = random.choice(["year", "quarter", "month"])
-
- # Compute SNP frequencies.
- ds = api.snp_allele_frequencies_advanced(
- transcript=transcript,
- area_by=area_by,
- period_by=period_by,
- sample_sets=sample_sets,
- min_cohort_size=min_cohort_size,
- site_mask=site_mask,
- )
-
- # Trim things down a bit for speed.
- ds = ds.isel(variants=slice(0, 100))
-
- # Plot.
- fig = api.plot_frequencies_time_series(ds, show=False)
-
- # Test.
- assert isinstance(fig, go.Figure)
-
- # Compute amino acid change frequencies.
- ds = api.aa_allele_frequencies_advanced(
- transcript=transcript,
- area_by=area_by,
- period_by=period_by,
- sample_sets=sample_sets,
- min_cohort_size=min_cohort_size,
- )
-
- # Trim things down a bit for speed.
- ds = ds.isel(variants=slice(0, 100))
-
- # Plot.
- fig = api.plot_frequencies_time_series(ds, show=False)
-
- # Test.
- assert isinstance(fig, go.Figure)
-
-
-@parametrize_with_cases("fixture,api", cases=".")
-def test_plot_frequencies_time_series_with_taxa(
- fixture,
- api: AnophelesSnpFrequencyAnalysis,
-):
- # Pick test parameters at random.
- all_sample_sets = api.sample_sets()["sample_set"].to_list()
- sample_sets = random.choice(all_sample_sets)
- site_mask = random.choice(api.site_mask_ids + (None,))
- transcript = random_transcript(api=api).name
- area_by = random.choice(["country", "admin1_iso", "admin2_name"])
- period_by = random.choice(["year", "quarter", "month"])
-
- # Pick a random taxon and taxa from valid taxa.
- sample_sets_taxa = (
- api.sample_metadata(sample_sets=sample_sets)["taxon"].dropna().unique().tolist()
- )
- taxon = random.choice(sample_sets_taxa)
- taxa = random.sample(sample_sets_taxa, random.randint(1, len(sample_sets_taxa)))
-
- # Compute SNP frequencies.
- ds = api.snp_allele_frequencies_advanced(
- transcript=transcript,
- area_by=area_by,
- period_by=period_by,
- sample_sets=sample_sets,
- min_cohort_size=1, # Don't exclude any samples.
- site_mask=site_mask,
- )
-
- # Trim things down a bit for speed.
- ds = ds.isel(variants=slice(0, 100))
-
- # Plot with taxon.
- fig = api.plot_frequencies_time_series(ds, show=False, taxa=taxon)
-
- # Test taxon plot.
- assert isinstance(fig, go.Figure)
-
- # Plot with taxa.
- fig = api.plot_frequencies_time_series(ds, show=False, taxa=taxa)
-
- # Test taxa plot.
- assert isinstance(fig, go.Figure)
-
-
-@parametrize_with_cases("fixture,api", cases=".")
-def test_plot_frequencies_time_series_with_areas(
- fixture,
- api: AnophelesSnpFrequencyAnalysis,
-):
- # Pick test parameters at random.
- all_sample_sets = api.sample_sets()["sample_set"].to_list()
- sample_sets = random.choice(all_sample_sets)
- site_mask = random.choice(api.site_mask_ids + (None,))
- transcript = random_transcript(api=api).name
- area_by = random.choice(["country", "admin1_iso", "admin2_name"])
- period_by = random.choice(["year", "quarter", "month"])
-
- # Compute SNP frequencies.
- ds = api.snp_allele_frequencies_advanced(
- transcript=transcript,
- area_by=area_by,
- period_by=period_by,
- sample_sets=sample_sets,
- min_cohort_size=1, # Don't exclude any samples.
- site_mask=site_mask,
- )
-
- # Trim things down a bit for speed.
- ds = ds.isel(variants=slice(0, 100))
-
- # Extract cohorts into a DataFrame.
- cohort_vars = [v for v in ds if str(v).startswith("cohort_")]
- df_cohorts = ds[cohort_vars].to_dataframe()
-
- # Pick a random area and areas from valid areas.
- cohorts_areas = df_cohorts["cohort_area"].dropna().unique().tolist()
- area = random.choice(cohorts_areas)
- areas = random.sample(cohorts_areas, random.randint(1, len(cohorts_areas)))
-
- # Plot with area.
- fig = api.plot_frequencies_time_series(ds, show=False, areas=area)
-
- # Test areas plot.
- assert isinstance(fig, go.Figure)
-
- # Plot with areas.
- fig = api.plot_frequencies_time_series(ds, show=False, areas=areas)
-
- # Test area plot.
- assert isinstance(fig, go.Figure)
-
-
-@parametrize_with_cases("fixture,api", cases=".")
-def test_plot_frequencies_interactive_map(
- fixture,
- api: AnophelesSnpFrequencyAnalysis,
-):
- import ipywidgets # type: ignore
-
- # Pick test parameters at random.
- all_sample_sets = api.sample_sets()["sample_set"].to_list()
- sample_sets = random.choice(all_sample_sets)
- site_mask = random.choice(api.site_mask_ids + (None,))
- min_cohort_size = random.randint(0, 2)
- transcript = random_transcript(api=api).name
- area_by = random.choice(["country", "admin1_iso", "admin2_name"])
- period_by = random.choice(["year", "quarter", "month"])
-
- # Compute SNP frequencies.
- ds = api.snp_allele_frequencies_advanced(
- transcript=transcript,
- area_by=area_by,
- period_by=period_by,
- sample_sets=sample_sets,
- min_cohort_size=min_cohort_size,
- site_mask=site_mask,
- )
-
- # Trim things down a bit for speed.
- ds = ds.isel(variants=slice(0, 100))
-
- # Plot.
- fig = api.plot_frequencies_interactive_map(ds)
-
- # Test.
- assert isinstance(fig, ipywidgets.Widget)
-
- # Compute amino acid change frequencies.
- ds = api.aa_allele_frequencies_advanced(
- transcript=transcript,
- area_by=area_by,
- period_by=period_by,
- sample_sets=sample_sets,
- min_cohort_size=min_cohort_size,
- )
-
- # Trim things down a bit for speed.
- ds = ds.isel(variants=slice(0, 100))
-
- # Plot.
- fig = api.plot_frequencies_interactive_map(ds)
-
- # Test.
- assert isinstance(fig, ipywidgets.Widget)