Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions .agent/skills/bae-compute-graph/SKILL.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
---
name: bae-compute-graph
description: Use when defining or modifying BAE compute graphs, sparse Jacobian traces, bundle adjustment or pose graph optimization problems, or any code using `TrackingTensor`, `pypose.Parameter`, `map_transform`, or `bae.autograd.graph.jacobian`.
description: Use when defining or modifying BAE compute graphs, sparse Jacobian traces, bundle adjustment or pose graph optimization problems, or any code using `pp.Parameter`, `psjac`, or `bae.autograd.graph.jacobian`.
---

# BAE Compute Graph

## Core mental model
- The forward pass records a lightweight operation trace on tensors.
- `TrackingTensor` preserves PyPose `LieTensor` type information, so tracked `pp.SE3` values stay LieTensor-aware through `nn.Parameter(...)`, tensor indexing, LieTensor operations, and concatenation, `torch.cat(..., dim=0)`.
- The forward pass records a lightweight operation trace on tensors, done through wrapping a optimizable parameter with `pp.Parameter(..., sjac=True)`.
- `pp.Parameter(..., sjac=True)` preserves PyPose `LieTensor` type information, so tracked `pp.SE3` values stay LieTensor-aware through tensor indexing, LieTensor operations, and concatenation, `torch.cat(..., dim=0)`.
- The sparse autograd logic classifies operations mainly by their effect on the Jacobian:
- `index`: determines sparse block-column layout.
- `map`: computes Jacobian block values.
Expand All @@ -19,12 +19,12 @@ description: Use when defining or modifying BAE compute graphs, sparse Jacobian
- Internally, intermediate Jacobians may be stored as `(indices, values)` before they are materialized as sparse BSR tensors at the leaves. `indices=None` means the current trace still has identity column layout and only carries block values.

## Authoring recipe
1. Wrap each optimizable state as `nn.Parameter(TrackingTensor(data))`.
2. If `data` is already a true PyPose `LieTensor` such as `pp.SE3(nodes)`, keep it that way. The tracked parameter wrapped by `TrackingTensor` will stay LieTensor-aware, and its optimizer step shape is inferred automatically from `parameter_update_shape(...)`.
1. Wrap each optimizable state as `pp.Parameter(data, sjac=True)`.
2. If `data` is already a true PyPose `LieTensor` such as `pp.SE3(nodes)`, keep it that way. The `pp.Parameter` with `sjac=True` will stay LieTensor-aware, and its optimizer step shape is inferred automatically from `parameter_update_shape(...)`.
3. The usage of `param.trim_SE3_grad = True` is not recommended. It is only for mixed ambient tensor layouts, such as a stored 7D quaternion pose or a pose-plus-extra-parameters tensor whose SE(3) portion should optimize on a 6D tangent space. Consider this an escape hatch for legacy code or special cases, not a general pattern. When using `trim_SE3_grad`, the user must ensure the first 7 entries of the parameter tensor encode SE(3) to ensure compatability.
4. Define each custom per-factor residual block with `@map_transform`.
4. Define each custom per-factor residual block with `@psjac` (imported from `pypose.autograd.function`).
5. In `forward()`, gather participating states by tensor indexing such as `self.pose[camera_idx]` or `self.nodes[edges[..., 0]]`. Indexed tracked LieTensor values preserve their LieTensor behavior.
6. Combine factor groups or rebuilt state tables with `torch.cat(..., dim=0)` if needed. Other concatenation mode is only supported inside `@map_transform`.
6. Combine factor groups or rebuilt state tables with `torch.cat(..., dim=0)` if needed. Other concatenation mode is only supported inside `@psjac`.
7. Return the residual tensor. `LM.step()` will call `bae.autograd.graph.jacobian(...)` on it to automatically derive the sparse Jacobian.

