Source code for shap_enhanced.explainers.MBSHAP

"""
MB-SHAP: Multi-Baseline SHAP Explainer
======================================

Theoretical Explanation
-----------------------

Multi-Baseline SHAP (MB-SHAP) enhances the robustness of SHAP-based feature attribution
by computing SHAP values with respect to multiple baselines rather than a single reference.
This addresses a key limitation in standard SHAP explainers: their sensitivity to baseline selection.

By averaging attributions from diverse or locally-relevant baselines (e.g., nearest neighbors, mean, k-means centroids),
MB-SHAP produces more stable, reliable, and representative explanations—particularly useful in domains with
heterogeneous data distributions or models that exhibit local nonlinearity.

Key Concepts
^^^^^^^^^^^^

- **Multiple Baselines**:
    Each input is explained with respect to a set of baselines instead of just one. Baseline options include:
        - Random background samples.
        - Mean or centroid-based references.
        - K nearest neighbors (local context).
        - User-specified selections.

- **Explainer Flexibility**:
    MB-SHAP is compatible with any SHAP-style explainer, including `DeepExplainer`, `GradientExplainer`, and `KernelExplainer`.
    It wraps the base explainer and runs it separately for each baseline.

- **Attribution Averaging**:
    For each input sample:
        - SHAP values are computed with respect to each baseline.
        - The resulting attribution vectors are averaged to yield a final, smoothed explanation.

- **Local Fidelity**:
    Using per-input nearest neighbors as baselines helps improve explanation fidelity for local model behavior.

Algorithm
---------

1. **Initialization**:
    - Accepts a model, background dataset, number of baselines, baseline selection strategy (`'random'`, `'nearest'`, `'mean'`, `'kmeans'`, etc.),
        SHAP explainer class (e.g., `shap.DeepExplainer`), and device context.

2. **Baseline Selection**:
    - For each input sample:
        - Select multiple baseline samples from the background using the chosen strategy.

3. **SHAP Value Computation**:
    - For each selected baseline:
        - Instantiate the base SHAP explainer.
        - Compute SHAP values for the input sample with respect to that baseline.
    - Average the SHAP results across all baselines.

4. **Output**:
    - Return the final attributions as averaged SHAP values, preserving shape and semantics of the model input.

References
----------

- **Lundberg & Lee (2017), “A Unified Approach to Interpreting Model Predictions”**  
  [SHAP foundation—coalitional feature attribution framework]

- **Chen et al. (2022), “Explaining a Series of Models by Propagating Shapley Values” (G‑DeepSHAP)**  
  [Uses multiple baselines and shows that averaging explanations across them improves consistency and fidelity] :contentReference[oaicite:1]{index=1}

- **Google Vertex AI documentation (2025)**  
  [Allows multiple baseline specifications (e.g. min, max, random) to improve attribution context and stability] :contentReference[oaicite:2]{index=2}

- **Sundararajan & Najmi (2020), “The Many Shapley Values for Model Explanation”**  
  [Discusses how baseline selection influences SHAP-value interpretations, and the implications of multiple baseline settings] :contentReference[oaicite:3]{index=3}

- **Shaping Up SHAP: Enhancing Stability through Layer-Wise Neighbor Selection (Kelodjou et al., 2023)**  
  [Highlights instability in KernelSHAP and proposes neighbor sampling strategies to stabilize results, underscoring need for ensemble or multi-baseline approaches] :contentReference[oaicite:4]{index=4}
"""

import inspect

import numpy as np
import torch

from shap_enhanced.base_explainer import BaseExplainer


