diff --git a/tests/insights/test_shap.py b/tests/insights/test_shap.py index 0fc3eabef..4d231999b 100644 --- a/tests/insights/test_shap.py +++ b/tests/insights/test_shap.py @@ -1,6 +1,7 @@ """Tests for insights subpackage.""" import inspect +from contextlib import nullcontext from unittest import mock import numpy as np @@ -80,26 +81,35 @@ def test_non_shap_signature(explainer_name): def _test_shap_insight(campaign, explainer_cls, use_comp_rep, is_shap): """Helper function for general SHAP explainer tests.""" - try: + context = nullcontext() + if ( + (not use_comp_rep) + and (explainer_cls != "KernelExplainer") + and any(not p.is_numerical for p in campaign.parameters) + ): + # We expect a validation error in case an explanation with an unsupported + # explainer type is attempted on a search space representation with + # non-numerical entries + context = pytest.raises(IncompatibleExplainerError) + + with context: shap_insight = SHAPInsight.from_campaign( campaign, explainer_cls=explainer_cls, use_comp_rep=use_comp_rep, ) - except IncompatibleExplainerError: - pytest.skip("Unsupported model/explainer combination.") - - # Sanity check explainer - assert isinstance(shap_insight, insights.SHAPInsight) - assert isinstance(shap_insight.explainer, _get_explainer_cls(explainer_cls)) - assert shap_insight.uses_shap_explainer == is_shap - # Sanity check explanation - df = campaign.measurements[[p.name for p in campaign.parameters]] - if use_comp_rep: - df = campaign.searchspace.transform(df) - shap_explanation = shap_insight.explain(df) - assert isinstance(shap_explanation, shap.Explanation) + # Sanity check explainer + assert isinstance(shap_insight, insights.SHAPInsight) + assert isinstance(shap_insight.explainer, _get_explainer_cls(explainer_cls)) + assert shap_insight.uses_shap_explainer == is_shap + + # Sanity check explanation + df = campaign.measurements[[p.name for p in campaign.parameters]] + if use_comp_rep: + df = campaign.searchspace.transform(df) + shap_explanation = shap_insight.explain(df) + assert isinstance(shap_explanation, shap.Explanation) @mark.slow