Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
9a7c6f2
Add CLAUDE.md for Claude Code context
cmoyates Mar 7, 2026
34d4677
Add Claude Code hooks + CLAUDE.md refinements
cmoyates Mar 7, 2026
4c523d3
Add VRAM & perf optimization plan doc
cmoyates Mar 7, 2026
75f543a
Add Phase 0 benchmarking infrastructure to optimization plan
cmoyates Mar 7, 2026
82baa3b
Add Phase 0 benchmark script + quality gate tests
cmoyates Mar 7, 2026
147ad3a
Fix MPS API name + Lanczos overshoot tolerance, mark Phase 0 complete
cmoyates Mar 7, 2026
d400e29
Add benchmark results doc w/ Phase 0 baseline measurements
cmoyates Mar 7, 2026
cfe1239
Phase 1: FP16 weight casting — 7.2GB mem savings, 27% faster
cmoyates Mar 7, 2026
272a4f2
Add Phase 1 benchmark results to RESULTS.md
cmoyates Mar 7, 2026
679fdc0
Phase 2: GPU color math + asset caching
cmoyates Mar 7, 2026
d416402
Add Phase 2 benchmark results to RESULTS.md
cmoyates Mar 7, 2026
ea30174
Phase 3: decouple backbone/refiner resolutions
cmoyates Mar 7, 2026
4637f85
Add Phase 3 benchmark results to RESULTS.md
cmoyates Mar 7, 2026
938914b
Phase 4: tiled CNN refiner w/ tent blending
cmoyates Mar 7, 2026
773bcf5
Add Phase 5: CLI feature flags for all optimizations
cmoyates Mar 7, 2026
94e4fde
Phase 4 results: 96px overlap default, -52% VRAM
cmoyates Mar 7, 2026
b8b8b0e
Phase 5: CLI feature flags for all optimizations
cmoyates Mar 7, 2026
b38a4a6
Phase 5 benchmark matrix results documented
cmoyates Mar 7, 2026
f3da7b7
Add CUDA testing guide for optimization branch
cmoyates Mar 8, 2026
0353bef
Fix PR #54 review concerns: dedup CLI args, fix _clamp, remove per-ti…
cmoyates Mar 8, 2026
ade2cb8
Add visual preset comparison script
cmoyates Mar 8, 2026
75e1567
Fix FP16: drop redundant .half(), use autocast exclusively
cmoyates Mar 8, 2026
e1457e5
Switch EXR compression PXR24 → ZIP (Windows deadlock fix)
cmoyates Mar 8, 2026
cc7a92c
Add platform markers for torch index (macOS falls back to PyPI)
cmoyates Mar 8, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions .claude/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
{
"hooks": {
"PostToolUse": [
{
"matcher": "Write|Edit|MultiEdit",
"hooks": [
{
"type": "command",
"command": "file=$(echo \"$CLAUDE_TOOL_INPUT\" | jq -r '.file_path // empty'); if [ -n \"$file\" ] && echo \"$file\" | grep -qE '\\.py$'; then uv run ruff format \"$file\" 2>/dev/null; fi"
}
]
}
],
"Stop": [
{
"hooks": [
{
"type": "command",
"command": "uv run ruff check --fix 2>&1 | tail -20"
}
]
}
]
}
}
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ CorridorKey_remote.bat
*.onnx

# Checkpoint Directories
# Benchmark baseline outputs (large .npy files)
benchmarks/baseline/
benchmarks/**/diffs/

