Skip to content

Commit

Permalink
Captum_integration (#336)
Browse files Browse the repository at this point in the history
* Captum Integration fo explainability
Fix device assignment in TabNetBackbone

* fixed a few bugs

* add test cases for captum

* added explainability notebook

* updated documentation

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add captum library to requirements.txt and handle
captum import error

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Remove empty list from continuous_cols parameter
in test_captum_integration_regression

* fixed ruff lone too long errors

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
manujosephv and pre-commit-ci[bot] authored Dec 8, 2023
1 parent 34bac20 commit 75976e1
Show file tree
Hide file tree
Showing 8 changed files with 1,589 additions and 49 deletions.
43 changes: 43 additions & 0 deletions docs/explainability.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
The explainability features in PyTorch Tabular allow users to interpret and understand the predictions made by a tabular deep learning model. These features provide insights into the model's decision-making process and help identify the most influential features. Some of the explainability features are inbuilt from the models, and a lot of others are based on the [Captum](https://captum.ai/) library.

## Native Feature Importance
One of the features of the GBDT models which everybody loves is the feature importance. It helps us understand which features are the most important for the model. PyTorch Tabular provides a similar feature for some of the models - GANDALF, GATE, and FTTransformers - where the models natively support the extraction of feature importance.

``` python
# tabular_model is the trained model of a supported model
tabular_model.feature_importance()
```

## Local Feature Attributions/Explanations
Local feature attributions/explanations help us understand the contribution of each feature towards the prediction for a particular sample. PyTorch Tabular provides this feature for all the models except TabTransformer, Tabnet, and Mixed Density Networks. It is based on the [Captum](https://captum.ai/) library. The library provides a lot of algorithms for computing feature attributions. PyTorch Tabular provides a wrapper around the library to make it easy to use. The following algorithms are supported:

- GradientShap: https://captum.ai/api/gradient_shap.html
- IntegratedGradients: https://captum.ai/api/integrated_gradients.html
- DeepLift: https://captum.ai/api/deep_lift.html
- DeepLiftShap: https://captum.ai/api/deep_lift_shap.html
- InputXGradient: https://captum.ai/api/input_x_gradient.html
- FeaturePermutation: https://captum.ai/api/feature_permutation.html
- FeatureAblation: https://captum.ai/api/feature_ablation.html
- KernelShap: https://captum.ai/api/kernel_shap.html

`PyTorch Tabular` also supports explaining single instances as well as batches of instances. But, larger datasets will take longer to explain. An exception is the `FeaturePermutation` and `FeatureAblation` methods, which is only meaningful for large batches of instances.

Most of these explainability methods require a baseline. This is used to compare the attributions of the input with the attributions of the baseline. The baseline can be a scalar value, a tensor of the same shape as the input, or a special string like "b|10000" which means 10000 samples from the training data. If the baseline is not provided, the default baseline (zero) is used.

``` python
# tabular_model is the trained model of a supported model

# Explain a single instance using the GradientShap method and baseline as 10000 samples from the training data
tabular_model.explain(test.head(1), method="GradientShap", baselines="b|10000")

# Explain a batch of instances using the IntegratedGradients method and baseline as 0
tabular_model.explain(test.head(10), method="IntegratedGradients", baselines=0)
```

Checkout the [Captum documentation](https://captum.ai/docs/algorithms) for more details on the algorithms and the [Explainability Tutorial](tutorials/11-Explainability.ipynb) for example usage.

## API Reference
::: pytorch_tabular.TabularModel.explain
options:
show_root_heading: yes
heading_level: 4
1,114 changes: 1,114 additions & 0 deletions docs/tutorials/11-Explainability.ipynb

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ nav:
- Self-Supervised Learning - Denoising Autoencoders: "tutorials/08-Self-Supervised Learning-DAE.ipynb"
- Low-level API Usecases: "tutorials/09-Low-level API Usecases.ipynb"
- Test Time Augumentation: "tutorials/10-Test Time Augmentation.ipynb"
- Explainability/Interpretability: "tutorials/11-Explainability.ipynb"
- Configuration:
- Data: data.md
- Supervised Models: models.md
Expand All @@ -22,6 +23,8 @@ nav:
- Experiment Tracking: experiment_tracking.md
- Tabular Model:
- TabularModel: tabular_model.md
- Explainability/Interpretablity:
- Explainability/Interpretablity: explainability.md
- Other Features:
- "Other Features": other_features.md
# - FAQ: faq.md
Expand Down
1 change: 1 addition & 0 deletions requirements/extra.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
wandb >=0.15.0, <0.17.0
plotly>=5.13.0, <5.19.0
kaleido >=0.2.0, <0.3.0
captum >=0.5.0, <0.6.0
10 changes: 10 additions & 0 deletions src/pytorch_tabular/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,3 +651,13 @@ def _build_network(self):
# all components are initialized in the init function
self._backbone = self.kwargs.get("backbone")
self._head = self._get_head_from_config()


class _CaptumModel(nn.Module):
def __init__(self, model: BaseModel):
super().__init__()
self.model = model

def forward(self, x: Tensor):
x = self.model.compute_backbone(x)
return self.model.compute_head(x)["logits"]
5 changes: 5 additions & 0 deletions src/pytorch_tabular/models/tabnet/tabnet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ def unpack_input(self, x: Dict):
def forward(self, x: Dict):
# unpacking into a tuple
x = self.unpack_input(x)
# Making two parameters to the right device.
self.tabnet.embedder.embedding_group_matrix = self.tabnet.embedder.embedding_group_matrix.to(x.device)
self.tabnet.tabnet.encoder.group_attention_matrix = self.tabnet.tabnet.encoder.group_attention_matrix.to(
x.device
)
# Returns output and Masked Loss. We only need the output
x, _ = self.tabnet(x)
return x
Expand Down
Loading

0 comments on commit 75976e1

Please sign in to comment.