Skip to content

Commit

Permalink
move viz methods to seperate module
Browse files Browse the repository at this point in the history
  • Loading branch information
christinaexyou committed Feb 21, 2024
1 parent e40dc58 commit 0e95b69
Show file tree
Hide file tree
Showing 12 changed files with 536 additions and 447 deletions.
51 changes: 0 additions & 51 deletions src/trustyai/explainers/explanation_results.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
"""Generic class for Explanation and Saliency results"""
from abc import ABC, abstractmethod
from typing import Dict

import bokeh.models
import pandas as pd
from bokeh.io import show
from pandas.io.formats.style import Styler


Expand All @@ -27,51 +24,3 @@ class SaliencyResults(ExplanationResults):
@abstractmethod
def saliency_map(self):
"""Return the Saliencies as a dictionary, keyed by output name"""

@abstractmethod
def _matplotlib_plot(self, output_name: str, block: bool, call_show: bool) -> None:
"""Plot the saliencies of a particular output in matplotlib"""

@abstractmethod
def _get_bokeh_plot(self, output_name: str) -> bokeh.models.Plot:
"""Get a bokeh plot visualizing the saliencies of a particular output"""

def _get_bokeh_plot_dict(self) -> Dict[str, bokeh.models.Plot]:
"""Get a dictionary containing visualizations of the saliencies of all outputs,
keyed by output name"""
return {
output_name: self._get_bokeh_plot(output_name)
for output_name in self.saliency_map().keys()
}

def plot(
self, output_name=None, render_bokeh=False, block=True, call_show=True
) -> None:
"""
Plot the found feature saliencies.
Parameters
----------
output_name : str
(default= `None`) The name of the output to be explainer. If `None`, all outputs will
be displayed
render_bokeh : bool
(default= `False`) If true, render plot in bokeh, otherwise use matplotlib.
block: bool
(default= `True`) Whether displaying the plot blocks subsequent code execution
call_show: bool
(default= 'True') Whether plt.show() will be called by default at the end of the
plotting function. If `False`, the plot will be returned to the user for further
editing.
"""
if output_name is None:
for output_name_iterator in self.saliency_map().keys():
if render_bokeh:
show(self._get_bokeh_plot(output_name_iterator))
else:
self._matplotlib_plot(output_name_iterator, block, call_show)
else:
if render_bokeh:
show(self._get_bokeh_plot(output_name))
else:
self._matplotlib_plot(output_name, block, call_show)
105 changes: 1 addition & 104 deletions src/trustyai/explainers/lime.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,12 @@
# pylint: disable = unused-argument, duplicate-code, consider-using-f-string, invalid-name
from typing import Dict, Union

import bokeh.models
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
from bokeh.models import ColumnDataSource, HoverTool
from bokeh.plotting import figure
import pandas as pd
from matplotlib.colors import LinearSegmentedColormap

from trustyai import _default_initializer # pylint: disable=unused-import
from trustyai.utils._visualisation import (
DEFAULT_STYLE as ds,
DEFAULT_RC_PARAMS as drcp,
bold_red_html,
bold_green_html,
output_html,
feature_html,
)

from trustyai.utils._visualisation import DEFAULT_STYLE as ds
from trustyai.utils.data_conversions import (
OneInputUnionType,
data_conversion_docstring,
Expand Down Expand Up @@ -137,96 +124,6 @@ def as_html(self) -> pd.io.formats.style.Styler:
)
return htmls

def _matplotlib_plot(self, output_name: str, block=True, call_show=True) -> None:
"""Plot the LIME saliencies."""
with mpl.rc_context(drcp):
dictionary = {}
for feature_importance in (
self.saliency_map().get(output_name).getPerFeatureImportance()
):
dictionary[
feature_importance.getFeature().name
] = feature_importance.getScore()

colours = [
ds["negative_primary_colour"]
if i < 0
else ds["positive_primary_colour"]
for i in dictionary.values()
]
plt.title(f"LIME: Feature Importances to {output_name}")
plt.barh(
range(len(dictionary)),
dictionary.values(),
align="center",
color=colours,
)
plt.yticks(range(len(dictionary)), list(dictionary.keys()))
plt.tight_layout()

