Source code for shap_enhanced.explainers.ERSHAP

"""
ER-SHAP: Ensemble of Random SHAP Explainer
==========================================

Theoretical Explanation
-----------------------

ER-SHAP is a computationally efficient, ensemble-based approximation of Shapley values, designed for
both sequential and tabular models. Instead of exhaustively enumerating all possible coalitions,
ER-SHAP repeatedly samples random subsets of feature–timestep positions and estimates their
marginal contributions to model output.

This stochastic approach significantly accelerates SHAP estimation while maintaining interpretability,
especially in high-dimensional or temporal settings. ER-SHAP also allows prior knowledge (e.g., feature importance)
to guide coalition sampling through weighted schemes.

Key Concepts
^^^^^^^^^^^^

- **Random Coalition Sampling**:
    For each position \\((t, f)\\), sample coalitions \\( C \\subseteq (T \times F) \\setminus \\{(t, f)\\} \\)
    and estimate the marginal contribution of \\((t, f)\\) by measuring its impact on model output.

- **Weighted Sampling**:
    Coalition sampling can be uniform or weighted based on prior feature importance scores
    or positional frequency, allowing informed, efficient sampling.

- **Flexible Masking**:
    Masked features are imputed using:
        - Zeros (hard masking).
        - Feature-wise means from the background dataset (soft masking).

- **Additivity Normalization**:
    Final attributions are scaled so that their sum matches the model output difference
    between the original and fully-masked input.

Algorithm
---------

1. **Initialization**:
    - Accepts a model, background dataset for imputation, number of sampled coalitions,
        masking strategy (`'zero'` or `'mean'`), weighting scheme, optional feature importance, and device context.

2. **Coalition Sampling**:
    - For each feature–timestep pair \\((t, f)\\):
        - Sample coalitions \\( C \\subseteq (T \times F) \\setminus \\{(t, f)\\} \\), either uniformly or using weights.
        - For each coalition:
            - Impute the coalition \\( C \\) in the input.
            - Impute the coalition \\( C \\cup \\{(t, f)\\} \\).
            - Compute the model output difference.
        - Average these differences to estimate the marginal contribution of \\((t, f)\\).

3. **Normalization**:
    - Scale the final attributions so that their total equals the difference in model output
        between the original input and a fully-masked baseline.

References
----------

- **Lundberg & Lee (2017), “A Unified Approach to Interpreting Model Predictions”**  
  [SHAP foundation—coalitional feature attribution framework]

- **Castro et al. (2009) and Mann & Shapley (1960), Monte Carlo sampling for Shapley values**  
  [Introduces simple uniform random sampling of permutations/coalitions for Shapley estimation] :contentReference[oaicite:1]{index=1}

- **Okhrati & Lipani (2020), “A Multilinear Sampling Algorithm to Estimate Shapley Values”**  
  [Proposes variance-reduced sampling for Shapley value estimation via multilinear extensions] :contentReference[oaicite:2]{index=2}

- **Witter et al. (2025), “Regression‑adjusted Monte Carlo Estimators for Shapley Values and Probabilistic Values”**  
  [Combines Monte Carlo with regression adjustments to achieve more efficient, low-variance Shapley approximations] :contentReference[oaicite:3]{index=3}

- **Rozemberczki et al. (2023), “Ensembles of Random SHAPs” (ER‑SHAP)**  
  [Directly describes ER‑SHAP: building ensembles of SHAPs over random subsets and averaging—also includes weighted sampling via preliminary importance] :contentReference[oaicite:4]{index=4}

- **Maleki et al. (2013), “Bounding the Estimation Error of Sampling‑based Shapley Value Approximation”**  
  [Provides theoretical error bounds for Monte Carlo approximation and discusses stratified sampling for variance reduction] :contentReference[oaicite:5]{index=5}
"""

import numpy as np
import torch

from shap_enhanced.base_explainer import BaseExplainer


