Source code for shap_enhanced.explainers.CASHAP

"""
CASHAP: Coalition-Aware SHAP Explainer
======================================

Theoretical Explanation
-----------------------

CASHAP (Coalition-Aware SHAP) is a Shapley value estimation framework tailored for models that process sequential or structured inputs, such as LSTMs.
Unlike classical SHAP methods that treat features independently, CASHAP considers **feature-time pairs**—enabling attribution of both spatial and temporal components.

By explicitly sampling coalitions (subsets) of feature-time pairs and measuring marginal contributions, CASHAP provides granular, context-aware explanations.
It also supports multiple imputation strategies to ensure the perturbed inputs remain valid and interpretable.

Key Concepts
^^^^^^^^^^^^

- **Coalition Sampling**: For every feature-time pair \\((t, f)\\), random subsets of all other positions are sampled.
    The contribution of \\((t, f)\\) is assessed by adding it to each coalition and measuring the change in model output.
- **Masking/Imputation Strategies**:
    - **Zero masking**: Replace masked values with zero.
    - **Mean imputation**: Use feature-wise means from background data.
    - **Custom imputers**: Support for user-defined imputation functions.
- **Model-Agnostic & Domain-General**: While ideal for time-series and sequential models, CASHAP can also be applied to tabular data
    wherever structured coalition masking is appropriate.
- **Additivity Normalization**: Attribution scores are scaled such that their total sum equals the difference in model output
    between the original input and a fully-masked version.

Algorithm
---------

1. **Initialization**:
    - Accepts a model, background data for imputation, masking strategy, optional custom imputer, and device context.

2. **Coalition Sampling**:
    - For each feature-time pair \\((t, f)\\):
        - Sample coalitions \\( C \\subseteq (T \times F) \\setminus \\{(t, f)\\} \\).
        - For each coalition \\( C \\):
            - Impute features in \\( C \\) using the chosen strategy.
            - Impute features in \\( C \\cup \\{(t, f)\\} \\).
            - Compute and record the model output difference.

3. **Attribution Estimation**:
    - Average the output differences across coalitions to estimate the marginal contribution of \\((t, f)\\).

4. **Normalization**:
    - Normalize attributions so that their total matches the difference between the model's prediction
        on the original and the fully-masked input.

References
----------

- **Lundberg & Lee (2017), “A Unified Approach to Interpreting Model Predictions”**  
  [SHAP foundation—coalitional feature attribution framework]

- **Jutte et al. (2025), “C‑SHAP for time series: An approach to high‑level temporal explanations”**  
  [Applies concept‑based SHAP to structured temporal data; treats temporal segments or concepts as coalition elements] :contentReference[oaicite:1]{index=1}

- **Schlegel et al. (2019), “Towards a Rigorous Evaluation of XAI Methods on Time Series”**  
  [Evaluates how SHAP and other methods behave for sequential/time‑series models, highlighting temporal structure challenges] :contentReference[oaicite:2]{index=2}

- **Franco de la Peña et al. (2025), “ShaTS: A Shapley‑based Explainability Method for Time Series Models”**  
  [Proposes temporally aware grouping for Shapley attribution in sequential IoT data, preserving temporal dependencies] :contentReference[oaicite:3]{index=3}

- **Molnar, “Interpretable Machine Learning” (2022), SHAP chapter**  
  [Describes masking and coalition sampling strategies, including dealing with structured or dependent features]
"""

from collections.abc import Callable
from typing import Any

import numpy as np
import torch

from shap_enhanced.base_explainer import BaseExplainer