## What each tracked op means
Expand All @@ -37,7 +37,7 @@ description: Use when defining or modifying BAE compute graphs, sparse Jacobian
- When the indexed source is a tracked PyPose `LieTensor`, the indexed result remains LieTensor-aware, so downstream code can keep using native LieTensor methods such as `.Inv()`, `.Log()`, or `.Act(...)`.

### `map`
- Use `@map_transform` for a vectorized residual function that maps indexed inputs to per-factor residuals.
- Use `@psjac` for a vectorized residual function that maps indexed inputs to per-factor residuals.
- Simple tracked arithmetic such as `+`, `-`, and `*` is also recorded as a `map` op through `WHITELISTED_MAPS`, so expressions like `pred - obs` can stay inline.
- The backward pass computes local Jacobian blocks with `torch.vmap(jacrev(func, argnums=...))`.
- Those local blocks are then chained with any upstream Jacobian already attached to the output trace.
Expand All @@ -54,8 +54,8 @@ description: Use when defining or modifying BAE compute graphs, sparse Jacobian

## Hard constraints and gotchas
- The final residual trace must end in one of: `map`, `index`, or `cat(dim=0)`.
- Automatic indexing trace capture only happens when `TrackingTensor.__getitem__` receives a tensor index through PyTorch dispatch. Plain Python slicing is not the main supported sparse-layout path.
- `map_transform` functions must be compatible with `jacrev` and effectively batch-vectorized for `vmap`.
- Automatic indexing trace capture happens when a `pp.Parameter(..., sjac=True)` is indexed with a tensor index. Plain Python slicing is not the main supported sparse-layout path.
- `psjac` functions must be compatible with `jacrev` and effectively batch-vectorized for `vmap`.
- Only `torch.cat(..., dim=0)` is supported.
- If a parameter never appears in observations, its block-columns will be empty. The authors explicitly treat this as a structural failure because it will cause the solver to fail.
- Jacobian column counts and optimizer step views follow `parameter_update_shape(param)`:
Expand Down
29 changes: 16 additions & 13 deletions .agent/skills/bae-compute-graph/references/bal.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,14 @@
### Parameter setup

```python
import pypose as pp
from pypose.autograd.function import psjac

class Reproj(nn.Module):
def __init__(self, camera_params, points_3d):
super().__init__()
self.pose = nn.Parameter(TrackingTensor(camera_params))
self.points_3d = nn.Parameter(TrackingTensor(points_3d))
self.pose = pp.Parameter(camera_params, sjac=True)
self.points_3d = pp.Parameter(points_3d, sjac=True)
self.pose.trim_SE3_grad = True
```

Expand All @@ -25,9 +28,9 @@ class Reproj(nn.Module):
### Projection map

```python
@map_transform
@psjac
def project(points, camera_params):
points_proj = rotate_quat(points, camera_params[..., :7])
points_proj = pp.SE3(camera_params[..., :7]).Act(points)
points_proj = -points_proj[..., :2] / points_proj[..., 2].unsqueeze(-1)

f = camera_params[..., -3].unsqueeze(-1)
Expand Down Expand Up @@ -71,9 +74,9 @@ Use this when the first camera pose is fixed and should not appear in the optimi
class ReprojFixedFirstCamera(nn.Module):
def __init__(self, camera_se3_rest, camera_intrinsics, points_3d):
super().__init__()
self.pose_rest = nn.Parameter(TrackingTensor(camera_se3_rest))
self.intrinsics = nn.Parameter(TrackingTensor(camera_intrinsics))
self.points_3d = nn.Parameter(TrackingTensor(points_3d))
self.pose_rest = pp.Parameter(camera_se3_rest, sjac=True)
self.intrinsics = pp.Parameter(camera_intrinsics, sjac=True)
self.points_3d = pp.Parameter(points_3d, sjac=True)
self.pose_rest.trim_SE3_grad = True
```

Expand All @@ -84,7 +87,7 @@ class ReprojFixedFirstCamera(nn.Module):
### Projection map with split pose/intrinsics

