Skip to content

Commit

Permalink
fix: Cache all possible predictions in cache_predictions (#1098)
Browse files Browse the repository at this point in the history
fixes #1093 

Make sure that we cache all possible type of predictions when calling
`reporter.cache_predictions`. We here check the number of keys in the
dictionary that is enough.
  • Loading branch information
glemaitre authored Jan 14, 2025
1 parent 8f1835e commit 0354774
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 11 deletions.
12 changes: 6 additions & 6 deletions skore/src/skore/sklearn/_estimator/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,9 @@ def cache_predictions(self, response_methods="auto", n_jobs=None):
if response_methods == "auto":
response_methods = ("predict",)
if hasattr(self._estimator, "predict_proba"):
response_methods = ("predict_proba",)
response_methods += ("predict_proba",)
if hasattr(self._estimator, "decision_function"):
response_methods = ("decision_function",)
response_methods += ("decision_function",)
pos_labels = self._estimator.classes_
else:
if response_methods == "auto":
Expand All @@ -175,8 +175,8 @@ def cache_predictions(self, response_methods="auto", n_jobs=None):
data_sources = ("test",)
Xs = (self._X_test,)
if self._X_train is not None:
data_sources = ("train",)
Xs = (self._X_train,)
data_sources += ("train",)
Xs += (self._X_train,)

parallel = joblib.Parallel(n_jobs=n_jobs, return_as="generator_unordered")
generator = parallel(
Expand All @@ -188,8 +188,8 @@ def cache_predictions(self, response_methods="auto", n_jobs=None):
pos_label=pos_label,
data_source=data_source,
)
for response_method, pos_label, data_source, X in product(
response_methods, pos_labels, data_sources, Xs
for response_method, pos_label, (data_source, X) in product(
response_methods, pos_labels, zip(data_sources, Xs)
)
)
# trigger the computation
Expand Down
26 changes: 21 additions & 5 deletions skore/tests/unit/sklearn/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,17 +289,33 @@ def test_estimator_report_repr(binary_classification_data):


@pytest.mark.parametrize(
"fixture_name", ["binary_classification_data", "regression_data"]
"fixture_name, pass_train_data, expected_n_keys",
[
("binary_classification_data", True, 6),
("binary_classification_data_svc", True, 6),
("multiclass_classification_data", True, 8),
("regression_data", True, 2),
("binary_classification_data", False, 3),
("binary_classification_data_svc", False, 3),
("multiclass_classification_data", False, 4),
("regression_data", False, 1),
],
)
def test_estimator_report_cache_predictions(request, fixture_name):
def test_estimator_report_cache_predictions(
request, fixture_name, pass_train_data, expected_n_keys
):
"""Check that calling cache_predictions fills the cache."""
estimator, X_test, y_test = request.getfixturevalue(fixture_name)
report = EstimatorReport(
estimator, X_train=X_test, y_train=y_test, X_test=X_test, y_test=y_test
)
if pass_train_data:
report = EstimatorReport(
estimator, X_train=X_test, y_train=y_test, X_test=X_test, y_test=y_test
)
else:
report = EstimatorReport(estimator, X_test=X_test, y_test=y_test)

assert report._cache == {}
report.cache_predictions()
assert len(report._cache) == expected_n_keys
assert report._cache != {}
stored_cache = deepcopy(report._cache)
report.cache_predictions()
Expand Down

0 comments on commit 0354774

Please sign in to comment.