[docs] class NearestNeighborMultiBaselineSHAP(BaseExplainer): r""" NearestNeighborMultiBaselineSHAP: Multi-Baseline SHAP Explainer This explainer improves attribution robustness by selecting the K nearest neighbors from a background dataset as baselines for each input sample, computing SHAP values individually for each baseline, and then averaging the results. It is compatible with various SHAP explainers (e.g., `DeepExplainer`, `GradientExplainer`, `KernelExplainer`) and automatically adapts input types and parameter formats accordingly. .. note:: Baseline selection is input-dependent and done per sample using L2 distance in flattened input space. :param base_explainer_class: The SHAP explainer class to use (e.g., `shap.DeepExplainer`). :param model: The predictive model to explain. :type model: Any :param background: Background dataset (N, ...) for nearest neighbor selection. :type background: np.ndarray :param int n_baselines: Number of nearest neighbor baselines to use per sample. :param base_explainer_kwargs: Additional keyword arguments passed to the SHAP explainer. :type base_explainer_kwargs: dict or None :param str device: Device context for torch-based explainers ('cpu' or 'cuda'). """ def __init__( self, base_explainer_class, model, background, n_baselines=5, base_explainer_kwargs=None, device=None, ): self.base_explainer_class = base_explainer_class self.model = model self.background = np.asarray(background) self.n_baselines = n_baselines self.base_explainer_kwargs = base_explainer_kwargs or {} self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") def _to_torch(self, arr): if isinstance(arr, torch.Tensor): return arr.to(self.device) arr = np.asarray(arr, dtype=np.float32) return torch.tensor(arr, dtype=torch.float32, device=self.device) def _make_explainer(self, baseline): # Handle different explainer parameter names cls = self.base_explainer_class sig = inspect.signature(cls.__init__) params = list(sig.parameters.keys()) params = [p for p in params if p != "self"] if len(params) == 1: return cls(self.model, **self.base_explainer_kwargs) elif len(params) > 1: param2 = params[1].lower() if param2 in ("data", "background"): return cls(self.model, baseline, **self.base_explainer_kwargs) else: return cls( self.model, **{param2: baseline}, **self.base_explainer_kwargs ) else: raise RuntimeError("Cannot infer how to call explainer_class!") def _safe_shap_values(self, explainer, x, **kwargs): # Adds check_additivity=False if possible sig = inspect.signature(explainer.shap_values) params = sig.parameters if "check_additivity" in params: return explainer.shap_values(x, check_additivity=False, **kwargs) else: return explainer.shap_values(x, **kwargs)
[docs] def shap_values(self, X, **kwargs): r""" Compute SHAP values using per-sample nearest neighbor baselines. For each sample in `X`, this method: 1. Selects the `n_baselines` nearest neighbors from the background. 2. Instantiates the explainer with the selected baselines. 3. Computes SHAP values with respect to each baseline. 4. Averages SHAP values across baselines to produce a robust explanation. .. math:: \phi(x) = \frac{1}{K} \sum_{k=1}^{K} \text{SHAP}(x | b_k) :param X: Input samples to explain, shape (N, ...) or single sample (...). :type X: np.ndarray :param kwargs: Additional keyword arguments forwarded to the SHAP explainer. :return: Averaged SHAP attributions, shape (N, ...) or (...) for single input. :rtype: np.ndarray """ X = np.asarray(X) if X.ndim == 2: X = X[None] n_samples = X.shape[0] bg_flat = self.background.reshape(self.background.shape[0], -1) attributions = [] for i in range(n_samples): x = X[i] x_flat = x.reshape(-1) # K nearest neighbors in background dists = np.linalg.norm(bg_flat - x_flat, axis=1) idx = np.argsort(dists)[: self.n_baselines] nn_bases = self.background[idx] # (n_baselines, T, F) # DeepExplainer wants a batch of backgrounds as tensor! expl_name = self.base_explainer_class.__name__.lower() if "deep" in expl_name or "gradient" in expl_name: baseline_ = self._to_torch(nn_bases) x_torch = self._to_torch(x[None]) expl = self._make_explainer(baseline_) attr = self._safe_shap_values(expl, x_torch, **kwargs) if isinstance(attr, list): attr = attr[0] if torch.is_tensor(attr): attr = attr.detach().cpu().numpy() sample_avg = attr[0] # first sample in batch else: # For KernelExplainer, background should be np.ndarray baseline_ = nn_bases x_in = x[None] expl = self._make_explainer(baseline_) attr = self._safe_shap_values(expl, x_in, **kwargs) if isinstance(attr, list): attr = attr[0] sample_avg = attr[0] # (T, F) or similar attributions.append(sample_avg) attributions = np.stack(attributions, axis=0) if attributions.shape[0] == 1: return attributions[0] return attributions