A Go library for computing SHAP (SHapley Additive exPlanations) values for ML model explainability.
SHAP-Go provides a Go-native implementation of SHAP value computation for explaining machine learning model predictions. It supports:
- 📦 ONNX models via ONNX Runtime bindings
- ⚙️ Custom models via a simple function interface
- 🔀 Permutation SHAP with antithetic sampling for variance reduction
- 🎲 Sampling SHAP using Monte Carlo estimation
- 📋 JSON-serializable explanations for audit/compliance
| Status | Explainer | Model Type | Notes |
|---|---|---|---|
| ✅ | TreeSHAP | Trees | Exact & fast (O(TLD²)) for XGBoost, LightGBM, CatBoost; 40-100x faster than permutation |
| ✅ | KernelSHAP | Any | Black-box, weighted linear regression, model-agnostic baseline |
| ✅ | LinearSHAP | Linear | Exact closed-form solution for linear/logistic regression |
| ✅ | DeepSHAP | Neural Nets | Combines DeepLIFT with Shapley values, efficient for deep networks |
| ✅ | GradientSHAP | Any | Expected gradients using numerical differentiation, works with any differentiable model |
| ✅ | SamplingSHAP | Any | Monte Carlo approximation, fast, good for quick estimates |
| ✅ | PermutationSHAP | Any | Black-box, antithetic sampling for variance reduction, guarantees local accuracy |
| ✅ | ExactSHAP | Any | Brute-force exact computation, O(2^n) - only for small feature sets (≤15) |
| ✅ | PartitionSHAP | Structured | Hierarchical Owen values for feature groupings, respects domain structure |
| ✅ | AdditiveSHAP | GAMs | Exact SHAP for Generalized Additive Models, O(n×b) complexity |
- ✅ Implemented
- ⬜ Not yet implemented
| Use Case | Recommended Explainer |
|---|---|
| Tree-based models (XGBoost, LightGBM) | TreeSHAP ✅ |
| Linear/logistic regression | LinearSHAP ✅ |
| Any model, need guaranteed accuracy | PermutationSHAP ✅ |
| Any model, weighted regression baseline | KernelSHAP ✅ |
| Any model, quick estimates | SamplingSHAP ✅ |
| Small feature sets (≤15 features) | ExactSHAP ✅ |
| Deep learning models (ONNX) | DeepSHAP ✅ |
| Differentiable models, gradient-based | GradientSHAP ✅ |
| Grouped/structured features | PartitionSHAP ✅ |
| Generalized Additive Models (GAMs) | AdditiveSHAP ✅ |
go get github.com/plexusone/shap-gopackage main
import (
"context"
"fmt"
"github.com/plexusone/shap-go/explainer"
"github.com/plexusone/shap-go/explainer/permutation"
"github.com/plexusone/shap-go/model"
)
func main() {
// Define a simple model
predict := func(ctx context.Context, input []float64) (float64, error) {
return input[0] + 2*input[1], nil
}
m := model.NewFuncModel(predict, 2)
// Background data for SHAP computation
background := [][]float64{
{0.0, 0.0},
}
// Create explainer
exp, _ := permutation.New(m, background,
explainer.WithNumSamples(100),
explainer.WithFeatureNames([]string{"x1", "x2"}),
)
// Explain a prediction
ctx := context.Background()
explanation, _ := exp.Explain(ctx, []float64{1.0, 2.0})
fmt.Printf("Prediction: %.2f\n", explanation.Prediction)
fmt.Printf("Base Value: %.2f\n", explanation.BaseValue)
for name, shap := range explanation.Values {
fmt.Printf("SHAP(%s): %.4f\n", name, shap)
}
// Verify local accuracy
result := explanation.Verify(1e-10)
fmt.Printf("Local accuracy valid: %v\n", result.Valid)
}TreeSHAP computes exact SHAP values in O(TLD²) time, where T=trees, L=leaves, D=depth. This is 40-100x faster than permutation-based methods for typical tree ensembles.
package main
import (
"context"
"fmt"
"log"
"github.com/plexusone/shap-go/explainer/tree"
)
func main() {
// Load XGBoost model (saved with model.save_model("model.json"))
ensemble, err := tree.LoadXGBoostModel("model.json")
if err != nil {
log.Fatal(err)
}
// Create TreeSHAP explainer
explainer, err := tree.New(ensemble)
if err != nil {
log.Fatal(err)
}
// Explain a prediction
ctx := context.Background()
instance := []float64{0.5, 0.3, 0.8}
explanation, err := explainer.Explain(ctx, instance)
if err != nil {
log.Fatal(err)
}
fmt.Printf("Prediction: %.4f\n", explanation.Prediction)
fmt.Printf("Base Value: %.4f\n", explanation.BaseValue)
for _, feat := range explanation.TopFeatures(10) {
fmt.Printf(" %s: %.4f\n", feat.Name, feat.Value)
}
}// Load LightGBM JSON model (saved with booster.dump_model())
ensemble, err := tree.LoadLightGBMModel("model.json")
if err != nil {
log.Fatal(err)
}
// Or load text format (saved with booster.save_model())
ensemble, err := tree.LoadLightGBMTextModel("model.txt")
explainer, err := tree.New(ensemble)
// ... same as XGBoostXGBoost:
import xgboost as xgb
model = xgb.Booster()
model.load_model("model.bin")
model.save_model("model.json") # JSON format for GoLightGBM:
import lightgbm as lgb
import json
model = lgb.Booster(model_file="model.txt")
with open("model.json", "w") as f:
json.dump(model.dump_model(), f)// Explain multiple instances in parallel
instances := [][]float64{
{0.1, 0.2, 0.3},
{0.4, 0.5, 0.6},
{0.7, 0.8, 0.9},
}
explanations, err := explainer.ExplainBatch(ctx, instances)TreeSHAP can compute pairwise feature interactions, revealing how features work together:
// Compute SHAP interaction values
result, err := explainer.ExplainInteractions(ctx, instance)
if err != nil {
log.Fatal(err)
}
// Interaction matrix properties:
// - Diagonal Interactions[i][i]: main effect of feature i
// - Off-diagonal Interactions[i][j]: interaction between features i and j
// - Symmetric: Interactions[i][j] == Interactions[j][i]
// - Rows sum to SHAP values: sum(Interactions[i][:]) == SHAP[i]
// Get the interaction between two features
interaction := result.GetInteraction(0, 1)
// Get main effect (diagonal)
mainEffect := result.GetMainEffect(0)
// Get derived SHAP value (row sum)
shapValue := result.GetSHAPValue(0)
// Get top k strongest interactions
topK := result.TopInteractions(5)
for _, inter := range topK {
fmt.Printf("%s <-> %s: %.4f\n", inter.Name1, inter.Name2, inter.Value)
}Core types for SHAP explanations:
- 📊
Explanation- Contains prediction, base value, SHAP values, and metadata - ✔️
Verify()- Checks local accuracy (sum of SHAP values = prediction - base) - 🔝
TopFeatures()- Returns features sorted by absolute SHAP value - 📄 JSON serialization with
ToJSON()andFromJSON()
Model interface for SHAP computation:
- 🔌
Modelinterface withPredict(),PredictBatch(), andNumFeatures() - 🛠️
FuncModel- Wraps a prediction function as a Model
ONNX Runtime wrapper:
- 🔗
Session- Wraps an ONNX Runtime session - 📦 Supports batch predictions
- 📚 Requires ONNX Runtime shared library
TreeSHAP for tree-based models:
- 🎯 Exact SHAP values (not approximations)
- ⚡ O(TLD²) complexity - 40-100x faster than permutation
- 🌲 XGBoost JSON model support
- 💡 LightGBM JSON and text format support
- 🔄 Parallel batch processing
- 🔗 Interaction values for pairwise feature interactions
LinearSHAP for linear models:
- 🎯 Exact closed-form solution:
SHAP[i] = coef[i] * (x[i] - E[X[i]]) - ⚡ O(d) complexity where d is number of features
- 📈 Support for linear regression and logistic regression
KernelSHAP for model-agnostic explanations:
- 🔮 Model-agnostic black-box method
- ⚖️ Weighted linear regression on binary coalition masks
- 🧮 SHAP kernel weights:
(d-1) / (C(d,k) * k * (d-k)) - ✅ Validated against Python SHAP library
ExactSHAP for brute-force exact Shapley values:
- 🎯 Mathematically exact values by enumerating all 2^n coalitions
- ⏱️ O(n * 2^n) complexity - only practical for ≤15 features
- 🔍 Useful for validating other SHAP implementations
- 📐 Reference implementation for small feature sets
DeepSHAP for neural network explanations:
- 🧠 Combines DeepLIFT with Shapley values for efficient neural network attribution
- 🔗 Works with ONNX models via
model/onnxActivationSession - ⚡ Efficient backward propagation using DeepLIFT rescale rule
- 📊 Supports Dense, ReLU, Sigmoid, Tanh, Softmax layers
Permutation SHAP with antithetic sampling:
- ✅ Guarantees local accuracy
- 🔄 Supports parallel computation
- 📉 Lower variance than pure Monte Carlo
Monte Carlo sampling SHAP:
- 🛠️ Simple implementation
- ⚡ Good for quick estimates
Feature masking strategies:
- 🎭
IndependentMasker- Marginal/independent masking using background samples
Background dataset management:
- 📂 Dataset loading and statistics
- 🎲 Random sampling and k-means summarization
The permutation explainer uses antithetic sampling for variance reduction:
-
For each permutation sample:
▶️ Forward pass: Start with background, add features one by one◀️ Reverse pass: Start with instance, remove features one by one- ⚖️ Average contributions from both passes
-
Average over all permutation samples
This guarantees that SHAP values sum exactly to (prediction - base value).
The sampling explainer uses simple Monte Carlo:
- 🔀 Generate random permutations
- 📊 For each permutation, compute marginal contributions
- ⚖️ Average over all samples
exp, err := permutation.New(model, background,
explainer.WithNumSamples(100), // Number of permutation samples
explainer.WithSeed(42), // Random seed for reproducibility
explainer.WithNumWorkers(4), // Parallel workers
explainer.WithFeatureNames(names), // Feature names
explainer.WithModelID("my-model"), // Model identifier
)import "github.com/plexusone/shap-go/model/onnx"
// Initialize ONNX Runtime
onnx.InitializeRuntime("/path/to/libonnxruntime.so")
defer onnx.DestroyRuntime()
// Create session
session, err := onnx.NewSession(onnx.Config{
ModelPath: "model.onnx",
InputName: "float_input",
OutputName: "probabilities",
NumFeatures: 10,
})
defer session.Close()
// Use with explainer
exp, err := permutation.New(session, background)Every SHAP explanation should satisfy local accuracy:
sum(SHAP values) = prediction - base_value
You can verify this with:
result := explanation.Verify(tolerance)
if !result.Valid {
fmt.Printf("Local accuracy failed: difference = %f\n", result.Difference)
}Performance benchmarks on Apple M1 Max (arm64):
| Configuration | Time/op | Allocs/op |
|---|---|---|
| 10 trees, depth 4, 10 features | 20μs | 372 |
| 100 trees, depth 4, 10 features | 194μs | 3,612 |
| 1000 trees, depth 4, 10 features | 1.9ms | 36,012 |
| Tree Depth | Time/op | Notes |
|---|---|---|
| Depth 3 | 39μs | Shallow trees |
| Depth 6 | 598μs | Typical production depth |
| Depth 10 | 13.2ms | Very deep trees |
| Method | Time/op | Type |
|---|---|---|
| TreeSHAP | 8.8μs | Exact |
| PermutationSHAP (10 samples) | 16μs | Approximate |
| PermutationSHAP (50 samples) | 77μs | Approximate |
| PermutationSHAP (100 samples) | 153μs | Approximate |
TreeSHAP is ~17x faster than PermutationSHAP with 100 samples while providing exact values.
| Model Size | Trees | Depth | Features | Time/op |
|---|---|---|---|---|
| Small | 50 | 4 | 10 | 106μs |
| Medium | 200 | 6 | 30 | 2.7ms |
| Large | 500 | 8 | 50 | 31.7ms |
| Workers | 100 instances | Speedup |
|---|---|---|
| 1 (sequential) | 10.2ms | 1.0x |
| 4 (parallel) | 8.0ms | 1.3x |
| 8 (parallel) | 8.1ms | 1.3x |
Run benchmarks with:
go test -bench=. -benchmem ./explainer/tree/...The examples/ directory contains working examples:
| Example | Description |
|---|---|
examples/linear |
PermutationSHAP with a simple linear model |
examples/linearshap |
LinearSHAP for linear/logistic regression |
examples/treeshap |
TreeSHAP with manually constructed tree ensembles |
examples/kernelshap |
KernelSHAP weighted linear regression explainer |
examples/sampling |
SamplingSHAP Monte Carlo approximation |
examples/onnx_basic |
ONNX model with KernelSHAP explanations |
examples/deepshap |
DeepSHAP for neural network explanations |
examples/batch |
Batch processing with parallel workers |
examples/visualization |
Generating chart data for visualizations |
examples/markdown_report |
Generate Markdown reports with SHAP explanations |
Run an example:
go run ./examples/linear
go run ./examples/linearshap
go run ./examples/treeshap
go run ./examples/kernelshap
go run ./examples/sampling
go run ./examples/batch
go run ./examples/visualization
go run ./examples/markdown_report
# ONNX examples (requires ONNX Runtime and model files)
cd examples/onnx_basic && python generate_model.py && go run main.go
cd examples/deepshap && python generate_model.py && go run main.goMIT License