Skip to content

Proposal: Extend to general tensor einsum operations #21

@GiggleLiu

Description

@GiggleLiu

Summary

Propose extending tropical-gemm from matrix multiplication (2D) to general tensor einsum operations, enabling high-performance tropical tensor network contractions.

Motivation

The BPDecoderPlus project implements tropical tensor networks for MPE (Most Probable Explanation) inference. Currently, it uses a Python-based einsum implementation with rule-based dispatch, falling back to tropical-gemm only for GEMM-like patterns. Extending tropical-gemm to support general einsum would:

  1. Enable SIMD/CUDA acceleration for all contraction patterns
  2. Provide a unified, high-performance backend for tropical tensor networks
  3. Support broader use cases (factor graphs, tensor networks, graphical models)

Proposed Interfaces

Core Data Structures

/// Index mapping for tensor contractions
pub struct IndexMap {
    pub input_vars: Vec<Vec<usize>>,  // Variable indices for each input tensor
    pub output_vars: Vec<usize>,       // Output variable indices
    pub elim_vars: Vec<usize>,         // Variables to eliminate (max/min over)
}

/// Backpointer for argmax tracking (MPE traceback)
pub struct Backpointer {
    pub elim_vars: Vec<usize>,
    pub elim_shape: Vec<usize>,
    pub out_vars: Vec<usize>,
    pub argmax_flat: Vec<i64>,  // Flattened argmax indices
}

Primary Operations

1. Binary Tropical Contraction

/// C[out_vars] = max/min_{elim_vars}(A[a_vars] + B[b_vars])
fn tropical_contract<S: Semiring>(
    a: &Tensor<S>,
    b: &Tensor<S>,
    index_map: &IndexMap,
    track_argmax: bool,
) -> (Tensor<S>, Option<Backpointer>);

2. Unary Reduction

/// Reduce tensor over specified dimensions
fn tropical_reduce<S: Semiring>(
    tensor: &Tensor<S>,
    elim_dims: &[usize],
    track_argmax: bool,
) -> (Tensor<S>, Option<Backpointer>);

3. General Einsum Interface

/// Tropical einsum following OMEinsum-style design
/// Example: tropical_einsum([A, B], [(0,1), (1,2)], (0,2)) computes C[i,k] = max_j(A[i,j] + B[j,k])
fn tropical_einsum<S: Semiring>(
    tensors: &[&Tensor<S>],
    input_indices: &[&[usize]],
    output_indices: &[usize],
    track_argmax: bool,
) -> (Tensor<S>, Option<Backpointer>);

Supporting Operations

// Dimension manipulation
fn permute_dims<S>(tensor: &Tensor<S>, perm: &[usize]) -> Tensor<S>;
fn align_tensor<S>(tensor: &Tensor<S>, src_vars: &[usize], dst_vars: &[usize]) -> Tensor<S>;

// Diagonal operations
fn extract_diagonal<S>(tensor: &Tensor<S>, dim1: usize, dim2: usize) -> Tensor<S>;
fn tropical_trace<S: Semiring>(tensor: &Tensor<S>, dim1: usize, dim2: usize) -> Tensor<S>;

// Backpointer utilities
fn argmax_trace(bp: &Backpointer, assignment: &HashMap<usize, usize>) -> HashMap<usize, usize>;

Python Bindings

import tropical_gemm as tg

# NumPy interface
result, bp = tg.tropical_einsum(
    [A, B],
    ixs=[(0, 1), (1, 2)],
    iy=(0, 2),
    semiring="maxplus",
    track_argmax=True
)

# PyTorch interface (with autograd support)
result, bp = tg.torch.tropical_einsum(A, B, ixs=[(0,1), (1,2)], iy=(0,2))

Implementation Strategy

Phase 1: Core Tensor Support

  • Define Tensor<S> type with arbitrary dimensions
  • Implement dimension permutation and alignment
  • Add unary reduction operations

Phase 2: Binary Contraction

  • Implement general binary contraction (extend current GEMM)
  • Support batch dimensions
  • Backpointer tracking for all patterns

Phase 3: Rule-Based Dispatch

  • Pattern matching for optimized paths (GEMM, outer product, etc.)
  • Fallback to general implementation
  • CUDA kernels for common patterns

Phase 4: Python Integration

  • Extend Python bindings for tensor operations
  • PyTorch autograd integration
  • NumPy ufunc compatibility

Reference Implementation

The current Python implementation in BPDecoderPlus can serve as a reference:

  • tropical_in_new/src/tropical_einsum.py - Rule-based dispatch
  • tropical_in_new/src/primitives.py - Core operations and backpointers
  • tropical_in_new/src/contraction.py - Tree-based contraction

Questions

  1. Should we support n-ary contractions directly, or always decompose to binary?
  2. What CUDA kernel strategies work best for general tensor contractions?
  3. Should the tensor type own its data or support views/slices?

Related Work

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions