"""
ER-SHAP: Ensemble of Random SHAP Explainer
==========================================
Theoretical Explanation
-----------------------
ER-SHAP is a computationally efficient, ensemble-based approximation of Shapley values, designed for
both sequential and tabular models. Instead of exhaustively enumerating all possible coalitions,
ER-SHAP repeatedly samples random subsets of feature–timestep positions and estimates their
marginal contributions to model output.
This stochastic approach significantly accelerates SHAP estimation while maintaining interpretability,
especially in high-dimensional or temporal settings. ER-SHAP also allows prior knowledge (e.g., feature importance)
to guide coalition sampling through weighted schemes.
Key Concepts
^^^^^^^^^^^^
- **Random Coalition Sampling**:
For each position \\((t, f)\\), sample coalitions \\( C \\subseteq (T \times F) \\setminus \\{(t, f)\\} \\)
and estimate the marginal contribution of \\((t, f)\\) by measuring its impact on model output.
- **Weighted Sampling**:
Coalition sampling can be uniform or weighted based on prior feature importance scores
or positional frequency, allowing informed, efficient sampling.
- **Flexible Masking**:
Masked features are imputed using:
- Zeros (hard masking).
- Feature-wise means from the background dataset (soft masking).
- **Additivity Normalization**:
Final attributions are scaled so that their sum matches the model output difference
between the original and fully-masked input.
Algorithm
---------
1. **Initialization**:
- Accepts a model, background dataset for imputation, number of sampled coalitions,
masking strategy (`'zero'` or `'mean'`), weighting scheme, optional feature importance, and device context.
2. **Coalition Sampling**:
- For each feature–timestep pair \\((t, f)\\):
- Sample coalitions \\( C \\subseteq (T \times F) \\setminus \\{(t, f)\\} \\), either uniformly or using weights.
- For each coalition:
- Impute the coalition \\( C \\) in the input.
- Impute the coalition \\( C \\cup \\{(t, f)\\} \\).
- Compute the model output difference.
- Average these differences to estimate the marginal contribution of \\((t, f)\\).
3. **Normalization**:
- Scale the final attributions so that their total equals the difference in model output
between the original input and a fully-masked baseline.
References
----------
- **Lundberg & Lee (2017), “A Unified Approach to Interpreting Model Predictions”**
[SHAP foundation—coalitional feature attribution framework]
- **Castro et al. (2009) and Mann & Shapley (1960), Monte Carlo sampling for Shapley values**
[Introduces simple uniform random sampling of permutations/coalitions for Shapley estimation] :contentReference[oaicite:1]{index=1}
- **Okhrati & Lipani (2020), “A Multilinear Sampling Algorithm to Estimate Shapley Values”**
[Proposes variance-reduced sampling for Shapley value estimation via multilinear extensions] :contentReference[oaicite:2]{index=2}
- **Witter et al. (2025), “Regression‑adjusted Monte Carlo Estimators for Shapley Values and Probabilistic Values”**
[Combines Monte Carlo with regression adjustments to achieve more efficient, low-variance Shapley approximations] :contentReference[oaicite:3]{index=3}
- **Rozemberczki et al. (2023), “Ensembles of Random SHAPs” (ER‑SHAP)**
[Directly describes ER‑SHAP: building ensembles of SHAPs over random subsets and averaging—also includes weighted sampling via preliminary importance] :contentReference[oaicite:4]{index=4}
- **Maleki et al. (2013), “Bounding the Estimation Error of Sampling‑based Shapley Value Approximation”**
[Provides theoretical error bounds for Monte Carlo approximation and discusses stratified sampling for variance reduction] :contentReference[oaicite:5]{index=5}
"""
import numpy as np
import torch
from shap_enhanced.base_explainer import BaseExplainer
[docs]
class ERSHAPExplainer(BaseExplainer):
"""
ER-SHAP: Ensemble of Random SHAP Explainer
An efficient approximation of Shapley values using random coalition sampling over
time-feature positions. Supports uniform and weighted sampling strategies and flexible
masking (zero or mean) to generate perturbed inputs.
:param model: Model to explain, compatible with PyTorch tensors.
:type model: Any
:param background: Background dataset for mean imputation; shape (N, T, F).
:type background: np.ndarray or torch.Tensor
:param n_coalitions: Number of coalitions to sample per (t, f) position.
:type n_coalitions: int
:param mask_strategy: Masking method: 'zero' or 'mean'.
:type mask_strategy: str
:param weighting: Sampling scheme: 'uniform', 'frequency', or 'importance'.
:type weighting: str
:param feature_importance: Prior feature importances for weighted sampling; shape (T, F).
:type feature_importance: Optional[np.ndarray]
:param device: Device identifier, 'cpu' or 'cuda'.
:type device: str
"""
def __init__(
self,
model,
background,
n_coalitions=100,
mask_strategy="mean",
weighting="uniform",
feature_importance=None,
device=None,
):
super().__init__(model, background)
self.n_coalitions = n_coalitions
self.mask_strategy = mask_strategy
self.weighting = weighting
self.feature_importance = feature_importance
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
if mask_strategy == "mean":
self._mean = background.mean(axis=0)
else:
self._mean = None
def _impute(self, X, idxs):
r"""
Apply masking strategy to selected (t, f) indices in input.
- 'zero': Replace with 0.0.
- 'mean': Use mean value from background dataset.
:param X: Input sample of shape (T, F).
:type X: np.ndarray
:param idxs: List of (t, f) pairs to mask.
:type idxs: list[tuple[int, int]]
:return: Masked/imputed version of X.
:rtype: np.ndarray
"""
X_imp = X.copy()
for t, f in idxs:
if self.mask_strategy == "zero":
X_imp[t, f] = 0.0
elif self.mask_strategy == "mean":
X_imp[t, f] = self._mean[t, f]
else:
raise ValueError(f"Unknown mask_strategy: {self.mask_strategy}")
return X_imp
def _sample_coalition(self, available, k, weights=None):
"""
Sample a coalition of k positions from the available list.
If weights are provided, sampling is weighted; otherwise, uniform.
:param available: List of available (t, f) pairs.
:type available: list[tuple[int, int]]
:param k: Number of elements to sample.
:type k: int
:param weights: Sampling probabilities aligned with `available`.
:type weights: Optional[np.ndarray]
:return: List of sampled (t, f) pairs.
:rtype: list[tuple[int, int]]
"""
if weights is not None:
weights = np.array([weights[idx] for idx in available])
weights = weights / (weights.sum() + 1e-8)
idxs = np.random.choice(len(available), size=k, replace=False, p=weights)
else:
idxs = np.random.choice(len(available), size=k, replace=False)
return [available[i] for i in idxs]
[docs]
def shap_values(self, X, check_additivity=True, random_seed=42, **kwargs):
r"""
Compute SHAP values via random coalition sampling.
For each position (t, f), sample coalitions of other positions,
compute marginal contributions, and average over samples.
Attributions are normalized to satisfy:
.. math::
\sum_{t=1}^T \sum_{f=1}^F \phi_{t,f} \approx f(x) - f(x_{masked})
:param X: Input array or tensor of shape (T, F) or (B, T, F).
:type X: np.ndarray or torch.Tensor
:param check_additivity: Whether to apply normalization for additivity.
:type check_additivity: bool
:param random_seed: Seed for reproducibility.
:type random_seed: int
:return: SHAP values of shape (T, F) or (B, T, F).
:rtype: np.ndarray
"""
np.random.seed(random_seed)
is_torch = hasattr(X, "detach")
X_in = X.detach().cpu().numpy() if is_torch else np.asarray(X)
shape = X_in.shape
if len(shape) == 2:
X_in = X_in[None, ...]
single = True
else:
single = False
B, T, F = X_in.shape
shap_vals = np.zeros((B, T, F), dtype=float)
for b in range(B):
x_orig = X_in[b]
all_pos = [(t, f) for t in range(T) for f in range(F)]
shap_matrix = np.zeros((T, F))
for t in range(T):
for f in range(F):
mc = []
available = [idx for idx in all_pos if idx != (t, f)]
# Define weights for coalition sampling
weights = None
if (
self.weighting == "importance"
and self.feature_importance is not None
):
flat_imp = self.feature_importance.flatten()
idx_map = {idx: i for i, idx in enumerate(all_pos)}
weights = np.array(
[flat_imp[idx_map[idx]] for idx in available]
)
weights = weights / (weights.sum() + 1e-8)
elif self.weighting == "frequency":
weights = None # Implemented as uniform, could use prior freq
for _ in range(self.n_coalitions):
# Uniform or weighted coalition size (avoid full/empty)
k = np.random.randint(1, len(available) + 1)
C_idxs = self._sample_coalition(available, k, weights)
x_C = self._impute(x_orig, C_idxs)
x_C_tf = self._impute(x_C, [(t, f)])
out_C = (
self.model(
torch.tensor(
x_C[None], dtype=torch.float32, device=self.device
)
)
.detach()
.cpu()
.numpy()
.squeeze()
)
out_C_tf = (
self.model(
torch.tensor(
x_C_tf[None],
dtype=torch.float32,
device=self.device,
)
)
.detach()
.cpu()
.numpy()
.squeeze()
)
mc.append(out_C_tf - out_C)
shap_matrix[t, f] = np.mean(mc)
shap_vals[b] = shap_matrix
# Additivity normalization per sample
orig_pred = (
self.model(
torch.tensor(x_orig[None], dtype=torch.float32, device=self.device)
)
.detach()
.cpu()
.numpy()
.squeeze()
)
x_all_masked = self._impute(x_orig, all_pos)
masked_pred = (
self.model(
torch.tensor(
x_all_masked[None], dtype=torch.float32, device=self.device
)
)
.detach()
.cpu()
.numpy()
.squeeze()
)
shap_sum = shap_vals[b].sum()
model_diff = orig_pred - masked_pred
if shap_sum != 0:
shap_vals[b] *= model_diff / shap_sum
shap_vals = shap_vals[0] if single else shap_vals
if check_additivity:
print(
f"[ERSHAP Additivity] sum(SHAP)={shap_vals.sum():.4f} | Model diff={float(orig_pred - masked_pred):.4f}"
)
return shap_vals