CorridorKeyModule/checkpoints/*
!CorridorKeyModule/checkpoints/.gitkeep
CorridorKeyModule/IgnoredCheckpoints/*
Expand Down
90 changes: 90 additions & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# CLAUDE.md

This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.

## Project Summary

CorridorKey is a neural-network-based green screen keyer for professional VFX. It takes an RGB image + coarse alpha hint and produces physically unmixed straight foreground color + linear alpha channel. Native inference at 2048x2048 (Hiera backbone). Requires ~22.7GB VRAM (CUDA), also supports MPS (Apple Silicon) and an experimental MLX backend.

## Important

Be extremely concise in all interactions and commit messages. Sacrifice grammar for the sake of being concise.

## Commands

```bash
# Setup (uses uv, not pip)
uv sync --group dev

# Tests
uv run pytest # all tests
uv run pytest -v # verbose
uv run pytest -m "not gpu" # skip GPU tests (what CI runs)
uv run pytest tests/test_color_utils.py # single file
uv run pytest -k "test_name" # single test by name
uv run pytest --cov # with coverage

# Lint & format
uv run ruff check # lint
uv run ruff format --check # format check
uv run ruff format # auto-format

# Run the keyer
uv run python corridorkey_cli.py --action wizard --win_path <path>
# Or drag files onto CorridorKey_DRAG_CLIPS_HERE_local.sh / .bat
```

Ruff config: line-length 120, rules `E,F,W,I,B`. `gvm_core/` and `VideoMaMaInferenceModule/` are excluded from linting. CI runs lint + tests (Python 3.10, 3.13) on every PR to `main`.

## Architecture

### Entry Points
- `corridorkey_cli.py` — CLI arg parsing, env setup, interactive wizard. Imports pipeline logic from `clip_manager.py`.
- `clip_manager.py` — Core pipeline: scans directories for `Input/` (RGB) + `AlphaHint/` (BW masks), prompts for config (gamma, despill, despeckle, refiner), loops frame-by-frame through the engine.
- Launcher scripts: `.sh` / `.bat` files that invoke `corridorkey_cli.py`.

### CorridorKeyModule (inference engine)
- `inference_engine.py` — `CorridorKeyEngine`: loads model, resizes to/from 2048x2048 (Lanczos4), normalizes inputs (uint8->float), packs output passes.
- `backend.py` — Backend factory: selects Torch or MLX engine. Configured via `CORRIDORKEY_BACKEND` env var or CLI flag (`auto`/`torch`/`mlx`).
- `core/model_transformer.py` — `GreenFormer`: Hiera backbone (timm, 4-channel: RGB + alpha hint), multiscale decoders, CNN refiner head (`CNNRefinerModule`) with additive delta logits.
- `core/color_utils.py` — Compositing math: piecewise sRGB transfer functions, premultiply, luminance-preserving despill, morphological matte cleanup.

### backend/ (service layer)
Higher-level service abstraction (`CorridorKeyService`), project/clip state management, GPU job queue, ffmpeg tools, frame I/O. This layer sits above the raw inference engine.

### Optional Alpha Hint Generators
- `gvm_core/` — GVM (Generative Video Matting). Automatic, no user mask needed. ~80GB VRAM.
- `VideoMaMaInferenceModule/` — VideoMaMa. Requires user-provided `VideoMamaMaskHint/`. ~80GB VRAM.

Both invoked through `clip_manager.py` (`--action generate_alphas`). These are third-party research repos kept close to upstream — excluded from ruff enforcement.

### Device Selection
`device_utils.py` — Centralized device resolution: CLI flag > `CORRIDORKEY_DEVICE` env var > auto-detect (CUDA > MPS > CPU).

### Output Structure (per shot)
- `/Matte` — Linear alpha (half-float EXR)
- `/FG` — Straight foreground color, sRGB gamut (half-float EXR)
- `/Processed` — Linear premultiplied RGBA (half-float EXR)
- `/Comp` — Checkerboard preview (8-bit PNG)

## Philosophy

This codebase will outlive you. Every shortcut you take becomes
someone else's burden. Every hack compounds into technical debt
that slows the whole team down.

You are not just writing code. You are shaping the future of this
project. The patterns you establish will be copied. The corners
you cut will be cut again.

Fight entropy. Leave the codebase better than you found it.


## Critical Rules

1. **Color math is sacred.** Model outputs: FG is sRGB, alpha is linear. Use piecewise sRGB transfer functions from `color_utils.py`, never `pow(x, 2.2)`. "Crushed shadows" or "dark fringes" = check sRGB-to-linear conversion ordering.
2. **Model I/O is `[0.0, 1.0]` float tensors.** Always.
3. **Performance matters.** 4K video frame-by-frame. Minimize `.numpy()` transfers and unnecessary `cv2.resize` in hot loops.
4. **OpenEXR requires** `os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"` before importing cv2.
5. **Folder structure is meaningful.** The wizard expects `Input/`, `AlphaHint/`, and optionally `VideoMamaMaskHint/` subdirectories per shot.
6. **Model weights are gitignored.** `.pth`, `.safetensors`, `.ckpt`, `.bin` are all in `.gitignore`. Most tests don't need them.
2 changes: 1 addition & 1 deletion CorridorKeyModule/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ result = engine.process_frame(
proc_rgba = result['processed']
proc_bgra = cv2.cvtColor(proc_rgba, cv2.COLOR_RGBA2BGRA)

exr_flags = [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF, cv2.IMWRITE_EXR_COMPRESSION, cv2.IMWRITE_EXR_COMPRESSION_PXR24]
exr_flags = [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF, cv2.IMWRITE_EXR_COMPRESSION, cv2.IMWRITE_EXR_COMPRESSION_ZIP]
cv2.imwrite("output_processed.exr", proc_bgra, exr_flags)
```

Expand Down
16 changes: 15 additions & 1 deletion CorridorKeyModule/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,11 @@ def create_engine(
backend: str | None = None,
device: str | None = None,
img_size: int = DEFAULT_IMG_SIZE,
backbone_size: int | None = None,
refiner_tile_size: int | None = 512,
refiner_tile_overlap: int = 96,
fp16: bool = True,
gpu_postprocess: bool = True,
):
"""Factory: returns an engine with process_frame() matching the Torch contract."""
backend = resolve_backend(backend)
Expand All @@ -222,4 +227,13 @@ def create_engine(
from CorridorKeyModule.inference_engine import CorridorKeyEngine

logger.info("Torch engine loaded: %s (device=%s)", ckpt.name, device)
return CorridorKeyEngine(checkpoint_path=str(ckpt), device=device or "cpu", img_size=img_size)
return CorridorKeyEngine(
checkpoint_path=str(ckpt),
device=device or "cpu",
img_size=img_size,
backbone_size=backbone_size,
refiner_tile_size=refiner_tile_size,
refiner_tile_overlap=refiner_tile_overlap,
fp16=fp16,
gpu_postprocess=gpu_postprocess,
)
16 changes: 12 additions & 4 deletions CorridorKeyModule/core/color_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ def _clamp(x: np.ndarray | torch.Tensor, min: float) -> np.ndarray | torch.Tenso
Clamp function that supports both Numpy arrays and PyTorch tensors.
"""
if _is_tensor(x):
return x.clamp(min=0.0)
return x.clamp(min=min)
else:
return np.clip(x, 0.0, None)
return np.clip(x, min, None)


_torch_stack = functools.partial(torch.stack, dim=-1)
Expand Down Expand Up @@ -247,6 +247,15 @@ def despill(
return despilled


_kernel_cache: dict[int, np.ndarray] = {}


def _get_ellipse_kernel(kernel_size: int) -> np.ndarray:
if kernel_size not in _kernel_cache:
_kernel_cache[kernel_size] = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
return _kernel_cache[kernel_size]


def clean_matte(alpha_np: np.ndarray, area_threshold: int = 300, dilation: int = 15, blur_size: int = 5) -> np.ndarray:
"""
Cleans up small disconnected components (like tracking markers) from a predicted alpha matte.
Expand Down Expand Up @@ -275,8 +284,7 @@ def clean_matte(alpha_np: np.ndarray, area_threshold: int = 300, dilation: int =
# Dilate
if dilation > 0:
kernel_size = int(dilation * 2 + 1)
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
cleaned_mask = cv2.dilate(cleaned_mask, kernel)
cleaned_mask = cv2.dilate(cleaned_mask, _get_ellipse_kernel(kernel_size))

# Blur
if blur_size > 0:
Expand Down
97 changes: 85 additions & 12 deletions CorridorKeyModule/core/model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,16 +144,31 @@ def __init__(
encoder_name: str = "hiera_base_plus_224.mae_in1k_ft_in1k",
in_channels: int = 4,
img_size: int = 512,
backbone_size: int | None = None,
use_refiner: bool = True,
refiner_tile_size: int | None = None,
refiner_tile_overlap: int = 64,
) -> None:
super().__init__()

# Backbone resolution — None means same as img_size (no downsampling)
self.backbone_size = backbone_size
encoder_img_size = backbone_size or img_size

# Tiled refiner config — reduces peak VRAM by processing tiles sequentially
self.refiner_tile_size = refiner_tile_size
self.refiner_tile_overlap = refiner_tile_overlap
if refiner_tile_size is not None:
self._tent_weight = self._build_tent_weight(refiner_tile_size, refiner_tile_overlap)
else:
self._tent_weight = None

# --- Encoder ---
# Load Pretrained Hiera
# 1. Create Target Model (512x512, Random Weights)
# 1. Create Target Model (Random Weights)
# We use features_only=True, which wraps it in FeatureGetterNet
print(f"Initializing {encoder_name} (img_size={img_size})...")
self.encoder = timm.create_model(encoder_name, pretrained=False, features_only=True, img_size=img_size)
print(f"Initializing {encoder_name} (img_size={img_size}, backbone_size={encoder_img_size})...")
self.encoder = timm.create_model(encoder_name, pretrained=False, features_only=True, img_size=encoder_img_size)
# We skip downloading/loading base weights because the user's checkpoint
# (loaded immediately after this) contains all weights, including correctly
# trained/sized PosEmbeds. This keeps the project offline-capable using only local assets.
Expand Down Expand Up @@ -235,12 +250,65 @@ def _patch_input_layer(self, in_channels: int) -> None:

print(f"Patched input layer: 3 channels -> {in_channels} channels (Extra initialized to 0)")

@staticmethod
def _build_tent_weight(tile_size: int, overlap: int) -> torch.Tensor:
"""Build 2D tent (linear ramp) weight map for tile seam blending."""
ramp = torch.linspace(0, 1, overlap + 2)[1:-1] # (0, 1) exclusive
center = torch.ones(tile_size - 2 * overlap)
w1d = torch.cat([ramp, center, ramp.flip(0)])
return (w1d.unsqueeze(1) * w1d.unsqueeze(0)).unsqueeze(0).unsqueeze(0) # [1, 1, ts, ts]

def _tiled_refine(self, rgb: torch.Tensor, coarse_pred: torch.Tensor) -> torch.Tensor:
"""Run refiner in tiles to reduce peak VRAM. Blends overlaps with tent weights."""
tile_size = self.refiner_tile_size
overlap = self.refiner_tile_overlap
stride = tile_size - overlap
_, _, h, w = rgb.shape
device = rgb.device

# CPU accumulators — tiles offloaded immediately to save VRAM
output_acc = torch.zeros(1, 4, h, w, dtype=torch.float32)
weight_acc = torch.zeros(1, 1, h, w, dtype=torch.float32)
tent = self._tent_weight # [1, 1, tile_size, tile_size]

def _starts(length: int) -> list[int]:
"""Tile start positions — last tile end-aligns with image edge."""
s = list(range(0, length - tile_size + 1, stride))
if not s or s[-1] + tile_size < length:
s.append(length - tile_size)
return sorted(set(s))

for y in _starts(h):
for x in _starts(w):
rgb_tile = rgb[:, :, y : y + tile_size, x : x + tile_size]
coarse_tile = coarse_pred[:, :, y : y + tile_size, x : x + tile_size]

delta = self.refiner(rgb_tile, coarse_tile)
delta_cpu = delta.cpu().float()

output_acc[:, :, y : y + tile_size, x : x + tile_size] += delta_cpu * tent
weight_acc[:, :, y : y + tile_size, x : x + tile_size] += tent

del delta, rgb_tile, coarse_tile

return (output_acc / weight_acc.clamp(min=1e-8)).to(device)

def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
# x: [B, 4, H, W]
input_size = x.shape[2:]

# Encode
features = self.encoder(x) # Returns list of features
# Optionally downsample for backbone (encoder runs at lower res)
if self.backbone_size is not None and (
input_size[0] != self.backbone_size or input_size[1] != self.backbone_size
):
x_backbone = F.interpolate(
x, size=(self.backbone_size, self.backbone_size), mode="bilinear", align_corners=False
)
else:
x_backbone = x

# Encode (at backbone resolution)
features = self.encoder(x_backbone) # Returns list of features

# Decode Streams
alpha_logits = self.alpha_decoder(features) # [B, 1, H/4, W/4]
Expand All @@ -251,20 +319,19 @@ def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
alpha_logits_up = F.interpolate(alpha_logits, size=input_size, mode="bilinear", align_corners=False)
fg_logits_up = F.interpolate(fg_logits, size=input_size, mode="bilinear", align_corners=False)

# --- HUMILITY CLAMP REMOVED (Phase 3) ---
# User requested NO CLAMPING to preserve all backbone detail.
# Refiner sees raw logits (-inf to +inf).
# alpha_logits_up = torch.clamp(alpha_logits_up, -3.0, 3.0)
# fg_logits_up = torch.clamp(fg_logits_up, -3.0, 3.0)
# Humility clamp removed: clamping logits to [-3, 3] limited refiner correction
# range, causing visible banding in low-contrast regions. Raw logits are safe
# because FP16 autocast handles numerical stability and sigmoid saturates gracefully.

# Coarse Probs (for Loss and Refiner Input)
alpha_coarse = torch.sigmoid(alpha_logits_up)
fg_coarse = torch.sigmoid(fg_logits_up)

# --- Refinement (CNN Hybrid) ---
# 4. Refine (CNN)
# Input to refiner: RGB Image (first 3 channels of x) + Coarse Predictions (Probs)
# Input to refiner: RGB Image (first 3 channels of ORIGINAL x) + Coarse Predictions (Probs)
# We give the refiner 'Probs' as input features because they are normalized [0,1]
# Always use full-res RGB — refiner recovers fine detail lost by backbone downsampling
rgb = x[:, :3, :, :]

# Feed the Refiner
Expand All @@ -273,7 +340,13 @@ def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
# Refiner outputs DELTA LOGITS
# The refiner predicts the correction in valid score space (-inf, inf)
if self.use_refiner and self.refiner is not None:
delta_logits = self.refiner(rgb, coarse_pred)
use_tiling = self.refiner_tile_size is not None and (
input_size[0] > self.refiner_tile_size or input_size[1] > self.refiner_tile_size
)
if use_tiling:
delta_logits = self._tiled_refine(rgb, coarse_pred)
else:
delta_logits = self.refiner(rgb, coarse_pred)
else:
# Zero Deltas
delta_logits = torch.zeros_like(coarse_pred)
Expand Down
Loading
Loading