Source code for shap_enhanced.tools.comparison
"""
Attribution Comparison Utility for SHAP Explainers
==================================================
Overview
--------
This module provides a utility class for quantitatively comparing SHAP attributions
from multiple explainers against a reference ground truth. It is intended for use in benchmarking
or evaluating new SHAP-based methods by computing standard performance metrics.
Currently supported evaluation metrics include:
- **Mean Squared Error (MSE)**: Measures the squared deviation between predicted and ground-truth attributions.
- **Pearson Correlation**: Measures the linear correlation between flattened attribution arrays.
Key Components
^^^^^^^^^^^^^^
- **Comparison Class**:
- Accepts ground-truth SHAP values and a dictionary of predicted attribution maps.
- Computes MSE and Pearson correlation for each explainer.
- Handles flattened comparison over all timesteps and features.
Use Case
--------
This utility is ideal for:
- Benchmarking SHAP-style explainers on synthetic datasets with known ground truth.
- Evaluating the effect of surrogate or approximation methods.
- Comparing different explainer strategies in attribution consistency.
Example
-------
.. code-block:: python
gt = np.random.rand(10, 5) # Ground truth SHAP values
pred1 = gt + np.random.normal(0, 0.1, size=gt.shape)
pred2 = gt + np.random.normal(0, 0.2, size=gt.shape)
comp = Comparison(ground_truth=gt, shap_models={"ExplainerA": pred1, "ExplainerB": pred2})
mse_scores, pearson_scores = comp.calculate_kpis()
"""
import numpy as np
from scipy.stats import pearsonr
__all__ = ["Comparison"]
[docs]
class Comparison:
r"""
Comparison: SHAP Attribution Evaluation Utility
Provides evaluation metrics for comparing predicted SHAP attributions against a ground truth reference.
Designed for benchmarking SHAP-based explainers using quantitative metrics.
Supported Metrics
-----------------
- **Mean Squared Error (MSE)**: Measures squared deviation between predicted and true SHAP values.
- **Pearson Correlation**: Measures linear correlation between flattened attribution vectors.
:param np.ndarray ground_truth: Ground-truth SHAP values of shape (T, F).
:param dict shap_models: Dictionary mapping explainer names to their SHAP attribution arrays.
"""
def __init__(self, ground_truth, shap_models):
self.ground_truth = ground_truth
self.shap_models = shap_models
self.results = {}
self.pearson_results = {}
[docs]
def calculate_kpis(self):
r"""
Compute evaluation metrics (MSE and Pearson correlation) for each SHAP explainer.
.. note::
Flattened comparisons are used for both MSE and correlation.
:return: Tuple of dictionaries:
- MSE values for each explainer.
- Pearson correlation values for each explainer.
:rtype: (dict[str, float], dict[str, float])
"""
for name, arr in self.shap_models.items():
mse = np.mean((arr - self.ground_truth) ** 2)
gt_flat = self.ground_truth.flatten()
arr_flat = arr.flatten()
try:
pearson, _ = pearsonr(gt_flat, arr_flat)
except Exception:
pearson = np.nan
self.results[name] = mse
self.pearson_results[name] = pearson
return self.results, self.pearson_results