shap_enhanced.explainers.SurroSHAP

SurroSHAP: Surrogate Model SHAP Explainer

Theoretical Explanation

SurroSHAP is a surrogate modeling approach to SHAP that accelerates feature attribution by training a regression model to mimic SHAP values produced by a base explainer. Once trained, the surrogate regressor can produce fast, approximate SHAP values for new inputs, bypassing the computational expense of re-running the base SHAP explainer.

This method is particularly useful for large datasets, expensive black-box models, or scenarios where near-real-time interpretability is needed.

Key Concepts

  • Surrogate Regression:

    A regression model (e.g., Random Forest, Kernel Ridge, MLP) is trained to predict SHAP attributions using inputs as features and base SHAP values as targets.

  • Base SHAP Explainer:

    Any standard SHAP-style explainer (e.g., DeepExplainer, GradientExplainer, KernelExplainer) can be used to generate training labels (pseudo-ground-truth SHAP values).

  • Optional Scaling:

    Input features and/or SHAP attributions can be standardized to improve the surrogate’s learning performance.

  • Fast Inference:

    Once trained, the surrogate model can rapidly produce SHAP attributions for unseen inputs without invoking the base SHAP explainer again.

Algorithm

  1. Initialization:
    • Accepts the following:
      • A predictive model to explain.

      • Background dataset for base SHAP explainer.

      • SHAP-style base explainer class.

      • Surrogate regressor class (e.g., sklearn.ensemble.RandomForestRegressor).

      • Number of training samples for surrogate fitting.

      • Options for input/output scaling.

      • Device context (if applicable).

  2. Surrogate Training:
    • Sample training points from the background dataset.

    • For each sample:
      • Compute SHAP values using the base explainer.

      • Flatten the input and corresponding SHAP vector.

    • Optionally scale both inputs and targets.

    • Fit the surrogate regressor on the collected (input, attribution) pairs.

  3. SHAP Value Prediction:
    • For a new sample:
      • Flatten and optionally scale the input.

      • Predict SHAP attributions using the surrogate model.

      • Inverse-transform and reshape predictions to original attribution shape if needed.

Use Case

SurroSHAP is best suited for:
  • Large-scale datasets requiring rapid SHAP value generation.

  • Scenarios where base SHAP computation is slow or expensive.

  • Situations where approximate explanations are acceptable in exchange for speed.

References

  • Lundberg & Lee (2017), “A Unified Approach to Interpreting Model Predictions” [SHAP foundation—coalitional feature attribution framework]

  • Zhou, Chen & Hu (2022), “Shapley Computations Using Surrogate Model‑Based Trees” [Uses surrogate tree models to compute SHAP values via conditional expectation, trading accuracy for speed] :contentReference[oaicite:1]{index=1}

  • ShapGAP (2024) [A metric for evaluating surrogate model fidelity by comparing SHAP explanations of surrogate vs. black‑box models, ensuring surrogate explanations align in reasoning] :contentReference[oaicite:2]{index=2}

  • Arize MimicExplainer documentation [Describes practical use of surrogate explainability: fitting a model (e.g., RandomForest) to mimic black‑box outputs and generating SHAP values from surrogate] :contentReference[oaicite:3]{index=3}

  • Interpretable Machine Learning book (Molnar, 2022), SHAP chapter [Discusses surrogate/approximation strategies and trade‑offs between fidelity and computational efficiency] :contentReference[oaicite:4]{index=4}

Functions

ensure_shap_input(x, explainer[, device])

Ensure compatibility of input format with the SHAP explainer type.

shap_values_with_nsamples(base_explainer, x, ...)

Safely compute SHAP values with optional support for nsamples argument.

Classes

SurrogateSHAPExplainer(model, background, ...)

SurrogateSHAPExplainer: Fast SHAP Approximation via Supervised Regression

class shap_enhanced.explainers.SurroSHAP.SurrogateSHAPExplainer(model, background, base_explainer, regressor_class=<class 'sklearn.ensemble._forest.RandomForestRegressor'>, regressor_kwargs=None, nsamples_base=100, scale_inputs=True, scale_outputs=False, device=None)[source]

