Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revamp survival analysis interface #842

Merged
merged 27 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
b1d36b8
cox_ph add all arguments
aGuyLearning Dec 18, 2024
35dbacf
updated test to use keywords
aGuyLearning Dec 18, 2024
22d190a
weibull_aft arguments update
aGuyLearning Dec 18, 2024
742d38c
log_logistic update
aGuyLearning Dec 18, 2024
02e343d
updated log logistic example
aGuyLearning Dec 18, 2024
6038c7a
Merge branch 'main' into enhancement/issue-840
Zethson Dec 20, 2024
8e1baa5
store summary df in adata.uns
aGuyLearning Jan 8, 2025
119947b
Merge branch 'main' into enhancement/issue-840
aGuyLearning Jan 8, 2025
eb9daba
try moving np
eroell Jan 8, 2025
e340a28
omit inplace keyword
aGuyLearning Jan 8, 2025
c6a81df
added explanation, as to where the results are stored
aGuyLearning Jan 8, 2025
38f4efb
corrected spelling
aGuyLearning Jan 8, 2025
501b864
updated tests to check for .uns ( should be removed later, when the u…
aGuyLearning Jan 8, 2025
eb0b404
fix argument order, doc fixes
eroell Jan 8, 2025
cffed4d
slightly simpler wording
eroell Jan 8, 2025
3b21988
fiexed spelling
aGuyLearning Jan 9, 2025
58ce157
Update ehrapy/tools/_sa.py
eroell Jan 9, 2025
ee97f31
Update ehrapy/tools/_sa.py
eroell Jan 9, 2025
6dc7831
Update ehrapy/tools/_sa.py
eroell Jan 9, 2025
540b79f
Update ehrapy/tools/_sa.py
eroell Jan 9, 2025
09484d9
Update ehrapy/tools/_sa.py
eroell Jan 9, 2025
96db288
Update ehrapy/tools/_sa.py
eroell Jan 9, 2025
568b84b
Update ehrapy/tools/_sa.py
eroell Jan 9, 2025
cf00a3f
renamed function to be clearer
aGuyLearning Jan 10, 2025
b574978
Add uns_key parameter to Kaplan-Meier, Nelson-Aalen, and Weibull func…
aGuyLearning Jan 10, 2025
ace1baf
Update test assertions in TestSA for event_table handling and pass ad…
aGuyLearning Jan 10, 2025
3de5678
uns to in doc
eroell Jan 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 231 additions & 20 deletions ehrapy/tools/_sa.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import warnings
from typing import TYPE_CHECKING, Literal

import numpy as np # This package is implicitly used
import pandas as pd
import statsmodels.api as sm
import statsmodels.formula.api as smf
Expand All @@ -23,6 +22,7 @@
if TYPE_CHECKING:
from collections.abc import Iterable

import numpy as np
from anndata import AnnData
from statsmodels.genmod.generalized_linear_model import GLMResultsWrapper

Expand Down Expand Up @@ -347,23 +347,43 @@ def anova_glm(result_1: GLMResultsWrapper, result_2: GLMResultsWrapper, formula_
return dataframe


def _regression_model(
model_class, adata: AnnData, duration_col: str, event_col: str, entry_col: str = None, accept_zero_duration=True
):
def _regression_model_data_frame_preparation(adata: AnnData, duration_col: str, accept_zero_duration=True):
"""Convenience function for regression models."""
eroell marked this conversation as resolved.
Show resolved Hide resolved
df = anndata_to_df(adata)
df = df.dropna()

if not accept_zero_duration:
df.loc[df[duration_col] == 0, duration_col] += 1e-5

model = model_class()
model.fit(df, duration_col, event_col, entry_col=entry_col)

return model
return df


def cox_ph(adata: AnnData, duration_col: str, event_col: str, entry_col: str = None) -> CoxPHFitter:
def cox_ph(
adata: AnnData,
duration_col: str,
*,
inplace: bool = True,
aGuyLearning marked this conversation as resolved.
Show resolved Hide resolved
uns_key: str = "cox_ph",
alpha: float = 0.05,
label: str | None = None,
eroell marked this conversation as resolved.
Show resolved Hide resolved
baseline_estimation_method: Literal["breslow", "spline", "piecewise"] = "breslow",
penalizer: float | np.ndarray = 0.0,
l1_ratio: float = 0.0,
strata: list[str] | str | None = None,
n_baseline_knots: int = 4,
knots: list[float] | None = None,
breakpoints: list[float] | None = None,
event_col: str = None,
weights_col: str | None = None,
cluster_col: str | None = None,
entry_col: str = None,
robust: bool = False,
formula: str = None,
batch_mode: bool = None,
show_progress: bool = False,
initial_point: np.ndarray | None = None,
fit_options: dict | None = None,
) -> CoxPHFitter:
"""Fit the Cox’s proportional hazard for the survival function.

The Cox proportional hazards model (CoxPH) examines the relationship between the survival time of subjects and one or more predictor variables.
Expand All @@ -376,7 +396,27 @@ def cox_ph(adata: AnnData, duration_col: str, event_col: str, entry_col: str = N
duration_col: The name of the column in the AnnData objects that contains the subjects’ lifetimes.
event_col: The name of the column in anndata that contains the subjects’ death observation.
If left as None, assume all individuals are uncensored.
inplace: Whether to modify the AnnData object in place.
uns_key: The key to use for the uns slot in the AnnData object.
alpha: The alpha value in the confidence intervals.
label: A string to name the column of the estimate.
eroell marked this conversation as resolved.
Show resolved Hide resolved
baseline_estimation_method: The method used to estimate the baseline hazard. Options are 'breslow', 'spline', and 'piecewise'.
penalizer: Attach a penalty to the size of the coefficients during regression. This improves stability of the estimates and controls for high correlation between covariates.
l1_ratio: Specify what ratio to assign to a L1 vs L2 penalty. Same as scikit-learn. See penalizer above.
strata: specify a list of columns to use in stratification. This is useful if a categorical covariate does not obey the proportional hazard assumption. This is used similar to the strata expression in R. See http://courses.washington.edu/b515/l17.pdf.
n_baseline_knots: Used when baseline_estimation_method="spline". Set the number of knots (interior & exterior) in the baseline hazard, which will be placed evenly along the time axis. Should be at least 2. Royston et. al, the authors of this model, suggest 4 to start, but any values between 2 and 8 are reasonable. If you need to customize the timestamps used to calculate the curve, use the knots parameter instead.
knots: When baseline_estimation_method="spline", this allows customizing the points in the time axis for the baseline hazard curve. To use evenly-spaced points in time, the n_baseline_knots parameter can be employed instead.
breakpoints: Used when baseline_estimation_method="piecewise". Set the positions of the baseline hazard breakpoints.
event_col: he name of the column in DataFrame that contains the subjects’ death observation. If left as None, assume all individuals are uncensored.
weights_col: The name of the column in DataFrame that contains the weights for each subject.
cluster_col: The name of the column in DataFrame that contains the cluster variable. Using this forces the sandwich estimator (robust variance estimator) to be used.
entry_col: Column denoting when a subject entered the study, i.e. left-truncation.
robust: Compute the robust errors using the Huber sandwich estimator, aka Wei-Lin estimate. This does not handle ties, so if there are high number of ties, results may significantly differ.
formula: an Wilkinson formula, like in R and statsmodels, for the right-hand-side. If left as None, all columns not assigned as durations, weights, etc. are used. Uses the library Formulaic for parsing.
batch_mode: enabling batch_mode can be faster for datasets with a large number of ties. If left as None, lifelines will choose the best option.
eroell marked this conversation as resolved.
Show resolved Hide resolved
show_progress: since the fitter is iterative, show convergence diagnostics. Useful if convergence is failing.
eroell marked this conversation as resolved.
Show resolved Hide resolved
initial_point: set the starting point for the iterative solver.
fit_options: Additional keyword arguments to pass into the estimator.

Returns:
Fitted CoxPHFitter.
Expand All @@ -388,10 +428,62 @@ def cox_ph(adata: AnnData, duration_col: str, event_col: str, entry_col: str = N
>>> adata[:, ["censor_flg"]].X = np.where(adata[:, ["censor_flg"]].X == 0, 1, 0)
>>> cph = ep.tl.cox_ph(adata, "mort_day_censored", "censor_flg")
"""
return _regression_model(CoxPHFitter, adata, duration_col, event_col, entry_col)
df = _regression_model_data_frame_preparation(adata, duration_col)
cox_ph = CoxPHFitter(
alpha=alpha,
label=label,
strata=strata,
baseline_estimation_method=baseline_estimation_method,
penalizer=penalizer,
l1_ratio=l1_ratio,
n_baseline_knots=n_baseline_knots,
knots=knots,
breakpoints=breakpoints,
)
cox_ph.fit(
df,
duration_col=duration_col,
event_col=event_col,
entry_col=entry_col,
robust=robust,
initial_point=initial_point,
weights_col=weights_col,
cluster_col=cluster_col,
batch_mode=batch_mode,
formula=formula,
fit_options=fit_options,
show_progress=show_progress,
)

# Add the results to the AnnData object
aGuyLearning marked this conversation as resolved.
Show resolved Hide resolved
if inplace:
summary = cox_ph.summary
adata.uns[uns_key] = summary

def weibull_aft(adata: AnnData, duration_col: str, event_col: str, entry_col: str = None) -> WeibullAFTFitter:
return cox_ph


def weibull_aft(
adata: AnnData,
duration_col: str,
*,
inplace: bool = True,
aGuyLearning marked this conversation as resolved.
Show resolved Hide resolved
uns_key: str = "weibull_aft",
alpha: float = 0.05,
fit_intercept: bool = True,
penalizer: float | np.ndarray = 0.0,
l1_ratio: float = 0.0,
model_ancillary: bool = True,
event_col: str | None = None,
ancillary: bool | pd.DataFrame | str | None = None,
show_progress: bool = False,
weights_col: str | None = None,
robust: bool = False,
initial_point=None,
entry_col: str | None = None,
formula: str | None = None,
fit_options: dict | None = None,
) -> WeibullAFTFitter:
"""Fit the Weibull accelerated failure time regression for the survival function.

The Weibull Accelerated Failure Time (AFT) survival regression model is a statistical method used to analyze time-to-event data,
Expand All @@ -403,24 +495,95 @@ def weibull_aft(adata: AnnData, duration_col: str, event_col: str, entry_col: st
Args:
adata: AnnData object with necessary columns `duration_col` and `event_col`.
duration_col: Name of the column in the AnnData objects that contains the subjects’ lifetimes.
event_col: Name of the column in anndata that contains the subjects’ death observation.
inplace: Whether to modify the AnnData object in place.
uns_key: The key to use for the uns slot in the AnnData object.
alpha: The alpha value in the confidence intervals.
fit_intercept: Whether to fit an intercept term in the model.
penalizer: Attach a penalty to the size of the coefficients during regression. This improves stability of the estimates and controls for high correlation between covariates.
l1_ratio: Specify what ratio to assign to a L1 vs L2 penalty. Same as scikit-learn. See penalizer above.
model_ancillary: set the model instance to always model the ancillary parameter with the supplied Dataframe. This is useful for grid-search optimization.
event_col: Name of the column in anndata that contains the subjects’ death observation. 1 if observed, 0 else (censored).
If left as None, assume all individuals are uncensored.
ancillary: Choose to model the ancillary parameters.
If None or False, explicitly do not fit the ancillary parameters using any covariates.
If True, model the ancillary parameters with the same covariates as ``df``.
If DataFrame, provide covariates to model the ancillary parameters. Must be the same row count as ``df``.
If str, should be a formula
show_progress: since the fitter is iterative, show convergence diagnostics. Useful if convergence is failing.
weights_col: The name of the column in DataFrame that contains the weights for each subject.
robust: Compute the robust errors using the Huber sandwich estimator, aka Wei-Lin estimate. This does not handle ties, so if there are high number of ties, results may significantly differ.
initial_point: set the starting point for the iterative solver.
entry_col: Column denoting when a subject entered the study, i.e. left-truncation.
formula: Use an R-style formula for modeling the dataset. See formula syntax: https://matthewwardrop.github.io/formulaic/basic/grammar/
If a formula is not provided, all variables in the dataframe are used (minus those used for other purposes like event_col, etc.)
fit_options: Additional keyword arguments to pass into the estimator.


Returns:
Fitted WeibullAFTFitter.

Examples:
>>> import ehrapy as ep
>>> adata = ep.dt.mimic_2(encoded=False)
>>> # Flip 'censor_fl' because 0 = death and 1 = censored
>>> adata[:, ["censor_flg"]].X = np.where(adata[:, ["censor_flg"]].X == 0, 1, 0)
>>> aft = ep.tl.weibull_aft(adata, "mort_day_censored", "censor_flg")
>>> adata = adata[:, ["mort_day_censored", "censor_flg"]]
>>> aft = ep.tl.weibull_aft(adata, duration_col="mort_day_censored", event_col="censor_flg")
>>> aft.print_summary()
"""
return _regression_model(WeibullAFTFitter, adata, duration_col, event_col, entry_col, accept_zero_duration=False)

df = _regression_model_data_frame_preparation(adata, duration_col, accept_zero_duration=False)

def log_logistic_aft(adata: AnnData, duration_col: str, event_col: str, entry_col: str = None) -> LogLogisticAFTFitter:
weibull_aft = WeibullAFTFitter(
alpha=alpha,
fit_intercept=fit_intercept,
penalizer=penalizer,
l1_ratio=l1_ratio,
model_ancillary=model_ancillary,
)

weibull_aft.fit(
df,
duration_col=duration_col,
event_col=event_col,
entry_col=entry_col,
ancillary=ancillary,
show_progress=show_progress,
weights_col=weights_col,
robust=robust,
initial_point=initial_point,
formula=formula,
fit_options=fit_options,
)

# Add the results to the AnnData object
aGuyLearning marked this conversation as resolved.
Show resolved Hide resolved
if inplace:
summary = weibull_aft.summary
adata.uns[uns_key] = summary

return weibull_aft


def log_logistic_aft(
adata: AnnData,
duration_col: str,
*,
inplace: bool = True,
aGuyLearning marked this conversation as resolved.
Show resolved Hide resolved
uns_key: str = "log_logistic_aft",
alpha: float = 0.05,
fit_intercept: bool = True,
penalizer: float | np.ndarray = 0.0,
l1_ratio: float = 0.0,
model_ancillary: bool = False,
event_col: str | None = None,
ancillary: bool | pd.DataFrame | str | None = None,
show_progress: bool = False,
weights_col: str | None = None,
robust: bool = False,
initial_point=None,
entry_col: str | None = None,
formula: str | None = None,
fit_options: dict | None = None,
) -> LogLogisticAFTFitter:
"""Fit the log logistic accelerated failure time regression for the survival function.
The Log-Logistic Accelerated Failure Time (AFT) survival regression model is a powerful statistical tool employed in the analysis of time-to-event data.
This model operates under the assumption that the logarithm of survival time adheres to a log-logistic distribution, offering a flexible framework for understanding the impact of covariates on survival times.
Expand All @@ -431,9 +594,29 @@ def log_logistic_aft(adata: AnnData, duration_col: str, event_col: str, entry_co
Args:
adata: AnnData object with necessary columns `duration_col` and `event_col`.
duration_col: Name of the column in the AnnData objects that contains the subjects’ lifetimes.
event_col: Name of the column in anndata that contains the subjects’ death observation.
inplace: Whether to modify the AnnData object in place.
uns_key: The key to use for the uns slot in the AnnData object.
alpha: The alpha value in the confidence intervals.
alpha: The alpha value in the confidence intervals.
fit_intercept: Whether to fit an intercept term in the model.
penalizer: Attach a penalty to the size of the coefficients during regression. This improves stability of the estimates and controls for high correlation between covariates.
l1_ratio: Specify what ratio to assign to a L1 vs L2 penalty. Same as scikit-learn. See penalizer above.
model_ancillary: set the model instance to always model the ancillary parameter with the supplied Dataframe. This is useful for grid-search optimization.
eroell marked this conversation as resolved.
Show resolved Hide resolved
event_col: Name of the column in anndata that contains the subjects’ death observation. 1 if observed, 0 else (censored).
If left as None, assume all individuals are uncensored.
ancillary: Choose to model the ancillary parameters.
If None or False, explicitly do not fit the ancillary parameters using any covariates.
If True, model the ancillary parameters with the same covariates as ``df``.
If DataFrame, provide covariates to model the ancillary parameters. Must be the same row count as ``df``.
If str, should be a formula
show_progress: since the fitter is iterative, show convergence diagnostics. Useful if convergence is failing.
aGuyLearning marked this conversation as resolved.
Show resolved Hide resolved
weights_col: The name of the column in DataFrame that contains the weights for each subject.
robust: Compute the robust errors using the Huber sandwich estimator, aka Wei-Lin estimate. This does not handle ties, so if there are high number of ties, results may significantly differ.
initial_point: set the starting point for the iterative solver.
entry_col: Column denoting when a subject entered the study, i.e. left-truncation.
formula: Use an R-style formula for modeling the dataset. See formula syntax: https://matthewwardrop.github.io/formulaic/basic/grammar/
If a formula is not provided, all variables in the dataframe are used (minus those used for other purposes like event_col, etc.)
fit_options: Additional keyword arguments to pass into the estimator.

Returns:
Fitted LogLogisticAFTFitter.
Expand All @@ -443,12 +626,40 @@ def log_logistic_aft(adata: AnnData, duration_col: str, event_col: str, entry_co
>>> adata = ep.dt.mimic_2(encoded=False)
>>> # Flip 'censor_fl' because 0 = death and 1 = censored
>>> adata[:, ["censor_flg"]].X = np.where(adata[:, ["censor_flg"]].X == 0, 1, 0)
>>> llf = ep.tl.log_logistic_aft(adata, "mort_day_censored", "censor_flg")
>>> adata = adata[:, ["mort_day_censored", "censor_flg"]]
>>> llf = ep.tl.log_logistic_aft(adata, duration_col="mort_day_censored", event_col="censor_flg")
"""
return _regression_model(
LogLogisticAFTFitter, adata, duration_col, event_col, entry_col, accept_zero_duration=False
df = _regression_model_data_frame_preparation(adata, duration_col, accept_zero_duration=False)

log_logistic_aft = LogLogisticAFTFitter(
alpha=alpha,
fit_intercept=fit_intercept,
penalizer=penalizer,
l1_ratio=l1_ratio,
model_ancillary=model_ancillary,
)

log_logistic_aft.fit(
df,
duration_col=duration_col,
event_col=event_col,
entry_col=entry_col,
ancillary=ancillary,
show_progress=show_progress,
weights_col=weights_col,
robust=robust,
initial_point=initial_point,
formula=formula,
fit_options=fit_options,
)

# Add the results to the AnnData object
aGuyLearning marked this conversation as resolved.
Show resolved Hide resolved
if inplace:
summary = log_logistic_aft.summary
adata.uns[uns_key] = summary

return log_logistic_aft


def _univariate_model(
adata: AnnData,
Expand Down
2 changes: 1 addition & 1 deletion tests/tools/test_sa.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def _sa_function_assert(self, model, model_class):
def _sa_func_test(self, sa_function, sa_class, mimic_2_sa):
aGuyLearning marked this conversation as resolved.
Show resolved Hide resolved
adata, duration_col, event_col = mimic_2_sa

sa = sa_function(adata, duration_col, event_col)
sa = sa_function(adata, duration_col=duration_col, event_col=event_col)
self._sa_function_assert(sa, sa_class)

def test_kmf(self, mimic_2_sa):
Expand Down
Loading