if call_show:
plt.show(block=block)

def _get_bokeh_plot(self, output_name) -> bokeh.models.Plot:
lime_data_source = pd.DataFrame(
[
{
"feature": str(pfi.getFeature().getName()),
"saliency": pfi.getScore(),
}
for pfi in self.saliency_map()[output_name].getPerFeatureImportance()
]
)
lime_data_source["color"] = lime_data_source["saliency"].apply(
lambda x: ds["positive_primary_colour"]
if x >= 0
else ds["negative_primary_colour"]
)
lime_data_source["saliency_colored"] = lime_data_source["saliency"].apply(
lambda x: (bold_green_html if x >= 0 else bold_red_html)("{:.2f}".format(x))
)

lime_data_source["color_faded"] = lime_data_source["saliency"].apply(
lambda x: ds["positive_primary_colour_faded"]
if x >= 0
else ds["negative_primary_colour_faded"]
)
source = ColumnDataSource(lime_data_source)
htool = HoverTool(
name="bars",
tooltips="<h3>LIME</h3> {} saliency to {}: @saliency_colored".format(
feature_html("@feature"), output_html(output_name)
),
)
bokeh_plot = figure(
sizing_mode="stretch_both",
title="Lime Feature Importances",
y_range=lime_data_source["feature"],
tools=[htool],
)
bokeh_plot.hbar(
y="feature",
left=0,
right="saliency",
fill_color="color_faded",
line_color="color",
hover_color="color",
color="color",
height=0.75,
name="bars",
source=source,
)
bokeh_plot.line([0, 0], [0, len(lime_data_source)], color="#000")
bokeh_plot.xaxis.axis_label = "Saliency Value"
bokeh_plot.yaxis.axis_label = "Feature"
return bokeh_plot

def _get_bokeh_plot_dict(self) -> Dict[str, bokeh.models.Plot]:
return {
output_name: self._get_bokeh_plot(output_name)
for output_name in self.saliency_map().keys()
}


class LimeExplainer:
"""*"Which features were most important to the results?"*
Expand Down
45 changes: 2 additions & 43 deletions src/trustyai/explainers/pdp.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
"""Explainers.pdp module"""

import math
import matplotlib.pyplot as plt
import pandas as pd
from pandas.io.formats.style import Styler

Expand Down Expand Up @@ -62,45 +60,6 @@ def as_html(self) -> Styler:
"""
return self.as_dataframe().style

def plot(self, output_name=None, block=True, call_show=True) -> None:
"""
Parameters
----------
output_name: str
name of the output to be plotted
Default to None
block: bool
whether the plotting operation
should be blocking or not
call_show: bool
(default= 'True') Whether plt.show() will be called by default at the end of
the plotting function. If `False`, the plot will be returned to the user for
further editing.
"""
fig, axs = plt.subplots(len(self.pdp_graphs), constrained_layout=True)
p_idx = 0
for pdp_graph in self.pdp_graphs:
if output_name is not None and output_name != str(
pdp_graph.getOutput().getName()
):
continue
fig.suptitle(str(pdp_graph.getOutput().getName()))
pdp_x = []
for i in range(len(pdp_graph.getX())):
pdp_x.append(self._to_plottable(pdp_graph.getX()[i]))
pdp_y = []
for i in range(len(pdp_graph.getY())):
pdp_y.append(self._to_plottable(pdp_graph.getY()[i]))
axs[p_idx].plot(pdp_x, pdp_y)
axs[p_idx].set_title(
str(pdp_graph.getFeature().getName()), loc="left", fontsize="small"
)
axs[p_idx].grid()
p_idx += 1
fig.supylabel("Partial Dependence Plot")
if call_show:
plt.show(block=block)

@staticmethod
def _to_plottable(datum: Value):
plottable = datum.asNumber()
Expand Down Expand Up @@ -187,12 +146,12 @@ def getInputShape(self):
"""
return self.data.sample()

# pylint: disable = invalid-name
# pylint: disable = invalid-name, missing-final-newline
@JOverride
def getOutputShape(self):
"""
Returns
--------
a PredictionOutput
"""
return self.pred_out
return self.pred_out
Loading

0 comments on commit 0e95b69

Please sign in to comment.