[docs] class CoalitionAwareSHAPExplainer(BaseExplainer): """ Coalition-Aware SHAP (CASHAP) Explainer Estimates Shapley values for models processing structured inputs (e.g., time-series, sequences) by sampling coalitions of feature-time pairs and computing their marginal contributions using various imputation strategies. :param model: Model to be explained. :type model: Any :param background: Background data used for mean imputation strategy. :type background: Optional[np.ndarray or torch.Tensor] :param str mask_strategy: Strategy for imputing/masking feature-time pairs. Options: 'zero', 'mean', or 'custom'. :param imputer: Custom callable for imputation. Required if `mask_strategy` is 'custom'. :type imputer: Optional[Callable] :param device: Device on which computation runs. Defaults to 'cuda' if available. :type device: Optional[str] """ def __init__( self, model: Any, background: np.ndarray | torch.Tensor | None = None, mask_strategy: str = "zero", imputer: Callable | None = None, device: str | None = None, ): super().__init__(model, background) self.mask_strategy = mask_strategy self.imputer = imputer self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") # Precompute mean if needed if mask_strategy == "mean": if background is None: raise ValueError("Mean imputation requires background data.") self._mean = ( background.mean(axis=0) if isinstance(background, np.ndarray) else background.float().mean(dim=0) ) else: self._mean = None def _mask(self, X, idxs, value=None): """ Mask specified feature-time positions in the input. :param X: Input array (T, F) or tensor. :type X: np.ndarray or torch.Tensor :param idxs: List of (t, f) index pairs to mask. :type idxs: list[tuple[int, int]] :param value: Value to replace at masked positions. Defaults to 0.0. :return: Masked version of the input. :rtype: Same as input type """ X_masked = X.copy() if isinstance(X, np.ndarray) else X.clone() for t, f in idxs: if isinstance(X_masked, np.ndarray): X_masked[t, f] = value if value is not None else 0.0 else: X_masked[:, t, f] = value if value is not None else 0.0 return X_masked def _impute(self, X, idxs): """ Apply imputation strategy to specified positions. Imputation method depends on the selected `mask_strategy`: - 'zero': Set masked values to 0. - 'mean': Use mean values computed from background data. - 'custom': Use user-defined callable function. :param X: Input data (T, F). :type X: np.ndarray or torch.Tensor :param idxs: Positions to impute, as (t, f) tuples. :type idxs: list[tuple[int, int]] :return: Imputed input. :rtype: Same as input type """ if self.mask_strategy == "zero": return self._mask(X, idxs, value=0.0) elif self.mask_strategy == "mean": mean_val = ( self._mean if isinstance(X, np.ndarray) else self._mean.unsqueeze(0).expand_as(X) ) X_imp = X.copy() if isinstance(X, np.ndarray) else X.clone() for t, f in idxs: if isinstance(X_imp, np.ndarray): X_imp[t, f] = mean_val[t, f] else: X_imp[:, t, f] = mean_val[t, f] return X_imp elif self.mask_strategy == "custom": assert self.imputer is not None, "Custom imputer must be provided." return self.imputer(X, idxs) else: raise ValueError(f"Unknown mask_strategy: {self.mask_strategy}") def _get_model_output(self, X): """ Ensures model input is always a torch.Tensor on the correct device. Accepts (T, F) or (B, T, F), returns numpy array or float. """ if isinstance(X, np.ndarray): X = torch.tensor(X, dtype=torch.float32, device=self.device) elif isinstance(X, torch.Tensor): X = X.to(self.device) else: raise ValueError("Input must be np.ndarray or torch.Tensor.") with torch.no_grad(): out = self.model(X) # Out can be (B,), (B,1), or scalar. Always return numpy return out.cpu().numpy() if hasattr(out, "cpu") else np.asarray(out)
[docs] def shap_values( self, X: np.ndarray | torch.Tensor, nsamples: int = 100, coalition_size: int | None = None, mask_strategy: str | None = None, check_additivity: bool = True, random_seed: int = 42, **kwargs, ) -> np.ndarray: """ Compute CASHAP Shapley values for structured inputs via coalition-aware sampling. For each feature-time pair \\((t, f)\\), randomly sample coalitions excluding \\((t, f)\\), compute model outputs with and without the pair added, and average the marginal contributions. Attribution values are normalized so their total matches the model output difference between the original and fully-masked input. .. math:: \\phi_{t,f} \approx \\mathbb{E}_{C \\subseteq (T \times F) \\setminus \\{(t,f)\\}} \\left[ f(C \\cup \\{(t,f)\\}) - f(C) \right] .. note:: Normalization ensures: \\sum_{t=1}^T \\sum_{f=1}^F \\phi_{t,f} \approx f(x) - f(x_{\text{masked}}) :param X: Input sample of shape (T, F) or batch (B, T, F). :type X: np.ndarray or torch.Tensor :param nsamples: Number of coalitions sampled per (t, f). :type nsamples: int :param coalition_size: Fixed size of sampled coalitions. If None, varies randomly. :type coalition_size: Optional[int] :param mask_strategy: Override default masking strategy. :type mask_strategy: Optional[str] :param check_additivity: Print diagnostic SHAP sum vs. model delta. :type check_additivity: bool :param random_seed: Seed for reproducibility. :type random_seed: int :return: SHAP values of shape (T, F) or (B, T, F). :rtype: np.ndarray """ np.random.seed(random_seed) mask_strategy = mask_strategy or self.mask_strategy is_torch = isinstance(X, torch.Tensor) X_in = X.detach().cpu().numpy() if is_torch else np.asarray(X) shape = X_in.shape if len(shape) == 2: # (T, F) X_in = X_in[None, ...] # add batch dim B, T, F = X_in.shape shap_vals = np.zeros((B, T, F), dtype=float) for b in range(B): x_orig = X_in[b] for t in range(T): for f in range(F): contribs = [] all_pos = [ (i, j) for i in range(T) for j in range(F) if (i, j) != (t, f) ] for _ in range(nsamples): # Improved: Systematic coalition size if coalition_size is not None: k = coalition_size else: k = np.random.randint(1, len(all_pos) + 1) C_idxs = list( np.random.choice(len(all_pos), size=k, replace=False) ) C_idxs = [all_pos[idx] for idx in C_idxs] # Mask coalition (C) only x_C = self._impute(x_orig, C_idxs) # Mask coalition plus (t, f) x_C_tf = self._impute(x_C, [(t, f)]) # Compute outputs out_C = self._get_model_output(x_C[None])[0] out_C_tf = self._get_model_output(x_C_tf[None])[0] contrib = out_C_tf - out_C contribs.append(contrib) shap_vals[b, t, f] = np.mean(contribs) # Additivity correction per sample orig_pred = self._get_model_output(x_orig[None])[0] x_all_masked = self._impute( x_orig, [(ti, fi) for ti in range(T) for fi in range(F)] ) masked_pred = self._get_model_output(x_all_masked[None])[0] shap_sum = shap_vals[b].sum() model_diff = orig_pred - masked_pred if shap_sum != 0: shap_vals[b] *= model_diff / shap_sum shap_vals = shap_vals[0] if len(shape) == 2 else shap_vals if check_additivity: print( f"[CASHAP Additivity] sum(SHAP)={shap_vals.sum():.4f} | Model diff={float(orig_pred - masked_pred):.4f}" ) return shap_vals
if __name__ == "__main__": import numpy as np import torch import torch.nn as nn # --- Dummy LSTM model for demo --- class DummyLSTM(nn.Module): def __init__(self, input_dim=3, hidden_dim=8, output_dim=1): super().__init__() self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True) self.fc = nn.Linear(hidden_dim, output_dim) def forward(self, x): # Ensure input is float tensor if not torch.is_tensor(x): x = torch.tensor(x, dtype=torch.float32) x = x.float() # x: (B, T, F) out, _ = self.lstm(x) # Use last time step's output out = self.fc(out[:, -1, :]) return out.squeeze(-1) # (B,) # --- Generate synthetic data --- np.random.seed(0) torch.manual_seed(0) B, T, F = 2, 5, 3 train_X = np.random.normal(0, 1, (20, T, F)).astype(np.float32) test_X = np.random.normal(0, 1, (B, T, F)).astype(np.float32) # --- Initialize model and explainer --- model = DummyLSTM(input_dim=F, hidden_dim=8, output_dim=1) model.eval() explainer = CoalitionAwareSHAPExplainer( model=model, background=train_X, mask_strategy="mean" ) # --- Compute SHAP values --- shap_vals = explainer.shap_values( test_X, # (B, T, F) nsamples=10, # small for demo, increase for quality coalition_size=4, # mask 4 pairs at a time check_additivity=True, ) print("SHAP values shape:", shap_vals.shape) print("First sample SHAP values:\n", shap_vals[0])