Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add bias detection to preprocessing #690

Merged
merged 35 commits into from
May 4, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
744216d
Added correlation calculation
Lilly-May Apr 11, 2024
a1e6b2a
Standard. Mean Differences
Lilly-May Apr 11, 2024
68b1104
Added feature importances
Lilly-May Apr 12, 2024
0536586
Doc string improvements
Lilly-May Apr 13, 2024
c41ad45
Added correlations parameter
Lilly-May Apr 13, 2024
7233f96
Merge branch 'main' into feature/bias_detection
Zethson Apr 14, 2024
97b004b
PR Revisions
Lilly-May Apr 15, 2024
778c0c3
Added categorical value count calculation
Lilly-May Apr 15, 2024
7ad07ec
Added first test
Lilly-May Apr 16, 2024
7d483a3
docs clarifications
Lilly-May Apr 16, 2024
138860b
Test improvements
Lilly-May Apr 16, 2024
22f45ef
Merge branch 'main' into feature/bias_detection
Lilly-May Apr 25, 2024
c0bdcb1
Incorporate feature type detection
Lilly-May Apr 25, 2024
031808d
Finished tests
Lilly-May Apr 25, 2024
a863306
SMD improvements
Lilly-May Apr 25, 2024
cd44284
Merge branch 'main' into feature/bias_detection
Zethson Apr 25, 2024
efe6885
Merge branch 'main' into feature/bias_detection
Zethson Apr 25, 2024
eea9772
Test fixes
Lilly-May May 1, 2024
2895b41
Merge remote-tracking branch 'origin/feature/bias_detection' into fea…
Lilly-May May 1, 2024
ed9d8be
Merge branch 'main' into feature/bias_detection
Lilly-May May 1, 2024
5347c35
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 1, 2024
f1f4b4d
Save SMD in uns subdict
Lilly-May May 1, 2024
6688bf6
Fix tests and silence test warnings
Lilly-May May 1, 2024
381b8b1
Introduced copy parameter
Lilly-May May 1, 2024
3ff2c65
Added encoding check
Lilly-May May 1, 2024
bcfe3a4
Fixed sensitive_features dtype
Lilly-May May 1, 2024
b11f5ea
Feature importances return docstring
Lilly-May May 1, 2024
e1aaaae
Improved docs explanations
Lilly-May May 2, 2024
c1d3916
Sort feature importances results
Lilly-May May 2, 2024
f2d11f8
Apply suggestions from code review
Lilly-May May 2, 2024
2e8d630
Review comments
Lilly-May May 2, 2024
9d8b74e
doc formating
eroell May 3, 2024
daef606
Apply suggestions from code review
Lilly-May May 4, 2024
c22ee85
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 4, 2024
5ec7f8a
Fixed error raising
Lilly-May May 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ehrapy/preprocessing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ehrapy.preprocessing._bias import bias_detection
from ehrapy.preprocessing._encoding import encode, undo_encoding
from ehrapy.preprocessing._highly_variable_features import highly_variable_features
from ehrapy.preprocessing._imputation import (
Expand Down
119 changes: 119 additions & 0 deletions ehrapy/preprocessing/_bias.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
from collections.abc import Iterable
from typing import Literal

import numpy as np
import pandas as pd
from anndata import AnnData

from ehrapy import logging as logg
from ehrapy.anndata import anndata_to_df


def bias_detection(
adata: AnnData,
sensitive_features: Iterable[str] | Literal["all"],
Lilly-May marked this conversation as resolved.
Show resolved Hide resolved
corr_threshold: float = 0.5,
smd_threshold: float = 0.5,
feature_importance_threshold: float = 0.1,
prediction_confidence_threshold: float = 0.5,
corr_method: Literal["pearson", "spearman"] = "spearman",
):
"""Detects bias in the data.
Lilly-May marked this conversation as resolved.
Show resolved Hide resolved

Args:
adata: An annotated data matrix containing patient data.
Lilly-May marked this conversation as resolved.
Show resolved Hide resolved
sensitive_features: A list of sensitive features to check for bias.
Lilly-May marked this conversation as resolved.
Show resolved Hide resolved
corr_threshold: The threshold for the correlation coefficient between two features to be considered of interest. Defaults to 0.5.
smd_threshold: The threshold for the standardized mean difference between two features to be considered of interest. Defaults to 0.5.
feature_importance_threshold: The threshold for the feature importance of a sensitive feature for predicting another feature to be considered
of interest. Defaults to 0.1.
prediction_confidence_threshold: The threshold for the prediction confidence (R2 or accuracy) of a sensitive feature for predicting another
feature to be considered of interest. Defaults to 0.5.
corr_method: The correlation method to use. Choose between "pearson" and "spearman". Defaults to "spearman".
Lilly-May marked this conversation as resolved.
Show resolved Hide resolved
"""
from ehrapy.tools import rank_features_supervised
Lilly-May marked this conversation as resolved.
Show resolved Hide resolved

if sensitive_features == "all":
sensitive_features = adata.var_names

correlations = _feature_correlations(adata, method=corr_method)
adata.varp["correlation"] = correlations
Lilly-May marked this conversation as resolved.
Show resolved Hide resolved

for feature in sensitive_features:
for comp_feature in adata.var_names:
if correlations.loc[feature, comp_feature] > corr_threshold:
Lilly-May marked this conversation as resolved.
Show resolved Hide resolved
logg.warning(
f"Feature {feature} is highly correlated with {comp_feature} (correlation coefficient ≈{correlations.loc[feature, comp_feature]:.3f})."
) # TODO: How do we print results?

smd_dict = _standardized_mean_differences(adata, sensitive_features)
for feature in sensitive_features:
abs_smd = smd_dict[feature].abs()
for comp_feature in adata.var_names:
if abs_smd[comp_feature].max() > smd_threshold:
logg.warning(
f"Feature {comp_feature} has a high standardized mean difference with {feature}."
) # TODO: Do we look at / print groups individually?

for prediction_feature in adata.var_names:
prediction_score = rank_features_supervised(
adata,
prediction_feature,
input_features="all",
model="rf",
key_added=f"{prediction_feature}_feature_importances",
percent_output=True,
logging=False,
return_score=True,
)
for feature in sensitive_features:
feature_importance = adata.var[f"{prediction_feature}_feature_importances"][feature] / 100
if feature_importance > feature_importance_threshold and prediction_score > prediction_confidence_threshold:
logg.warning(
f"Feature {feature} has a high feature importance for predicting {prediction_feature} (importance in %: {feature_importance:.3f}, prediction score: {prediction_score:.3f})."
)


def _feature_correlations(adata: AnnData, method: Literal["pearson", "spearman"] = "spearman"):
Lilly-May marked this conversation as resolved.
Show resolved Hide resolved
"""Computes pairwise correlations between features in the AnnData object.

Args:
adata: An annotated data matrix containing patient data.
method: The correlation method to use. Choose between "pearson" and "spearman". Defaults to "spearman".

Returns:
A pandas DataFrame containing the correlation matrix.
"""
corr_matrix = anndata_to_df(adata).corr(method=method)
Lilly-May marked this conversation as resolved.
Show resolved Hide resolved
return corr_matrix


def _standardized_mean_differences(adata: AnnData, features: Iterable[str]) -> dict:
Lilly-May marked this conversation as resolved.
Show resolved Hide resolved
"""Computes the standardized mean differences between sensitive features.

Args:
adata: An annotated data matrix containing patient data.
features: A list of features to compute the standardized mean differences (SMD) for. For each listed feature, the SMD is computed for each
Lilly-May marked this conversation as resolved.
Show resolved Hide resolved
feature, comparing one group to the rest. Thus, we obtain a n_groups_in_feature x n_features matrix of SMDs for each listed feature.

Returns:
A dictionary mapping each feature to a pandas DataFrame containing the standardized mean differences.
"""
df = anndata_to_df(adata)
smd_results = {} # type: ignore

for group_feature in features: # TODO: Restrict to categorical features (wait for other PR)
smd_results[group_feature] = {}
for group in df[group_feature].unique():
group_mean = df[df[group_feature] == group].mean()
group_std = df[df[group_feature] == group].std()

comparison_mean = df[df[group_feature] != group].mean()
comparison_std = df[df[group_feature] != group].std()

smd = (group_mean - comparison_mean) / np.sqrt((group_std**2 + comparison_std**2) / 2)
smd_results[group_feature][group] = smd

smd_results[group_feature] = pd.DataFrame(smd_results[group_feature]).T[adata.var_names]

return smd_results
2 changes: 1 addition & 1 deletion ehrapy/preprocessing/_imputation.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def knn_impute(
imputation ran successfully.

Args:
adata: An annotated data matrix containing gene expression values.
adata: An annotated data matrix containing patient data.
Lilly-May marked this conversation as resolved.
Show resolved Hide resolved
var_names: A list of variable names indicating which columns to impute.
If `None`, all columns are imputed. Default is `None`.
n_neighbours: Number of neighbors to use when performing the imputation. Defaults to 5.
Expand Down
24 changes: 16 additions & 8 deletions ehrapy/tools/feature_ranking/_feature_importances.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,17 @@ def rank_features_supervised(
adata: AnnData,
predicted_feature: str,
prediction_type: Literal["continuous", "categorical", "auto"] = "auto",
model: Literal["regression", "svm", "rf"] = "regression",
model: Literal["regression", "svm", "rf"] = "rf",
input_features: Iterable[str] | Literal["all"] = "all",
layer: str | None = None,
test_split_size: float = 0.2,
key_added: str = "feature_importances",
feature_scaling: Literal["standard", "minmax"] | None = "standard",
percent_output: bool = False,
logging: bool = True,
return_score: bool = False,
**kwargs,
):
) -> float | None:
"""Calculate feature importances for predicting a specified feature in adata.var.

Args:
Expand All @@ -49,6 +51,8 @@ def rank_features_supervised(
for each feature individually. Defaults to 'standard'.
percent_output: Set to True to output the feature importances as percentages. Note that information about positive or negative
coefficients for regression models will be lost. Defaults to False.
logging: Set to False to disable logging. Defaults to True.
Lilly-May marked this conversation as resolved.
Show resolved Hide resolved
return_score: Set to True to return the R2 score / the accuracy of the model. Defaults to False.
**kwargs: Additional keyword arguments to pass to the model. See the documentation of the respective model in scikit-learn for details.

Examples:
Expand Down Expand Up @@ -92,9 +96,10 @@ def rank_features_supervised(
prediction_type = "categorical"
else:
prediction_type = "continuous"
logg.info(
f"Predicted feature {predicted_feature} was detected as {prediction_type}. If this is incorrect, please specify in the prediction_type argument."
)
if logging:
logg.info(
f"Predicted feature {predicted_feature} was detected as {prediction_type}. If this is incorrect, please specify in the prediction_type argument."
)

elif prediction_type == "continuous":
if pd.api.types.is_categorical_dtype(data[predicted_feature].dtype):
Expand Down Expand Up @@ -167,9 +172,10 @@ def rank_features_supervised(

score = predictor.score(x_test, y_test)
evaluation_metric = "R2 score" if prediction_type == "continuous" else "accuracy"
logg.info(
f"Training completed. The model achieved an {evaluation_metric} of {score:.2f} on the test set, consisting of {len(y_test)} samples."
)
if logging:
logg.info(
f"Training completed. The model achieved an {evaluation_metric} of {score:.2f} on the test set, consisting of {len(y_test)} samples."
)

if model == "regression" or model == "svm":
feature_importances = pd.Series(predictor.coef_.squeeze(), index=input_data.columns)
Expand All @@ -182,3 +188,5 @@ def rank_features_supervised(
# Reorder feature importances to match adata.var order and save importances in adata.var
feature_importances = feature_importances.reindex(adata.var_names)
adata.var[key_added] = feature_importances

return score if return_score else None
Loading