"""
SHAP Visualization Utilities
============================
Overview
--------
This module provides a set of clean, publication-ready visualization utilities for comparing
SHAP attributions against ground-truth or across multiple explainers. The visualizations support
both **sequential** (2D) and **tabular** (1D) input formats and offer multiple views such as bar plots,
3D surface plots, and 3D bar plots.
The focus is on **clarity**, **aesthetics**, and **comparability**, making the tools well-suited
for research papers, presentations, and internal model audits.
Key Functions
^^^^^^^^^^^^^
- **plot_mse_pearson**:
Bar chart comparing MSE and Pearson correlation of each explainer vs ground-truth SHAP.
- **plot_3d_surface**:
Side-by-side 3D surface plots for ground-truth and predicted SHAP values over time and features.
- **plot_3d_bars**:
Paired 3D bar plots for SHAP values, visually appealing and easy to compare height/direction.
- **plot_feature_comparison**:
Side-by-side bar plots for SHAP values from different explainers (1D/tabular inputs only).
Customization
-------------
- **Color maps**: Most functions support custom colormaps via `cmap` (default: `'viridis'`).
- **Saving**: All plots can be saved using the `save` argument (PDF/PNG via `matplotlib`).
- **Interactivity**: Plots are shown by default, but this can be toggled with `show=False`.
Use Case
--------
These tools are especially useful for:
- Benchmarking explainers on synthetic datasets.
- Visualizing time-series explanations (e.g., SHAP over `(T, F)` inputs).
- Comparing surrogate vs exact SHAP explainers.
- Producing clean visuals for publications and reports.
Example
-------
.. code-block:: python
plot_mse_pearson(mse_dict, pearson_dict, save="comparison.pdf")
plot_3d_surface(gt_shap, shap_outputs, seq_len=10, n_features=5)
plot_feature_comparison(gt_tab, shap_dict_tabular)
"""
import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.mplot3d import Axes3D # noqa: F401
__all__ = [
"plot_mse_pearson",
"plot_3d_surface",
"plot_3d_bars",
"plot_feature_comparison",
]
[docs]
def plot_mse_pearson(
results,
pearson_results,
save=None,
bar_threshold=4,
cmap="viridis",
show=True,
):
r"""
Plot comparison of Mean Squared Error and Pearson Correlation across SHAP explainers.
Generates side-by-side bar charts to compare each explainer's SHAP attributions to ground-truth values.
:param dict results: Dictionary mapping explainer names to MSE values.
:param dict pearson_results: Dictionary mapping explainer names to Pearson correlation scores.
:param str save: Optional filename to save the figure (PDF/PNG).
:param int bar_threshold: Orientation switches to horizontal if number of bars exceeds this threshold.
:param str cmap: Colormap name for value encoding (default: "viridis").
:param bool show: Whether to display the plot.
"""
plt.style.use("seaborn-v0_8-whitegrid")
n_bars = max(len(results), len(pearson_results))
orientation = "vertical" if n_bars <= bar_threshold else "horizontal"
mse_vals = np.array(list(results.values()))
pearson_vals = np.array(list(pearson_results.values()))
mse_names = list(results.keys())
pearson_names = list(pearson_results.keys())
vmax = max(np.abs(mse_vals).max(), np.abs(pearson_vals).max())
vmin = min(np.abs(mse_vals).min(), np.abs(pearson_vals).min())
norm = plt.Normalize(vmin, vmax)
cmap_ = plt.get_cmap(cmap)
figsize = (9, 4.5) if orientation == "vertical" else (11, 5.5)
fig, axes = plt.subplots(1, 2, figsize=figsize)
for ax, vals, names, title, ylabel in zip(
axes,
[mse_vals, pearson_vals],
[mse_names, pearson_names],
[
"MSE (Explainer vs Ground Truth)",
"Pearson Correlation (Explainer vs Ground Truth)",
],
["MSE", "Pearson Correlation"],
strict=False,
):
colors = cmap_(norm(np.abs(vals)))
if orientation == "vertical":
bars = ax.bar(
names, vals, color=colors, edgecolor="#222", alpha=0.88, zorder=2
)
ax.set_ylabel(ylabel, fontsize=12)
ax.set_xlabel("Explainer", fontsize=11)
for bar in bars:
height = bar.get_height()
ax.annotate(
f"{height:.3g}",
xy=(bar.get_x() + bar.get_width() / 2, height),
xytext=(0, 4 if height >= 0 else -14),
textcoords="offset points",
ha="center",
va="bottom" if height >= 0 else "top",
fontsize=10,
color="#222",
)
ax.grid(axis="y", linestyle="--", alpha=0.7, zorder=0)
else:
bars = ax.barh(
names, vals, color=colors, edgecolor="#222", alpha=0.88, zorder=2
)
ax.set_xlabel(ylabel, fontsize=12)
ax.set_ylabel("Explainer", fontsize=11)
for bar in bars:
width = bar.get_width()
ax.annotate(
f"{width:.3g}",
xy=(width, bar.get_y() + bar.get_height() / 2),
xytext=(4 if width >= 0 else -32, 0),
textcoords="offset points",
ha="left" if width >= 0 else "right",
va="center",
fontsize=10,
color="#222",
)
ax.grid(axis="x", linestyle="--", alpha=0.7, zorder=0)
ax.set_title(title, fontsize=13, fontweight="bold", pad=6)
ax.set_facecolor("#fafafa")
ax.tick_params(axis="both", labelsize=9)
plt.tight_layout(pad=1.2)
# --- No colorbar block here! ---
if save:
plt.savefig(save, bbox_inches="tight", pad_inches=0.02, dpi=300)
print(f"Figure saved to {save}")
if show:
plt.show()
else:
plt.close(fig)
[docs]
def plot_3d_surface(
shap_gt,
shap_models,
seq_len,
n_features,
save=None,
cmap="viridis",
vlim=None,
show=True,
):
r"""
Generate side-by-side 3D surface plots for SHAP values from ground-truth and explainers.
Used for comparing SHAP explanations on sequential (T, F) inputs with rich spatial structure.
:param np.ndarray shap_gt: Ground-truth SHAP array (T, F) or (1, T, F).
:param dict shap_models: Dictionary mapping explainer names to SHAP arrays.
:param int seq_len: Sequence length (T).
:param int n_features: Number of input features (F).
:param str save: Optional filename to save the figure.
:param str cmap: Colormap for surface shading.
:param tuple vlim: Optional tuple of (vmin, vmax) for shared Z-axis scaling.
:param bool show: Whether to display the figure.
"""
import matplotlib.pyplot as plt
import numpy as np
plt.style.use("seaborn-v0_8-whitegrid")
n_models = len(shap_models)
if vlim is None:
vmax = max(
np.abs(shap_gt).max(),
max(np.abs(np.asarray(arr)).max() for arr in shap_models.values()),
)
vlim = (-vmax, vmax)
else:
vmax = max(abs(vlim[0]), abs(vlim[1]))
vlim = (-vmax, vmax)
fig = plt.figure(figsize=(10, 3.8 * n_models))
for i, (name, arr) in enumerate(shap_models.items()):
arr = np.asarray(arr)
gt = np.asarray(shap_gt)
# If batched, squeeze out batch dim
if arr.ndim == 3 and arr.shape[0] == 1:
arr = arr[0]
if gt.ndim == 3 and gt.shape[0] == 1:
gt = gt[0]
# If flattened
if arr.ndim == 1:
arr = arr.reshape(seq_len, n_features)
if gt.ndim == 1:
gt = gt.reshape(seq_len, n_features)
_x = np.arange(n_features)
_y = np.arange(seq_len)
xx, yy = np.meshgrid(_x, _y)
# 1. Ground Truth
ax_gt = fig.add_subplot(n_models, 2, 2 * i + 1, projection="3d")
ax_gt.plot_surface(
xx,
yy,
gt,
cmap=cmap,
vmin=vlim[0],
vmax=vlim[1],
edgecolor="none",
alpha=0.95,
antialiased=True,
)
ax_gt.set_title(f"{name} – Ground Truth", fontsize=14, fontweight="bold", pad=0)
ax_gt.set_xlabel("Feature", fontsize=11, labelpad=6)
ax_gt.set_ylabel("Time", fontsize=11, labelpad=6)
ax_gt.set_zlabel("SHAP Value", fontsize=11, labelpad=8)
ax_gt.view_init(elev=30, azim=130)
ax_gt.set_facecolor("#ffffff")
ax_gt.tick_params(axis="both", labelsize=9, pad=2)
ax_gt.set_xticks(_x)
ax_gt.set_yticks(_y)
ax_gt.set_zlim(vlim[0] * 1.1, vlim[1] * 1.1)
ax_gt.grid(False)
# 2. Explainer Output
ax_ex = fig.add_subplot(n_models, 2, 2 * i + 2, projection="3d")
ax_ex.plot_surface(
xx,
yy,
arr,
cmap=cmap,
vmin=vlim[0],
vmax=vlim[1],
edgecolor="none",
alpha=0.95,
antialiased=True,
)
ax_ex.set_title(
f"{name} – SHAP Explanation", fontsize=14, fontweight="bold", pad=0
)
ax_ex.set_xlabel("Feature", fontsize=11, labelpad=6)
ax_ex.set_ylabel("Time", fontsize=11, labelpad=6)
ax_ex.set_zlabel("SHAP Value", fontsize=11, labelpad=8)
ax_ex.view_init(elev=30, azim=130)
ax_ex.set_facecolor("#ffffff")
ax_ex.tick_params(axis="both", labelsize=9, pad=2)
ax_ex.set_xticks(_x)
ax_ex.set_yticks(_y)
ax_ex.set_zlim(vlim[0] * 1.1, vlim[1] * 1.1)
ax_ex.grid(False)
plt.subplots_adjust(wspace=0.18, hspace=0.18)
plt.tight_layout(rect=[0, 0, 1, 1])
if save:
plt.savefig(save, bbox_inches="tight", pad_inches=0.01, dpi=300)
print(f"Figure saved to {save}")
if show:
plt.show()
else:
plt.close(fig)
[docs]
def plot_3d_bars(
shap_gt,
shap_models,
seq_len,
n_features,
save=None,
bar_alpha=0.88,
bar_color="#3498db", # Softer blue, can set e.g. "#43aa8b" for green
show=True,
):
r"""
Visual comparison of SHAP values using 3D bar plots for ground-truth and explainer outputs.
Emphasizes direction and magnitude using colored bars over (T, F) space.
:param np.ndarray shap_gt: Ground-truth SHAP values.
:param dict shap_models: Dictionary mapping explainer names to SHAP arrays.
:param int seq_len: Length of the sequence (T).
:param int n_features: Number of input features (F).
:param str save: Optional path to save the plot.
:param float bar_alpha: Transparency of bars (default: 0.88).
:param str bar_color: Hex color string for positive bars.
:param bool show: Whether to display the figure.
"""
plt.style.use("seaborn-v0_8-whitegrid")
n_models = len(shap_models)
vmax = max(
np.abs(shap_gt).max(), max(np.abs(arr).max() for arr in shap_models.values())
)
vmin = -vmax
fig = plt.figure(figsize=(11, 3.5 * n_models))
for i, (name, arr) in enumerate(shap_models.items()):
# Grid
_x = np.arange(n_features)
_y = np.arange(seq_len)
xx, yy = np.meshgrid(_x, _y)
xpos, ypos = xx.flatten(), yy.flatten()
zpos = np.zeros_like(xpos)
dx = dy = 0.75
# 1. Ground Truth
ax_gt = fig.add_subplot(n_models, 2, 2 * i + 1, projection="3d")
dz_gt = shap_gt.flatten()
# Use sign for color (optional): positive blue, negative orange/red
bar_colors_gt = np.where(dz_gt >= 0, bar_color, "#f39c12") # blue and orange
ax_gt.bar3d(
xpos,
ypos,
zpos,
dx,
dy,
dz_gt,
color=bar_colors_gt,
edgecolor="k",
alpha=bar_alpha,
linewidth=0.2,
)
ax_gt.plot_surface(
xx, yy, np.zeros_like(xx), color="gray", alpha=0.08, zorder=0
)
ax_gt.set_title(
f"{name} – Ground Truth", fontsize=14, fontweight="bold", pad=12
)
ax_gt.set_xlabel("Feature", fontsize=11, labelpad=6)
ax_gt.set_ylabel("Time", fontsize=11, labelpad=6)
ax_gt.set_zlabel("SHAP Value", fontsize=11, labelpad=8)
ax_gt.view_init(elev=28, azim=120)
ax_gt.grid(False)
ax_gt.set_facecolor("#fcfcfc")
ax_gt.tick_params(axis="both", labelsize=9, pad=2)
ax_gt.set_xticks(_x)
ax_gt.set_yticks(_y)
# Tighter z axis
ax_gt.set_zlim(vmin * 1.1, vmax * 1.1)
# 2. Explainer Output
ax_ex = fig.add_subplot(n_models, 2, 2 * i + 2, projection="3d")
dz_ex = arr.flatten()
bar_colors_ex = np.where(dz_ex >= 0, bar_color, "#f39c12") # blue and orange
ax_ex.bar3d(
xpos,
ypos,
zpos,
dx,
dy,
dz_ex,
color=bar_colors_ex,
edgecolor="k",
alpha=bar_alpha,
linewidth=0.2,
)
ax_ex.plot_surface(
xx, yy, np.zeros_like(xx), color="gray", alpha=0.08, zorder=0
)
ax_ex.set_title(
f"{name} – SHAP Explanation", fontsize=14, fontweight="bold", pad=12
)
ax_ex.set_xlabel("Feature", fontsize=11, labelpad=6)
ax_ex.set_ylabel("Time", fontsize=11, labelpad=6)
ax_ex.set_zlabel("SHAP Value", fontsize=11, labelpad=8)
ax_ex.view_init(elev=28, azim=120)
ax_ex.grid(False)
ax_ex.set_facecolor("#fcfcfc")
ax_ex.tick_params(axis="both", labelsize=9, pad=2)
ax_ex.set_xticks(_x)
ax_ex.set_yticks(_y)
ax_ex.set_zlim(vmin * 1.1, vmax * 1.1)
plt.subplots_adjust(wspace=0.18, hspace=0.18)
plt.tight_layout(rect=[0, 0, 1, 1])
if save:
plt.savefig(save, bbox_inches="tight", pad_inches=0.01, dpi=300)
print(f"Figure saved to {save}")
if show:
plt.show()
else:
plt.close(fig)
[docs]
def plot_feature_comparison(
shap_gt,
shap_models,
feature_names=None,
save=None,
):
r"""
Plot bar charts comparing 1D/tabular SHAP attributions across explainers and ground truth.
:param np.ndarray shap_gt: Ground-truth SHAP values (1D).
:param dict shap_models: Dictionary of SHAP arrays from different explainers.
:param list feature_names: Optional list of feature names.
:param str save: Optional path to save the plot.
"""
plt.style.use("seaborn-v0_8-whitegrid")
n_explainers = len(shap_models)
n_features = len(shap_gt)
features = np.arange(n_features)
width = 0.35
if feature_names is None:
feature_names = [str(i) for i in range(n_features)]
fig, axes = plt.subplots(
n_explainers, 1, figsize=(10, 3.8 * n_explainers), sharex=True
)
if n_explainers == 1:
axes = [axes]
color_gt = "#4093c6" # blue for GT
color_exp = "#f99c2b" # orange for Explainer
for ax, (name, vals) in zip(axes, shap_models.items(), strict=False):
ax.bar(
features - width / 2,
shap_gt,
width,
label="Monte Carlo GT",
color=color_gt,
edgecolor="#333",
alpha=0.80,
)
ax.bar(
features + width / 2,
vals,
width,
label=f"{name} SHAP",
color=color_exp,
edgecolor="#333",
alpha=0.80,
)
ax.set_title(f"{name} Explainer", fontsize=13, fontweight="bold", pad=14)
ax.set_ylabel("Shapley Value", fontsize=11)
ax.set_xticks(features)
ax.set_xticklabels(feature_names, fontsize=10)
# Move legend outside plot for clarity, avoid overlap
ax.legend(
loc="upper center",
bbox_to_anchor=(0.5, -0.20),
ncol=2,
fontsize=10,
frameon=False,
)
ax.grid(axis="y", linestyle="--", linewidth=0.8, alpha=0.7)
ax.set_facecolor("#fafafa")
ax.tick_params(axis="both", labelsize=9)
axes[-1].set_xlabel("Feature Index", fontsize=11)
plt.tight_layout(rect=[0, 0.08, 1, 1]) # Leave space for legend
if save:
plt.savefig(save, bbox_inches="tight", dpi=300)
print(f"Figure saved to {save}")
plt.show()