[docs] class ERSHAPExplainer(BaseExplainer): """ ER-SHAP: Ensemble of Random SHAP Explainer An efficient approximation of Shapley values using random coalition sampling over time-feature positions. Supports uniform and weighted sampling strategies and flexible masking (zero or mean) to generate perturbed inputs. :param model: Model to explain, compatible with PyTorch tensors. :type model: Any :param background: Background dataset for mean imputation; shape (N, T, F). :type background: np.ndarray or torch.Tensor :param n_coalitions: Number of coalitions to sample per (t, f) position. :type n_coalitions: int :param mask_strategy: Masking method: 'zero' or 'mean'. :type mask_strategy: str :param weighting: Sampling scheme: 'uniform', 'frequency', or 'importance'. :type weighting: str :param feature_importance: Prior feature importances for weighted sampling; shape (T, F). :type feature_importance: Optional[np.ndarray] :param device: Device identifier, 'cpu' or 'cuda'. :type device: str """ def __init__( self, model, background, n_coalitions=100, mask_strategy="mean", weighting="uniform", feature_importance=None, device=None, ): super().__init__(model, background) self.n_coalitions = n_coalitions self.mask_strategy = mask_strategy self.weighting = weighting self.feature_importance = feature_importance self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") if mask_strategy == "mean": self._mean = background.mean(axis=0) else: self._mean = None def _impute(self, X, idxs): r""" Apply masking strategy to selected (t, f) indices in input. - 'zero': Replace with 0.0. - 'mean': Use mean value from background dataset. :param X: Input sample of shape (T, F). :type X: np.ndarray :param idxs: List of (t, f) pairs to mask. :type idxs: list[tuple[int, int]] :return: Masked/imputed version of X. :rtype: np.ndarray """ X_imp = X.copy() for t, f in idxs: if self.mask_strategy == "zero": X_imp[t, f] = 0.0 elif self.mask_strategy == "mean": X_imp[t, f] = self._mean[t, f] else: raise ValueError(f"Unknown mask_strategy: {self.mask_strategy}") return X_imp def _sample_coalition(self, available, k, weights=None): """ Sample a coalition of k positions from the available list. If weights are provided, sampling is weighted; otherwise, uniform. :param available: List of available (t, f) pairs. :type available: list[tuple[int, int]] :param k: Number of elements to sample. :type k: int :param weights: Sampling probabilities aligned with `available`. :type weights: Optional[np.ndarray] :return: List of sampled (t, f) pairs. :rtype: list[tuple[int, int]] """ if weights is not None: weights = np.array([weights[idx] for idx in available]) weights = weights / (weights.sum() + 1e-8) idxs = np.random.choice(len(available), size=k, replace=False, p=weights) else: idxs = np.random.choice(len(available), size=k, replace=False) return [available[i] for i in idxs]
[docs] def shap_values(self, X, check_additivity=True, random_seed=42, **kwargs): r""" Compute SHAP values via random coalition sampling. For each position (t, f), sample coalitions of other positions, compute marginal contributions, and average over samples. Attributions are normalized to satisfy: .. math:: \sum_{t=1}^T \sum_{f=1}^F \phi_{t,f} \approx f(x) - f(x_{masked}) :param X: Input array or tensor of shape (T, F) or (B, T, F). :type X: np.ndarray or torch.Tensor :param check_additivity: Whether to apply normalization for additivity. :type check_additivity: bool :param random_seed: Seed for reproducibility. :type random_seed: int :return: SHAP values of shape (T, F) or (B, T, F). :rtype: np.ndarray """ np.random.seed(random_seed) is_torch = hasattr(X, "detach") X_in = X.detach().cpu().numpy() if is_torch else np.asarray(X) shape = X_in.shape if len(shape) == 2: X_in = X_in[None, ...] single = True else: single = False B, T, F = X_in.shape shap_vals = np.zeros((B, T, F), dtype=float) for b in range(B): x_orig = X_in[b] all_pos = [(t, f) for t in range(T) for f in range(F)] shap_matrix = np.zeros((T, F)) for t in range(T): for f in range(F): mc = [] available = [idx for idx in all_pos if idx != (t, f)] # Define weights for coalition sampling weights = None if ( self.weighting == "importance" and self.feature_importance is not None ): flat_imp = self.feature_importance.flatten() idx_map = {idx: i for i, idx in enumerate(all_pos)} weights = np.array( [flat_imp[idx_map[idx]] for idx in available] ) weights = weights / (weights.sum() + 1e-8) elif self.weighting == "frequency": weights = None # Implemented as uniform, could use prior freq for _ in range(self.n_coalitions): # Uniform or weighted coalition size (avoid full/empty) k = np.random.randint(1, len(available) + 1) C_idxs = self._sample_coalition(available, k, weights) x_C = self._impute(x_orig, C_idxs) x_C_tf = self._impute(x_C, [(t, f)]) out_C = ( self.model( torch.tensor( x_C[None], dtype=torch.float32, device=self.device ) ) .detach() .cpu() .numpy() .squeeze() ) out_C_tf = ( self.model( torch.tensor( x_C_tf[None], dtype=torch.float32, device=self.device, ) ) .detach() .cpu() .numpy() .squeeze() ) mc.append(out_C_tf - out_C) shap_matrix[t, f] = np.mean(mc) shap_vals[b] = shap_matrix # Additivity normalization per sample orig_pred = ( self.model( torch.tensor(x_orig[None], dtype=torch.float32, device=self.device) ) .detach() .cpu() .numpy() .squeeze() ) x_all_masked = self._impute(x_orig, all_pos) masked_pred = ( self.model( torch.tensor( x_all_masked[None], dtype=torch.float32, device=self.device ) ) .detach() .cpu() .numpy() .squeeze() ) shap_sum = shap_vals[b].sum() model_diff = orig_pred - masked_pred if shap_sum != 0: shap_vals[b] *= model_diff / shap_sum shap_vals = shap_vals[0] if single else shap_vals if check_additivity: print( f"[ERSHAP Additivity] sum(SHAP)={shap_vals.sum():.4f} | Model diff={float(orig_pred - masked_pred):.4f}" ) return shap_vals