```python
@map_transform
@psjac
def project_with_se3_and_intrinsics(points, camera_se3, intrinsics):
points_proj = pp.SE3(camera_se3).Act(points)
points_proj = -points_proj[..., :2] / points_proj[..., 2].unsqueeze(-1)
Expand Down Expand Up @@ -123,7 +126,7 @@ Use this when one point subset is optimized directly and another subset is produ
### Extra map

```python
@map_transform
@psjac
def transform_points(points, se3_params):
return pp.SE3(se3_params).Act(points)
```
Expand All @@ -134,10 +137,10 @@ def transform_points(points, se3_params):
class ReprojCat(nn.Module):
def __init__(self, camera_params, points_b, points_c, se3_c):
super().__init__()
self.pose = nn.Parameter(TrackingTensor(camera_params))
self.points_b = nn.Parameter(TrackingTensor(points_b))
self.points_c = nn.Parameter(TrackingTensor(points_c))
self.se3_c = nn.Parameter(TrackingTensor(se3_c))
self.pose = pp.Parameter(camera_params, sjac=True)
self.points_b = pp.Parameter(points_b, sjac=True)
self.points_c = pp.Parameter(points_c, sjac=True)
self.se3_c = pp.Parameter(se3_c, sjac=True)
self.pose.trim_SE3_grad = True
self.se3_c.trim_SE3_grad = True

Expand Down
10 changes: 6 additions & 4 deletions .agent/skills/bae-compute-graph/references/pgo.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,22 @@ This file is intentionally self-contained. Use it as the canonical recipe for a
## Parameter setup

```python
import pypose as pp
from pypose.autograd.function import psjac

class PoseGraph(nn.Module):
def __init__(self, nodes):
super().__init__()
self.nodes = nn.Parameter(TrackingTensor(nodes))
self.nodes.trim_SE3_grad = True
self.nodes = pp.Parameter(nodes, sjac=True)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The example should demonstrate initializing the parameter with a pp.LieTensor (like pp.SE3) to ensure the 6D optimization behavior described in the notes below is actually triggered.

Suggested change
self.nodes = pp.Parameter(nodes, sjac=True)
self.nodes = pp.Parameter(pp.SE3(nodes), sjac=True)

