Summary
Propose extending tropical-gemm from matrix multiplication (2D) to general tensor einsum operations, enabling high-performance tropical tensor network contractions.
Motivation
The BPDecoderPlus project implements tropical tensor networks for MPE (Most Probable Explanation) inference. Currently, it uses a Python-based einsum implementation with rule-based dispatch, falling back to tropical-gemm only for GEMM-like patterns. Extending tropical-gemm to support general einsum would:
- Enable SIMD/CUDA acceleration for all contraction patterns
- Provide a unified, high-performance backend for tropical tensor networks
- Support broader use cases (factor graphs, tensor networks, graphical models)
Proposed Interfaces
Core Data Structures
/// Index mapping for tensor contractions
pub struct IndexMap {
pub input_vars: Vec<Vec<usize>>, // Variable indices for each input tensor
pub output_vars: Vec<usize>, // Output variable indices
pub elim_vars: Vec<usize>, // Variables to eliminate (max/min over)
}
/// Backpointer for argmax tracking (MPE traceback)
pub struct Backpointer {
pub elim_vars: Vec<usize>,
pub elim_shape: Vec<usize>,
pub out_vars: Vec<usize>,
pub argmax_flat: Vec<i64>, // Flattened argmax indices
}
Primary Operations
1. Binary Tropical Contraction
/// C[out_vars] = max/min_{elim_vars}(A[a_vars] + B[b_vars])
fn tropical_contract<S: Semiring>(
a: &Tensor<S>,
b: &Tensor<S>,
index_map: &IndexMap,
track_argmax: bool,
) -> (Tensor<S>, Option<Backpointer>);
2. Unary Reduction
/// Reduce tensor over specified dimensions
fn tropical_reduce<S: Semiring>(
tensor: &Tensor<S>,
elim_dims: &[usize],
track_argmax: bool,
) -> (Tensor<S>, Option<Backpointer>);
3. General Einsum Interface
/// Tropical einsum following OMEinsum-style design
/// Example: tropical_einsum([A, B], [(0,1), (1,2)], (0,2)) computes C[i,k] = max_j(A[i,j] + B[j,k])
fn tropical_einsum<S: Semiring>(
tensors: &[&Tensor<S>],
input_indices: &[&[usize]],
output_indices: &[usize],
track_argmax: bool,
) -> (Tensor<S>, Option<Backpointer>);
Supporting Operations
// Dimension manipulation
fn permute_dims<S>(tensor: &Tensor<S>, perm: &[usize]) -> Tensor<S>;
fn align_tensor<S>(tensor: &Tensor<S>, src_vars: &[usize], dst_vars: &[usize]) -> Tensor<S>;
// Diagonal operations
fn extract_diagonal<S>(tensor: &Tensor<S>, dim1: usize, dim2: usize) -> Tensor<S>;
fn tropical_trace<S: Semiring>(tensor: &Tensor<S>, dim1: usize, dim2: usize) -> Tensor<S>;
// Backpointer utilities
fn argmax_trace(bp: &Backpointer, assignment: &HashMap<usize, usize>) -> HashMap<usize, usize>;
Python Bindings
import tropical_gemm as tg
# NumPy interface
result, bp = tg.tropical_einsum(
[A, B],
ixs=[(0, 1), (1, 2)],
iy=(0, 2),
semiring="maxplus",
track_argmax=True
)
# PyTorch interface (with autograd support)
result, bp = tg.torch.tropical_einsum(A, B, ixs=[(0,1), (1,2)], iy=(0,2))
Implementation Strategy
Phase 1: Core Tensor Support
Phase 2: Binary Contraction
Phase 3: Rule-Based Dispatch
Phase 4: Python Integration
Reference Implementation
The current Python implementation in BPDecoderPlus can serve as a reference:
tropical_in_new/src/tropical_einsum.py - Rule-based dispatch
tropical_in_new/src/primitives.py - Core operations and backpointers
tropical_in_new/src/contraction.py - Tree-based contraction
Questions
- Should we support n-ary contractions directly, or always decompose to binary?
- What CUDA kernel strategies work best for general tensor contractions?
- Should the tensor type own its data or support views/slices?
Related Work
Summary
Propose extending
tropical-gemmfrom matrix multiplication (2D) to general tensor einsum operations, enabling high-performance tropical tensor network contractions.Motivation
The BPDecoderPlus project implements tropical tensor networks for MPE (Most Probable Explanation) inference. Currently, it uses a Python-based einsum implementation with rule-based dispatch, falling back to
tropical-gemmonly for GEMM-like patterns. Extendingtropical-gemmto support general einsum would:Proposed Interfaces
Core Data Structures
Primary Operations
1. Binary Tropical Contraction
2. Unary Reduction
3. General Einsum Interface
Supporting Operations
Python Bindings
Implementation Strategy
Phase 1: Core Tensor Support
Tensor<S>type with arbitrary dimensionsPhase 2: Binary Contraction
Phase 3: Rule-Based Dispatch
Phase 4: Python Integration
Reference Implementation
The current Python implementation in BPDecoderPlus can serve as a reference:
tropical_in_new/src/tropical_einsum.py- Rule-based dispatchtropical_in_new/src/primitives.py- Core operations and backpointerstropical_in_new/src/contraction.py- Tree-based contractionQuestions
Related Work