Bases: BaseExplainer

SurrogateSHAPExplainer: Fast SHAP Approximation via Supervised Regression

SurroSHAP accelerates SHAP attribution by training a surrogate regression model that maps input features to SHAP attributions. This is useful when repeated SHAP computation is too costly or when near-instant explanations are needed for deployment.

The surrogate model is trained on a background dataset where “true” SHAP values are first computed using a base explainer (e.g., DeepExplainer, KernelExplainer), and then used as regression targets.

Note

Any sklearn-style regressor can be used (e.g., RandomForestRegressor, KernelRidge, etc.).

Parameters:
  • model (Any) – Predictive model to be explained.

  • background (np.ndarray) – Background dataset for training surrogate and computing SHAP targets. Shape: (N, T, F).

  • base_explainer (Any) – A SHAP-style explainer instance (already constructed).

  • regressor_class (type) – Regressor class implementing fit/predict API. Defaults to RandomForestRegressor.

  • regressor_kwargs (dict) – Optional keyword arguments for the regressor.

  • nsamples_base (int) – Number of SHAP samples used for each background point.

  • scale_inputs (bool) – Whether to standardize input features during training.

  • scale_outputs (bool) – Whether to standardize SHAP values during training.

  • device (str or torch.device) – Torch device (e.g., ‘cpu’ or ‘cuda’).

property expected_value

Optional property returning the expected model output on the background dataset.

Returns:

Expected value if defined by the subclass, else None.

Return type:

float or None

explain(X, **kwargs)

Alias to shap_values for flexibility and API compatibility.

Parameters:
  • X (Union[np.ndarray, torch.Tensor, list]) – Input samples to explain.

  • kwargs – Additional arguments.

Returns:

SHAP values.

Return type:

Union[np.ndarray, list]

shap_values(X, **kwargs)[source]

Predict SHAP attributions for new inputs using the trained surrogate model.

The input is reshaped and (optionally) standardized to match the format used during surrogate training, and the predicted SHAP values are inverse-transformed (if scaling was applied).

Note

This bypasses SHAP computation entirely and relies on the surrogate regressor.

Parameters:

X (np.ndarray or torch.Tensor) – Input instance or batch, shape (T, F) or (B, T, F).

Returns:

Approximated SHAP attributions, same shape as input.

Return type:

np.ndarray

shap_enhanced.explainers.SurroSHAP.ensure_shap_input(x, explainer, device='cpu')[source]

Ensure compatibility of input format with the SHAP explainer type.

This function inspects the explainer type (e.g., DeepExplainer, KernelExplainer) and converts the input x into the appropriate format—either a NumPy array or a PyTorch tensor—based on the explainer’s requirements.

  • Deep/Gradient explainers require torch.Tensor input.

  • Kernel/Partition explainers require np.ndarray input.

Parameters:
  • x (np.ndarray or torch.Tensor) – Input sample to format, shape (T, F) or (1, T, F).

  • explainer (Any) – Instantiated SHAP explainer object.

  • device (str) – Target torch device (‘cpu’ or ‘cuda’).

Returns:

Properly formatted input for SHAP explainer.

Return type:

np.ndarray or torch.Tensor

Raises:

TypeError – If the input type is unsupported.

shap_enhanced.explainers.SurroSHAP.shap_values_with_nsamples(base_explainer, x, nsamples, **kwargs)[source]

Safely compute SHAP values with optional support for nsamples argument.

This utility inspects the signature of the shap_values method and attempts to call it with nsamples, check_additivity, and any additional kwargs. It includes fallback logic for older SHAP versions that may not support these parameters.

Parameters:
  • base_explainer (Any) – SHAP explainer instance.

  • x (np.ndarray or torch.Tensor) – Input to explain (already formatted for the explainer).

  • nsamples (int) – Number of samples to use for SHAP estimation.

  • kwargs – Additional keyword arguments for shap_values.

Returns:

SHAP attributions for the input.

Return type:

np.ndarray or list