```

- `self.nodes` is typically shape `(num_nodes, 7)` in quaternion SE(3) storage.
- `trim_SE3_grad = True` converts each stored pose block into a 6D optimized tangent-space block.
- `pp.Parameter(..., sjac=True)` notifies `bae` to produce sparse Jacobian. Wrap the original batched tensor before performing any operation. If you use a regular tensor or LieTensor instead, the sparse backend will not recover the Jacobian for the tensor.

## Edge residual map

```python
@map_transform
@psjac
def edge_residual(poses, node1, node2, infos):
residual = (pp.SE3(poses).Inv() @ pp.SE3(node1).Inv() @ pp.SE3(node2)).Log().tensor()
residual = infos @ residual[..., None]
Expand Down
39 changes: 32 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,28 +153,53 @@ Bundle Adjustment optimizes camera poses and 3D point positions to minimize repr
```python
import torch
import pypose as pp
from pypose.autograd.function import psjac
from datapipes.bal_loader import get_problem
from ba_helpers import ReprojNonBatched, least_square_error
from bae.sparse.py_ops import *
from bae.sparse.solve import *
from bae.optim import LM
from bae.utils.pysolvers import PCG


class Reproj(torch.nn.Module):
def __init__(self, camera_params, points):
super().__init__()
self.pose = pp.Parameter(camera_params, sjac=True)
self.points = pp.Parameter(points, sjac=True)
self.pose.trim_SE3_grad = True

# Define the projection residual with structured Jacobian support
@psjac
def project(points, camera_params):
projection = pp.SE3(camera_params[..., :7]).Act(points)
projection = -projection[..., :2] / projection[..., [2]]

f = camera_params[..., [-3]]
k1 = camera_params[..., [-2]]
k2 = camera_params[..., [-1]]

n = torch.sum(projection**2, axis=-1, keepdim=True)
r = 1 + k1 * n + k2 * n**2
return projection * r * f

def forward(self, observes, cidx, pidx):
points_proj = Reproj.project(self.points[pidx], self.pose[cidx])
return points_proj - observes


# Load a problem from the BAL dataset
dataset = get_problem("problem-49-7776-pre", "ladybug", use_quat=True)
dataset = {k: v.to('cuda') for k, v in dataset.items() if isinstance(v, torch.Tensor)}

# Prepare input for the optimization
input = {
"points_2d": dataset['points_2d'],
"camera_indices": dataset['camera_index_of_observations'],
"point_indices": dataset['point_index_of_observations']
"observes": dataset['points_2d'],
"cidx": dataset['camera_index_of_observations'],
"pidx": dataset['point_index_of_observations'],
}

# Initialize model with camera parameters and 3D points
model = Reproj(
dataset['camera_params'].clone(),
dataset['points_3d'].clone()
dataset['points_3d'].clone(),
).to('cuda')

# Configure optimizer
Expand Down
11 changes: 5 additions & 6 deletions ba_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@
import pypose as pp
import torch
import torch.nn as nn
from pypose.autograd.function import psjac

from datapipes.bal_loader import get_problem
from bae.autograd.function import TrackingTensor, map_transform
from bae.optim import LM
from bae.utils.ba import rotate_quat
from bae.utils.pysolvers import PCG

TARGET_DATASET = "trafalgar"
Expand All @@ -23,9 +22,9 @@
NUM_CAMERA_PARAMS = 10 if OPTIMIZE_INTRINSICS else 7


@map_transform
@psjac
def project(points, camera_params):
projection = rotate_quat(points, camera_params[..., :7])
projection = pp.SE3(camera_params[..., :7]).Act(points)
projection = -projection[..., :2] / projection[..., [2]]

f = camera_params[..., [-3]]
Expand All @@ -40,8 +39,8 @@ def project(points, camera_params):
class Residual(nn.Module):
def __init__(self, camera_params, points):
super().__init__()
self.pose = nn.Parameter(TrackingTensor(camera_params))
self.points = nn.Parameter(TrackingTensor(points))
self.pose = pp.Parameter(camera_params, sjac=True)
self.points = pp.Parameter(points, sjac=True)
self.pose.trim_SE3_grad = True

def forward(self, observes, cidx, pidx):
Expand Down
4 changes: 0 additions & 4 deletions bae/utils/ba.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@ def rotate_euler(points, rot_vecs):
sin_theta = torch.sin(theta)
return cos_theta * points + sin_theta * torch.cross(v, points, dim=-1) + dot * (1 - cos_theta) * v

def rotate_quat(points, rot_vecs):
rot_vecs = pp.SE3(rot_vecs)
return rot_vecs.Act(points)

# inverse quat
def openGL2gtsam(pose):
R = pose.rotation()
Expand Down
53 changes: 5 additions & 48 deletions pgo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import argparse
import pypose as pp
from torch import nn
from bae.autograd.function import TrackingTensor, map_transform
from pypose.autograd.function import psjac

from bae.utils.pgo_dataset import G2OPGO
from bae.utils.pgo import plot_and_save, render_frame, save_gif
Expand All @@ -25,47 +25,9 @@

torch.set_printoptions(precision=6)

def diff(residual=None, jacobian=None):
num_factors = residual.shape[0] if residual is not None else jacobian.shape[0]
import numpy as np
with open('data.s', 'r') as f:
ceres_residuals = []
ceres_jacobians = []
for i in range(num_factors):
# read from 'data.s'
data = f.readline()
discard_left = data.split('[')[1:]
discard_right = [x.split(']')[0] for x in discard_left]
discard_semi = [x.split(';') for x in discard_right]
# convert to float
ceres_residual = [float(y[0]) for y in discard_semi]
ceres_residuals.append(ceres_residual)

ceres_jacobian = [np.fromstring(y[1], sep=',') for y in discard_semi]
ceres_jacobians.append(ceres_jacobian)
ceres_residuals = torch.tensor(ceres_residuals)
ceres_jacobians = torch.tensor(ceres_jacobians)
if residual is not None:
ceres_residuals = ceres_residuals - residual
# absolute difference
print(ceres_residuals.norm(dim=-1).mean())
# relative difference
print(((ceres_residuals.norm(dim=-1) / residual.norm(dim=-1)))[1:].mean())
if jacobian is not None:
ceres_jacobians = ceres_jacobians - jacobian
# absolute difference

def write_ceres_txt(nodes, filename='data.s'):
with open(filename, 'w') as f:
# ID x y z q_x q_y q_z q_w
for i in range(nodes.shape[0]):
node = nodes[i]
f.write(f'{i} {node[0].item()} {node[1].item()} {node[2].item()} {node[3].item()} {node[4].item()} {node[5].item()} {node[6].item()}\n')

def _pose_graph_residual(poses, node1, node2, infos):
if isinstance(infos, TrackingTensor):
infos = infos.tensor()

@psjac
def pose_graph_residual(poses, node1, node2, infos):
pose_ab_est = node1.Inv() @ node2
r_p = pose_ab_est.translation() - poses.translation()
# Match Ceres pose_graph_3d: 2 * vec(q_meas * q_est^{-1}).
Expand All @@ -77,15 +39,11 @@ def _pose_graph_residual(poses, node1, node2, infos):
return residual[..., 0]


@map_transform
def pose_graph_residual(poses, node1, node2, infos):
return _pose_graph_residual(poses, node1, node2, infos)

class PoseGraph(nn.Module):

def __init__(self, nodes):
super().__init__()
self.nodes = nn.Parameter(TrackingTensor(nodes))
self.nodes = pp.Parameter(nodes, sjac=True)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Initializing pp.Parameter with a plain 7D tensor will result in 7D optimization (standard Euclidean gradient for quaternions) instead of the intended 6D tangent-space optimization. PyPose parameters only trigger manifold-aware optimization automatically when the underlying data is a pp.LieTensor.

Suggested change
self.nodes = pp.Parameter(nodes, sjac=True)
self.nodes = pp.Parameter(pp.SE3(nodes), sjac=True)


def forward(self, edges, poses, infos):
node1 = self.nodes[edges[..., 0]]
Expand All @@ -96,7 +54,7 @@ def forward(self, edges, poses, infos):
class PoseGraphFixedFirst(nn.Module):
def __init__(self, nodes_rest):
super().__init__()
self.nodes_rest = nn.Parameter(TrackingTensor(nodes_rest))
self.nodes_rest = pp.Parameter(nodes_rest, sjac=True)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the PoseGraph class, self.nodes_rest should be initialized with a pp.LieTensor to ensure 6D tangent-space optimization is used during the solve.

Suggested change
self.nodes_rest = pp.Parameter(nodes_rest, sjac=True)
self.nodes_rest = pp.Parameter(pp.SE3(nodes_rest), sjac=True)


def nodes_all(self, node_fixed):
return torch.cat([node_fixed, self.nodes_rest], dim=0)
Expand Down Expand Up @@ -210,7 +168,6 @@ def forward(self, edges, poses, infos, node_fixed):
nodes_current = graph.nodes
plot_and_save(nodes_current.translation(), name+'.png', title)
torch.save(graph.state_dict(), name+'.pt')
write_ceres_txt(nodes_current.tensor(), name+'.txt')
if args.gif:
save_gif(gif_frames, sample_prefix + '.gif', duration=args.gif_duration)

Expand Down
Loading
Loading