diff --git a/skore/src/skore/sklearn/_estimator/report.py b/skore/src/skore/sklearn/_estimator/report.py index bd9dbe6d1..745d73fbc 100644 --- a/skore/src/skore/sklearn/_estimator/report.py +++ b/skore/src/skore/sklearn/_estimator/report.py @@ -163,9 +163,9 @@ def cache_predictions(self, response_methods="auto", n_jobs=None): if response_methods == "auto": response_methods = ("predict",) if hasattr(self._estimator, "predict_proba"): - response_methods = ("predict_proba",) + response_methods += ("predict_proba",) if hasattr(self._estimator, "decision_function"): - response_methods = ("decision_function",) + response_methods += ("decision_function",) pos_labels = self._estimator.classes_ else: if response_methods == "auto": @@ -175,8 +175,8 @@ def cache_predictions(self, response_methods="auto", n_jobs=None): data_sources = ("test",) Xs = (self._X_test,) if self._X_train is not None: - data_sources = ("train",) - Xs = (self._X_train,) + data_sources += ("train",) + Xs += (self._X_train,) parallel = joblib.Parallel(n_jobs=n_jobs, return_as="generator_unordered") generator = parallel( @@ -188,8 +188,8 @@ def cache_predictions(self, response_methods="auto", n_jobs=None): pos_label=pos_label, data_source=data_source, ) - for response_method, pos_label, data_source, X in product( - response_methods, pos_labels, data_sources, Xs + for response_method, pos_label, (data_source, X) in product( + response_methods, pos_labels, zip(data_sources, Xs) ) ) # trigger the computation diff --git a/skore/tests/unit/sklearn/test_estimator.py b/skore/tests/unit/sklearn/test_estimator.py index b8f9b3633..4e0e6e4c2 100644 --- a/skore/tests/unit/sklearn/test_estimator.py +++ b/skore/tests/unit/sklearn/test_estimator.py @@ -289,17 +289,33 @@ def test_estimator_report_repr(binary_classification_data): @pytest.mark.parametrize( - "fixture_name", ["binary_classification_data", "regression_data"] + "fixture_name, pass_train_data, expected_n_keys", + [ + ("binary_classification_data", True, 6), + ("binary_classification_data_svc", True, 6), + ("multiclass_classification_data", True, 8), + ("regression_data", True, 2), + ("binary_classification_data", False, 3), + ("binary_classification_data_svc", False, 3), + ("multiclass_classification_data", False, 4), + ("regression_data", False, 1), + ], ) -def test_estimator_report_cache_predictions(request, fixture_name): +def test_estimator_report_cache_predictions( + request, fixture_name, pass_train_data, expected_n_keys +): """Check that calling cache_predictions fills the cache.""" estimator, X_test, y_test = request.getfixturevalue(fixture_name) - report = EstimatorReport( - estimator, X_train=X_test, y_train=y_test, X_test=X_test, y_test=y_test - ) + if pass_train_data: + report = EstimatorReport( + estimator, X_train=X_test, y_train=y_test, X_test=X_test, y_test=y_test + ) + else: + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) assert report._cache == {} report.cache_predictions() + assert len(report._cache) == expected_n_keys assert report._cache != {} stored_cache = deepcopy(report._cache) report.cache_predictions()