From 056989a73beeef20541c7023239c6136b941d804 Mon Sep 17 00:00:00 2001 From: yifan <462449660@qq.com> Date: Wed, 25 Mar 2026 09:09:25 +0800 Subject: [PATCH] modify --- src/openworldlib/pipelines/pi3/__init__.py | 3 +- .../pipelines/pi3/pipeline_loger.py | 526 +++++++ .../pi3/loger/__init__.py | 0 .../pi3/loger/layers/attention.py | 291 ++++ .../pi3/loger/layers/block.py | 163 +++ .../pi3/loger/layers/camera_head.py | 159 +++ .../point_clouds_generation/pi3/loger/pi3.py | 1266 +++++++++++++++++ .../point_clouds_generation/pi3/loger/ttt.py | 323 +++++ .../pi3/loger/utils/basic.py | 90 ++ .../pi3/loger/utils/geometry.py | 36 + .../pi3/loger/utils/rotation.py | 180 +++ .../pi3/loger/utils/viser_utils.py | 781 ++++++++++ .../pi3/loger/utils/visual_util.py | 710 +++++++++ .../pi3/loger_representation.py | 362 +++++ test/test_loger.py | 54 + 15 files changed, 4943 insertions(+), 1 deletion(-) create mode 100644 src/openworldlib/pipelines/pi3/pipeline_loger.py create mode 100644 src/openworldlib/representations/point_clouds_generation/pi3/loger/__init__.py create mode 100644 src/openworldlib/representations/point_clouds_generation/pi3/loger/layers/attention.py create mode 100644 src/openworldlib/representations/point_clouds_generation/pi3/loger/layers/block.py create mode 100644 src/openworldlib/representations/point_clouds_generation/pi3/loger/layers/camera_head.py create mode 100644 src/openworldlib/representations/point_clouds_generation/pi3/loger/pi3.py create mode 100644 src/openworldlib/representations/point_clouds_generation/pi3/loger/ttt.py create mode 100644 src/openworldlib/representations/point_clouds_generation/pi3/loger/utils/basic.py create mode 100644 src/openworldlib/representations/point_clouds_generation/pi3/loger/utils/geometry.py create mode 100644 src/openworldlib/representations/point_clouds_generation/pi3/loger/utils/rotation.py create mode 100644 src/openworldlib/representations/point_clouds_generation/pi3/loger/utils/viser_utils.py create mode 100644 src/openworldlib/representations/point_clouds_generation/pi3/loger/utils/visual_util.py create mode 100644 src/openworldlib/representations/point_clouds_generation/pi3/loger_representation.py create mode 100644 test/test_loger.py diff --git a/src/openworldlib/pipelines/pi3/__init__.py b/src/openworldlib/pipelines/pi3/__init__.py index fb2c939d..4443b3ae 100644 --- a/src/openworldlib/pipelines/pi3/__init__.py +++ b/src/openworldlib/pipelines/pi3/__init__.py @@ -1,3 +1,4 @@ from .pipeline_pi3 import Pi3Pipeline, Pi3Result +from .pipeline_loger import LoGeRPipeline, LoGeRResult -__all__ = ["Pi3Pipeline", "Pi3Result"] +__all__ = ["Pi3Pipeline", "Pi3Result", "LoGeRPipeline", "LoGeRResult"] diff --git a/src/openworldlib/pipelines/pi3/pipeline_loger.py b/src/openworldlib/pipelines/pi3/pipeline_loger.py new file mode 100644 index 00000000..25f00972 --- /dev/null +++ b/src/openworldlib/pipelines/pi3/pipeline_loger.py @@ -0,0 +1,526 @@ +import os +import json +import math +from typing import List, Optional, Union, Dict, Any + +import numpy as np +from PIL import Image + +from ...operators.pi3_operator import Pi3Operator +from ...representations.point_clouds_generation.pi3.loger_representation import LoGeRRepresentation + +def _apply_camera_delta(c2w: np.ndarray, delta: List[float]) -> np.ndarray: + """Apply a [dx,dy,dz,theta_x,theta_z] delta to a 4x4 camera-to-world matrix.""" + dx, dy, dz, theta_x, theta_z = delta + result = c2w.copy() + + rx = np.eye(4) + cx, sx = math.cos(theta_x), math.sin(theta_x) + rx[1, 1], rx[1, 2] = cx, -sx + rx[2, 1], rx[2, 2] = sx, cx + + rz = np.eye(4) + cz, sz = math.cos(theta_z), math.sin(theta_z) + rz[0, 0], rz[0, 1] = cz, -sz + rz[1, 0], rz[1, 1] = sz, cz + + result = result @ rx @ rz + result[0, 3] += dx + result[1, 3] += dy + result[2, 3] += dz + return result + + +def render_point_cloud( + points: np.ndarray, + colors: np.ndarray, + camera_to_world: np.ndarray, + height: int, + width: int, + focal_scale: float = 1.0, + splat_radius: int = 3, +) -> Image.Image: + """Render a point cloud with strict z-buffer and front-to-back splatting. + + Points are sorted front-to-back. Each point is splatted as a disk. + Only the closest point at each pixel is kept (strict z-buffer), + which eliminates ghosting/layering artifacts. + """ + c2w = camera_to_world.astype(np.float64) + w2c = np.linalg.inv(c2w) + R, t = w2c[:3, :3], w2c[:3, 3] + + pts_cam = (R @ points.T).T + t + valid = pts_cam[:, 2] > 1e-4 + pts_cam = pts_cam[valid] + cols = colors[valid] + if cols.dtype == np.float64 or cols.dtype == np.float32: + if cols.max() <= 1.0: + cols = (cols * 255).clip(0, 255).astype(np.uint8) + else: + cols = cols.clip(0, 255).astype(np.uint8) + + fx = fy = focal_scale * max(height, width) + cx_img, cy_img = width / 2.0, height / 2.0 + + u = np.round(fx * pts_cam[:, 0] / pts_cam[:, 2] + cx_img).astype(np.int32) + v = np.round(fy * pts_cam[:, 1] / pts_cam[:, 2] + cy_img).astype(np.int32) + z = pts_cam[:, 2].astype(np.float32) + + sort_idx = np.argsort(z) + u, v, z, cols = u[sort_idx], v[sort_idx], z[sort_idx], cols[sort_idx] + + canvas = np.zeros((height, width, 3), dtype=np.uint8) + z_buf = np.full((height, width), np.inf, dtype=np.float32) + + r = splat_radius + for dy in range(-r, r + 1): + for dx in range(-r, r + 1): + if dx * dx + dy * dy > r * r: + continue + py = v + dy + px = u + dx + mask = (px >= 0) & (px < width) & (py >= 0) & (py < height) + px_m, py_m, z_m, cols_m = px[mask], py[mask], z[mask], cols[mask] + closer = z_m < z_buf[py_m, px_m] + px_c, py_c = px_m[closer], py_m[closer] + z_buf[py_c, px_c] = z_m[closer] + canvas[py_c, px_c] = cols_m[closer] + + return Image.fromarray(canvas) + + +class LoGeRResult: + + def __init__( + self, + depth_images: List[Image.Image], + numpy_data: Dict[str, np.ndarray], + camera_params: List[Dict[str, Any]], + camera_range: Dict[str, Any], + input_images: Optional[List[np.ndarray]] = None, + data_type: str = "image", + ): + self.depth_images = depth_images + self.numpy_data = numpy_data + self.camera_params = camera_params + self.camera_range = camera_range + self.input_images = input_images + self.data_type = data_type + + def __len__(self): + return len(self.depth_images) + + def __getitem__(self, idx): + return { + "depth_image": self.depth_images[idx] if idx < len(self.depth_images) else None, + "camera_params": self.camera_params[idx] if idx < len(self.camera_params) else None, + } + + def save(self, output_dir: Optional[str] = None) -> List[str]: + if output_dir is None: + output_dir = "./loger_output" + + os.makedirs(output_dir, exist_ok=True) + saved_files: List[str] = [] + + # Point cloud (PLY) + ply_dir = os.path.join(output_dir, "point_cloud") + os.makedirs(ply_dir, exist_ok=True) + if "points" in self.numpy_data and "masks" in self.numpy_data and self.input_images is not None: + try: + from plyfile import PlyData, PlyElement + + points_b0 = self.numpy_data["points"][0] + masks_b0 = self.numpy_data["masks"][0].astype(bool) + colors = np.stack(self.input_images, axis=0) + + pts = points_b0[masks_b0].astype(np.float32) + col = (colors[masks_b0] * 255).clip(0, 255).astype(np.uint8) + + vertices = np.zeros( + pts.shape[0], + dtype=[ + ("x", "f4"), ("y", "f4"), ("z", "f4"), + ("nx", "f4"), ("ny", "f4"), ("nz", "f4"), + ("red", "u1"), ("green", "u1"), ("blue", "u1"), + ], + ) + vertices["x"], vertices["y"], vertices["z"] = pts[:, 0], pts[:, 1], pts[:, 2] + vertices["nx"], vertices["ny"], vertices["nz"] = 0.0, 0.0, 0.0 + vertices["red"], vertices["green"], vertices["blue"] = col[:, 0], col[:, 1], col[:, 2] + + ply_path = os.path.join(ply_dir, "result.ply") + PlyData([PlyElement.describe(vertices, "vertex")]).write(ply_path) + saved_files.append(ply_path) + except ImportError: + pass + + # Raw numpy data + raw_dir = os.path.join(output_dir, "raw_data") + os.makedirs(raw_dir, exist_ok=True) + for key, value in self.numpy_data.items(): + if isinstance(value, np.ndarray): + npy_path = os.path.join(raw_dir, f"{key}.npy") + np.save(npy_path, value) + saved_files.append(npy_path) + + # Depth map visualizations + depth_dir = os.path.join(output_dir, "depth") + os.makedirs(depth_dir, exist_ok=True) + for i, img in enumerate(self.depth_images): + depth_path = os.path.join(depth_dir, f"depth_{i:04d}.png") + img.save(depth_path) + saved_files.append(depth_path) + + # Input RGB frames + if self.input_images is not None and len(self.input_images) > 0: + rgb_dir = os.path.join(output_dir, "rgb") + os.makedirs(rgb_dir, exist_ok=True) + for i, img_arr in enumerate(self.input_images): + img_uint8 = (img_arr * 255).clip(0, 255).astype(np.uint8) + rgb_path = os.path.join(rgb_dir, f"frame_{i:04d}.png") + Image.fromarray(img_uint8).save(rgb_path) + saved_files.append(rgb_path) + + # Camera poses + poses_dir = os.path.join(output_dir, "camera_poses") + os.makedirs(poses_dir, exist_ok=True) + for i, cam in enumerate(self.camera_params): + pose_path = os.path.join(poses_dir, f"pose_{i:04d}.json") + with open(pose_path, "w") as f: + json.dump(cam, f, indent=2) + saved_files.append(pose_path) + + # meta.json with camera_range + meta_path = os.path.join(output_dir, "meta.json") + with open(meta_path, "w") as f: + json.dump({"camera_range": self.camera_range}, f, indent=2) + saved_files.append(meta_path) + + return saved_files + + +def _build_camera_range(camera_params: List[Dict[str, Any]]) -> Dict[str, Any]: + """Compute camera parameter range from a list of camera_to_world matrices.""" + n = len(camera_params) + if n == 0: + return {} + + translations = np.array([np.array(c["camera_to_world"])[:3, 3] for c in camera_params]) + return { + "available_view_indices": list(range(n)), + "default_view_index": 0, + "num_views": n, + "translation_min": translations.min(axis=0).tolist(), + "translation_max": translations.max(axis=0).tolist(), + } + + +class LoGeRPipeline: + + def __init__( + self, + representation_model=None, + reasoning_model: Optional[Any] = None, + synthesis_model: Optional[Any] = None, + operator: Optional[Pi3Operator] = None, + ) -> None: + self.representation_model = representation_model + self.reasoning_model = reasoning_model + self.synthesis_model = synthesis_model + self.operator = operator or Pi3Operator() + self._cached_result: Optional[LoGeRResult] = None + self._current_camera: Optional[np.ndarray] = None + + @classmethod + def from_pretrained( + cls, + model_path: Optional[str] = None, + required_components: Optional[Dict[str, str]] = None, + mode: str = "loger", + device: Optional[str] = None, + weight_dtype: Optional[str] = None, + representation_path: Optional[str] = None, + model_type: Optional[str] = None, + **kwargs, + ) -> "LoGeRPipeline": + path = model_path or representation_path + if path is None: + raise ValueError("model_path is required.") + m = model_type or mode + + if m in ("loger", "loger_star"): + subfolder = "LoGeR" if m == "loger" else "LoGeR_star" + representation_model = LoGeRRepresentation.from_pretrained( + pretrained_model_path=path, device=device, subfolder=subfolder, **kwargs, + ) + else: + raise ValueError(f"Unknown mode: {m}. Choose 'loger' or 'loger_star'.") + return cls(representation_model=representation_model) + + def _run_inference( + self, + images: Union[str, np.ndarray, List[str], List[np.ndarray]], + **kwargs, + ) -> LoGeRResult: + """Run LoGeR model inference (single-shot, produces all outputs).""" + if self.representation_model is None: + raise RuntimeError("Representation model not loaded. Use from_pretrained() first.") + + interval = kwargs.get("interval", -1) + images_data = self.operator.process_perception(images, interval=interval) + if not isinstance(images_data, list): + images_data = [images_data] + + device = self.representation_model.device + imgs_tensor = self.operator.images_to_tensor(images_data, device=device) + + resized_images = [ + imgs_tensor[0, i].permute(1, 2, 0).cpu().numpy() + for i in range(imgs_tensor.shape[1]) + ] + + data = { + "images": imgs_tensor, + "conf_threshold": kwargs.get("conf_threshold", 0.1), + "edge_rtol": kwargs.get("edge_rtol", 0.03), + } + + # ── LoGeR / Pi3 推理控制参数透传 ────────────────────────── + _LOGER_KEYS = ( + "window_size", "overlap_size", "num_iterations", + "sim3", "se3", "reset_every", + "turn_off_ttt", "turn_off_swa", "sim3_scale_mode", + ) + for k in _LOGER_KEYS: + if k in kwargs: + data[k] = kwargs[k] + + conditions_path = kwargs.get("conditions_path") + if conditions_path is not None and os.path.exists(conditions_path): + import torch as _torch + cond_data = np.load(conditions_path, allow_pickle=True) + if "poses" in cond_data: + data["poses"] = _torch.from_numpy(cond_data["poses"]).float().unsqueeze(0) + if "depths" in cond_data: + data["depths"] = _torch.from_numpy(cond_data["depths"]).float().unsqueeze(0) + if "intrinsics" in cond_data: + data["intrinsics"] = _torch.from_numpy(cond_data["intrinsics"]).float().unsqueeze(0) + + results = self.representation_model.get_representation(data) + + depth_images = [] + depth_maps = results.get("depth_map") + if depth_maps is not None: + depth_b0 = depth_maps[0] + if depth_b0.ndim == 2: + depth_b0 = depth_b0[np.newaxis, ...] + for i in range(depth_b0.shape[0]): + d = depth_b0[i].astype(np.float64) + d_min, d_max = d.min(), d.max() + d_norm = (d - d_min) / (d_max - d_min + 1e-8) + d_uint8 = (d_norm * 255).astype(np.uint8) + depth_images.append(Image.fromarray(d_uint8, mode="L")) + + camera_params = [] + cam_poses = results.get("camera_poses") + if cam_poses is not None: + for i in range(cam_poses[0].shape[0]): + camera_params.append({ + "camera_to_world": cam_poses[0][i].tolist(), + }) + + camera_range = _build_camera_range(camera_params) + + result = LoGeRResult( + depth_images=depth_images, + numpy_data=results, + camera_params=camera_params, + camera_range=camera_range, + input_images=resized_images, + data_type="image", + ) + self._cached_result = result + if camera_params: + self._current_camera = np.array(camera_params[0]["camera_to_world"]) + return result + + def render_view( + self, + result: Optional["LoGeRResult"] = None, + camera_view=None, + camera_to_world: Optional[np.ndarray] = None, + ) -> Image.Image: + """Render a view from cached point cloud (no model inference). + + Args: + camera_view: Supports multiple formats: + - int: index into result.camera_params (e.g., 0, 1, 2) + - list of 5 floats: [dx,dy,dz,theta_x,theta_z] delta from default camera + - None: use the default (first) camera + camera_to_world: explicit 4x4 matrix (overrides camera_view if provided) + """ + res = result or self._cached_result + if res is None: + raise RuntimeError("No result available. Run inference first via __call__().") + + if "points" not in res.numpy_data or "masks" not in res.numpy_data: + raise RuntimeError("Result does not contain point cloud data.") + + pts_all = res.numpy_data["points"][0] + masks = res.numpy_data["masks"][0].astype(bool) + colors_all = np.stack(res.input_images, axis=0) if res.input_images else None + if colors_all is None: + raise RuntimeError("No input images in result for coloring.") + + pts = pts_all[masks].astype(np.float64) + cols = (colors_all[masks] * 255).clip(0, 255).astype(np.uint8) + + if camera_to_world is not None: + c2w = np.array(camera_to_world, dtype=np.float64) + elif isinstance(camera_view, int): + c2w = np.array(res.camera_params[camera_view]["camera_to_world"], dtype=np.float64) + elif isinstance(camera_view, (list, tuple)): + base = np.array(res.camera_params[0]["camera_to_world"], dtype=np.float64) + c2w = _apply_camera_delta(base, camera_view) + else: + c2w = np.array(res.camera_params[0]["camera_to_world"], dtype=np.float64) + + h = pts_all.shape[1] if pts_all.ndim >= 3 else 480 + w = pts_all.shape[2] if pts_all.ndim >= 4 else 640 + if pts_all.ndim >= 4: + h, w = pts_all.shape[1], pts_all.shape[2] + + return render_point_cloud(pts, cols, c2w, h, w) + + def _render_trajectory(self, n_interp: int = 15, fps: int = 15, **kwargs) -> List[Image.Image]: + """Render a trajectory video by interpolating between original camera poses. + Returns a list of PIL.Image frames. + """ + res = self._cached_result + if res is None: + raise RuntimeError("No result available. Run reconstruction first.") + + pts_all = res.numpy_data["points"][0] + masks = res.numpy_data["masks"][0].astype(bool) + colors_all = np.stack(res.input_images, axis=0) + pts = pts_all[masks].astype(np.float64) + cols = (colors_all[masks] * 255).clip(0, 255).astype(np.uint8) + h, w = pts_all.shape[1], pts_all.shape[2] + + c2ws = [np.array(c["camera_to_world"], dtype=np.float64) for c in res.camera_params] + frames = [] + for vi in range(len(c2ws) - 1): + for j in range(n_interp): + t = j / n_interp + c2w = c2ws[vi] * (1 - t) + c2ws[vi + 1] * t + frames.append(render_point_cloud(pts, cols, c2w, h, w)) + frames.append(render_point_cloud(pts, cols, c2ws[-1], h, w)) + return frames + + def __call__( + self, + images: Optional[Union[str, np.ndarray, List[str], List[np.ndarray]]] = None, + videos: Optional[Union[str, List[str]]] = None, + image_path: Optional[str] = None, + video_path: Optional[str] = None, + task_type: str = "reconstruction", + interactions: Optional[List[str]] = None, + camera_view=None, + visualize_ops: bool = True, + **kwargs, + ): + """Unified call interface. Behavior is determined by task_type. + + Args: + images: Image input path/list/tensor/array. + videos: Video input path/list. If provided, takes precedence over images. + image_path: Alias for a single image path. + video_path: Alias for a single video path. + task_type: One of "reconstruction", "render_view", "render_trajectory". + interactions: Navigation signals like ["forward", "left", "camera_r"]. + When provided with task_type="render_view", generates a video + with smooth transitions between each interaction. + camera_view: Supports int (view index) or list [dx,dy,dz,theta_x,theta_z]. + visualize_ops: Whether to generate visualizations. + + Returns: + - task_type="reconstruction": LoGeRResult + - task_type="render_view": PIL.Image or List[PIL.Image] (when interactions given) + - task_type="render_trajectory": List[PIL.Image] + """ + visual_input = videos or video_path or images or image_path + + if task_type == "reconstruction": + if visual_input is None: + raise ValueError("images is required for task_type='reconstruction'.") + return self._run_inference(visual_input, **kwargs) + + elif task_type == "render_view": + if interactions is not None: + n_move = kwargs.get("frames_per_interaction", 30) + n_hold = kwargs.get("hold_frames", 10) + frames = [] + if self._current_camera is None and self._cached_result is not None: + self._current_camera = np.array( + self._cached_result.camera_params[0]["camera_to_world"] + ) + for sig in interactions: + hold_img = self.render_view(camera_to_world=self._current_camera) + for _ in range(n_hold): + frames.append(hold_img) + delta = self.operator.process_interaction_single(sig) + sub_delta = [d / n_move for d in delta] + for _ in range(n_move): + self._current_camera = _apply_camera_delta( + self._current_camera, sub_delta + ) + frames.append(self.render_view( + camera_to_world=self._current_camera + )) + hold_img = self.render_view(camera_to_world=self._current_camera) + for _ in range(n_hold): + frames.append(hold_img) + return frames + return self.render_view(camera_view=camera_view, **kwargs) + + elif task_type == "render_trajectory": + return self._render_trajectory(**kwargs) + + else: + raise ValueError( + f"Unknown task_type: {task_type}. " + "Choose 'reconstruction', 'render_view', or 'render_trajectory'." + ) + + def stream( + self, + interaction_signal: Union[str, List[str]], + result: Optional[LoGeRResult] = None, + **kwargs, + ) -> Image.Image: + """Interactive rendering: apply navigation interaction to current camera and render. + + This does NOT run model inference. It uses the cached point cloud from a + previous __call__() and applies the interaction delta to move the camera. + """ + res = result or self._cached_result + if res is None: + raise RuntimeError("No result available. Run reconstruction first via __call__().") + + if self._current_camera is None: + if res.camera_params: + self._current_camera = np.array(res.camera_params[0]["camera_to_world"]) + else: + raise RuntimeError("No camera parameters available.") + + if isinstance(interaction_signal, str): + interaction_signal = [interaction_signal] + + self.operator.get_interaction(interaction_signal) + delta = self.operator.process_interaction() + + self._current_camera = _apply_camera_delta(self._current_camera, delta) + + return self.render_view(result=res, camera_to_world=self._current_camera) diff --git a/src/openworldlib/representations/point_clouds_generation/pi3/loger/__init__.py b/src/openworldlib/representations/point_clouds_generation/pi3/loger/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/openworldlib/representations/point_clouds_generation/pi3/loger/layers/attention.py b/src/openworldlib/representations/point_clouds_generation/pi3/loger/layers/attention.py new file mode 100644 index 00000000..73f0786f --- /dev/null +++ b/src/openworldlib/representations/point_clouds_generation/pi3/loger/layers/attention.py @@ -0,0 +1,291 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import logging +import os +import warnings + +from torch import Tensor +from torch import nn +import torch + +from torch.nn.functional import scaled_dot_product_attention +from torch.nn.attention import SDPBackend + +from ...pi3.layers.attention import ( + Attention, MemEffAttention, FlashAttention, CrossAttentionRope, MemEffCrossAttentionRope, AttentionRope, get_attn_score, PRopeFlashAttention, FlashCrossAttentionRope +) + +try: + from torch.nn.attention.flex_attention import flex_attention, create_block_mask + FLEX_ATTENTION_AVAILABLE = True +except ImportError: + FLEX_ATTENTION_AVAILABLE = False + flex_attention = None + create_block_mask = None + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import memory_efficient_attention, unbind + + XFORMERS_AVAILABLE = True + # warnings.warn("xFormers is available (Attention)") + else: + # warnings.warn("xFormers is disabled (Attention)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + # warnings.warn("xFormers is not available (Attention)") + + +# Cache for block masks to avoid recreation +_BLOCK_MASK_CACHE = {} + + +def get_causal_block_mask(P, B, H, M, N, device="cuda", _compile=True): + """ + Get causal block mask with efficient caching based on logical parameters. + + Args: + P: tokens per frame (image) + B: batch size (not used in cache key since mask can be reused across batch sizes) + H: number of heads + M: query sequence length (num_frames * P) + N: key sequence length (num_frames * P) + device: target device + _compile: whether to compile + + Returns: + Block mask where tokens within the same image can see each other, + but tokens from different images can only see previous images. + """ + if not FLEX_ATTENTION_AVAILABLE: + return None + + # Create cache key based on logical parameters + device_idx = device.index if hasattr(device, 'index') else 0 + cache_key = (P, H, M, N, device_idx, _compile) + + if cache_key in _BLOCK_MASK_CACHE: + cached_mask = _BLOCK_MASK_CACHE[cache_key] + return cached_mask + + # Create the score function + # Tokens within the same frame can see each other + # Tokens from frame i can see all tokens from frames 0 to i + def causal_mask(b, h, q_idx, kv_idx): + q_frame = q_idx // P + kv_frame = kv_idx // P + return q_frame >= kv_frame + + # Create new block mask + block_mask = create_block_mask(causal_mask, B, H, M, N, device=device, _compile=_compile) + + # Cache it + _BLOCK_MASK_CACHE[cache_key] = block_mask + + return block_mask + +class MemEffAttentionRope(AttentionRope): + def forward(self, x: Tensor, attn_bias=None, xpos=None, attn_mask=None) -> Tensor: + # If attn_mask is provided and flex_attention is available, use flex_attention + if attn_mask is not None and FLEX_ATTENTION_AVAILABLE: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).transpose(1, 3) + q, k, v = [qkv[:,:,i] for i in range(3)] + q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype) + + if self.rope is not None: + q = self.rope(q, xpos) + k = self.rope(k, xpos) + + # Ensure all tensors have the same dtype + target_dtype = v.dtype + if q.dtype != target_dtype: + q = q.to(target_dtype) + if k.dtype != target_dtype: + k = k.to(target_dtype) + + x = flex_attention( + q, k, v, + block_mask=attn_mask, + scale=None, + enable_gqa=False, + return_lse=False + ) + x = x.transpose(1, 2).reshape([B, N, C]) + x = self.proj(x) + x = self.proj_drop(x) + return x + + # Otherwise use xformers memory_efficient_attention + if not XFORMERS_AVAILABLE: + if attn_bias is not None: + raise AssertionError("xFormers is required for using nested tensors") + return super().forward(x, attn_bias=attn_bias, xpos=xpos, attn_mask=attn_mask) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + qkv = qkv.transpose(1, 3) + # q, k, v = unbind(qkv, 2) + q, k, v = [qkv[:,:,i] for i in range(3)] + q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype) + + if self.rope is not None: + q = self.rope(q, xpos) + k = self.rope(k, xpos) + + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + # score_matrix = (q.permute(0, 2, 1, 3) * self.scale @ k.permute(0, 2, 1, 3).transpose(-2, -1)).sum(dim=1).reshape(frame_num, 261, frame_num, 261).mean(dim=[1, 3]).sum(1) # for frame attention matrix + # global_valid_id = torch.where(score_matrix > 0) + # score_matrix = (q.permute(0, 2, 1, 3) * self.scale @ k.permute(0, 2, 1, 3).transpose(-2, -1)).sum(dim=1) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class FlashAttentionRope(AttentionRope): + def compute_kv(self, x: Tensor, xpos=None) -> tuple[Tensor, Tensor]: + """Compute K, V for caching. Returns (K, V) after norm and RoPE.""" + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).transpose(1, 3) + q, k, v = [qkv[:,:,i] for i in range(3)] + q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype) + + if self.rope is not None: + k = self.rope(k, xpos) + + return k, v + + def forward_with_kv_cache( + self, + x: Tensor, + k_cache: Tensor, + v_cache: Tensor, + xpos=None, + xpos_cache=None, + attn_mask=None + ) -> Tensor: + """Forward with pre-computed KV cache for history tokens. + + Args: + x: Current tokens [B, N_curr, C] + k_cache: Cached K from history [B, num_heads, N_hist, head_dim] + v_cache: Cached V from history [B, num_heads, N_hist, head_dim] + xpos: Position info for current tokens + xpos_cache: Position info for cached tokens (unused, positions already applied) + attn_mask: Optional attention mask + + Returns: + Output for current tokens only [B, N_curr, C] + """ + B, N_curr, C = x.shape + + # Compute Q, K, V for current tokens + qkv = self.qkv(x).reshape(B, N_curr, 3, self.num_heads, C // self.num_heads).transpose(1, 3) + q, k, v = [qkv[:,:,i] for i in range(3)] + q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype) + + if self.rope is not None: + q = self.rope(q, xpos) + k = self.rope(k, xpos) + + # Concatenate cached KV with current KV + # k_cache, v_cache: [B, num_heads, N_hist, head_dim] + # k, v: [B, num_heads, N_curr, head_dim] + k_full = torch.cat([k_cache, k], dim=2) + v_full = torch.cat([v_cache, v], dim=2) + + # Compute attention + is_float_mask = (attn_mask is not None and torch.is_floating_point(attn_mask)) + + if attn_mask is not None and FLEX_ATTENTION_AVAILABLE and not is_float_mask: + target_dtype = v_full.dtype + if q.dtype != target_dtype: + q = q.to(target_dtype) + if k_full.dtype != target_dtype: + k_full = k_full.to(target_dtype) + + x = flex_attention( + q, k_full, v_full, + block_mask=attn_mask, + scale=None, + enable_gqa=False, + return_lse=False + ) + else: + if q.dtype == torch.bfloat16 and not is_float_mask: + with nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION): + x = scaled_dot_product_attention(q, k_full, v_full) + else: + with nn.attention.sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]): + x = scaled_dot_product_attention(q, k_full, v_full, attn_mask=attn_mask) + + x = x.transpose(1, 2).reshape([B, N_curr, C]) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def forward(self, x: Tensor, attn_bias=None, xpos=None, attn_mask=None) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).transpose(1, 3) + + # q, k, v = unbind(qkv, 2) + q, k, v = [qkv[:,:,i] for i in range(3)] + q, k = self.q_norm(q).to(v.dtype), self.k_norm(k).to(v.dtype) + + if self.rope is not None: + q = self.rope(q, xpos) + k = self.rope(k, xpos) + + # If attn_mask (block_mask) is provided and flex_attention is available, use it + # If attn_mask (block_mask) is provided and flex_attention is available, use it + # [MODIFIED] Check if attn_mask is a float tensor (bias). If so, skip flex_attention + # because flex_attention typically expects a BlockMask or boolean mask. + is_float_mask = (attn_mask is not None and torch.is_floating_point(attn_mask)) + + if attn_mask is not None and FLEX_ATTENTION_AVAILABLE and not is_float_mask: + # Ensure all tensors have the same dtype for flex_attention + target_dtype = v.dtype + if q.dtype != target_dtype: + q = q.to(target_dtype) + if k.dtype != target_dtype: + k = k.to(target_dtype) + + x = flex_attention( + q, k, v, + block_mask=attn_mask, + scale=None, # flex_attention applies 1/sqrt(d) automatically + enable_gqa=False, + return_lse=False + ) + else: + # Use standard scaled_dot_product_attention + if q.dtype == torch.bfloat16 and not is_float_mask: + with nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION): + x = scaled_dot_product_attention(q, k, v) + else: + # Fallback to MATH/EFFICIENT if using float mask or other dtypes + with nn.attention.sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]): + x = scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + + x = x.transpose(1, 2).reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x diff --git a/src/openworldlib/representations/point_clouds_generation/pi3/loger/layers/block.py b/src/openworldlib/representations/point_clouds_generation/pi3/loger/layers/block.py new file mode 100644 index 00000000..1983e53e --- /dev/null +++ b/src/openworldlib/representations/point_clouds_generation/pi3/loger/layers/block.py @@ -0,0 +1,163 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +import logging +import os +from typing import Callable, List, Any, Tuple, Dict +import warnings + +import torch +from torch import nn, Tensor + +from .attention import Attention, MemEffAttention, CrossAttentionRope, MemEffCrossAttentionRope, FlashAttentionRope +from ......base_models.perception_core.general_perception.dinov2.layers.drop_path import DropPath +from ......base_models.perception_core.general_perception.dinov2.layers.layer_scale import LayerScale +from ......base_models.perception_core.general_perception.dinov2.layers.mlp import Mlp +from ...pi3.layers.block import ( + Block, drop_add_residual_stochastic_depth, get_branges_scales, add_residual, get_attn_bias_and_cat, drop_add_residual_stochastic_depth_list, NestedTensorBlock, CrossBlockRope, PoseInjectBlock, CrossOnlyBlockRope +) + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import fmha, scaled_index_add, index_select_cat + + XFORMERS_AVAILABLE = True + # warnings.warn("xFormers is available (Block)") + else: + # warnings.warn("xFormers is disabled (Block)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + # warnings.warn("xFormers is not available (Block)") + + +attn_bias_cache: Dict[Tuple, Any] = {} + +class BlockRope(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + qk_norm: bool=False, + rope=None + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + qk_norm=qk_norm, + rope=rope + ) + + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def compute_kv_cache(self, x: Tensor, xpos=None) -> tuple[Tensor, Tensor]: + """Compute K, V for caching from input x. + + Args: + x: Input tensor [B, N, C] + xpos: Position info for RoPE + + Returns: + (k, v): Cached K and V tensors, each [B, num_heads, N, head_dim] + """ + x_normed = self.norm1(x) + return self.attn.compute_kv(x_normed, xpos=xpos) + + def forward_with_kv_cache( + self, + x: Tensor, + k_cache: Tensor, + v_cache: Tensor, + xpos=None, + attn_mask=None + ) -> Tensor: + """Forward with pre-computed KV cache for history tokens. + + Args: + x: Current tokens [B, N_curr, C] + k_cache: Cached K from history [B, num_heads, N_hist, head_dim] + v_cache: Cached V from history [B, num_heads, N_hist, head_dim] + xpos: Position info for current tokens + attn_mask: Optional attention mask + + Returns: + Output for current tokens [B, N_curr, C] + """ + # Attention with KV cache + x_normed = self.norm1(x) + attn_out = self.attn.forward_with_kv_cache( + x_normed, k_cache, v_cache, xpos=xpos, attn_mask=attn_mask + ) + x = x + self.ls1(attn_out) + + # MLP (only on current tokens) + x = x + self.ls2(self.mlp(self.norm2(x))) + return x + + def forward(self, x: Tensor, xpos=None, attn_mask=None) -> Tensor: + def attn_residual_func(x: Tensor) -> Tensor: + return self.ls1(self.attn(self.norm1(x), xpos=xpos, attn_mask=attn_mask)) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x) + x = x + ffn_residual_func(x) + return x diff --git a/src/openworldlib/representations/point_clouds_generation/pi3/loger/layers/camera_head.py b/src/openworldlib/representations/point_clouds_generation/pi3/loger/layers/camera_head.py new file mode 100644 index 00000000..1add654e --- /dev/null +++ b/src/openworldlib/representations/point_clouds_generation/pi3/loger/layers/camera_head.py @@ -0,0 +1,159 @@ +import torch +import torch.nn as nn +from copy import deepcopy +import torch.nn.functional as F + +from ...pi3.layers.camera_head import ResConvBlock + +# code adapted from 'https://github.com/nianticlabs/marepo/blob/9a45e2bb07e5bb8cb997620088d352b439b13e0e/transformer/transformer.py#L172' + +class CameraHead(nn.Module): + def __init__(self, dim=512, output_quat=False): + super().__init__() + output_dim = dim + self.output_quat = output_quat + self.res_conv = nn.ModuleList([deepcopy(ResConvBlock(output_dim, output_dim)) + for _ in range(2)]) + self.avgpool = nn.AdaptiveAvgPool2d(1) + self.more_mlps = nn.Sequential( + nn.Linear(output_dim,output_dim), + nn.ReLU(), + nn.Linear(output_dim,output_dim), + nn.ReLU() + ) + self.fc_t = nn.Linear(output_dim, 3) + if self.output_quat: + self.fc_rot_qvec = nn.Linear(output_dim, 4) + else: + self.fc_rot = nn.Linear(output_dim, 9) + + def forward(self, feat, patch_h, patch_w): + BN, hw, c = feat.shape + + for i in range(2): + feat = self.res_conv[i](feat) + + # feat = self.avgpool(feat) + feat = self.avgpool(feat.permute(0, 2, 1).reshape(BN, -1, patch_h, patch_w).contiguous()) ########## + feat = feat.view(feat.size(0), -1) + + feat = self.more_mlps(feat) # [B, D_] + with torch.amp.autocast(device_type='cuda', enabled=False): + out_t = self.fc_t(feat.float()) # [B,3] + if self.output_quat: + out_r = self.fc_rot_qvec(feat.float()) # [B,4] + pose = self.convert_quat_to_4x4(BN, out_r, out_t, feat.device) + return pose, out_r + else: + out_r = self.fc_rot(feat.float()) # [B,9] or [B,4] + pose = self.convert_pose_to_4x4(BN, out_r, out_t, feat.device) + return pose + + def convert_quat_to_4x4(self, B, q, t, device): + # q: [B, 4] (w, x, y, z) + # t: [B, 3] + + q = torch.nn.functional.normalize(q, dim=-1) + w, x, y, z = q[:, 0], q[:, 1], q[:, 2], q[:, 3] + + # Rotation matrix elements + xx = x * x + yy = y * y + zz = z * z + xy = x * y + xz = x * z + yz = y * z + xw = x * w + yw = y * w + zw = z * w + + row0 = torch.stack([1 - 2 * (yy + zz), 2 * (xy - zw), 2 * (xz + yw)], dim=-1) + row1 = torch.stack([2 * (xy + zw), 1 - 2 * (xx + zz), 2 * (yz - xw)], dim=-1) + row2 = torch.stack([2 * (xz - yw), 2 * (yz + xw), 1 - 2 * (xx + yy)], dim=-1) + + R = torch.stack([row0, row1, row2], dim=1) # [B, 3, 3] + + pose = torch.zeros((B, 4, 4), device=device) + pose[:, :3, :3] = R + pose[:, :3, 3] = t + pose[:, 3, 3] = 1.0 + + return pose + + def convert_pose_to_4x4(self, B, out_r, out_t, device): + out_r = self.svd_orthogonalize(out_r) # [N,3,3] + pose = torch.zeros((B, 4, 4), device=device) + pose[:, :3, :3] = out_r + pose[:, :3, 3] = out_t + pose[:, 3, 3] = 1. + return pose + + def svd_orthogonalize_old(self, m): + """Convert 9D representation to SO(3) using SVD orthogonalization. + + Args: + m: [BATCH, 3, 3] 3x3 matrices. + + Returns: + [BATCH, 3, 3] SO(3) rotation matrices. + """ + if m.dim() < 3: + m = m.reshape((-1, 3, 3)) + m_transpose = torch.transpose(torch.nn.functional.normalize(m, p=2, dim=-1), dim0=-1, dim1=-2) + u, s, v = torch.svd(m_transpose) + det = torch.det(torch.matmul(v, u.transpose(-2, -1))) + # Check orientation reflection. + r = torch.matmul( + torch.cat([v[:, :, :-1], v[:, :, -1:] * det.view(-1, 1, 1)], dim=2), + u.transpose(-2, -1) + ) + return r + + def svd_orthogonalize(self, m): + """ + Convert 9D representation to SO(3) using SVD orthogonalization. + This is a more stable implementation using torch.linalg.svd. + """ + if m.dim() < 3: + m = m.reshape((-1, 3, 3)) + + B = m.shape[0] + + # 1. 和原来一样: 归一化 m 的每一行,然后转置 + # m_transpose 的列向量是单位向量 + m_norm_rows = torch.nn.functional.normalize(m, p=2, dim=-1) + m_transpose = m_norm_rows.transpose(-1, -2) + + # 2. 使用 torch.linalg.svd 替换 torch.svd + # A = U S Vh (其中 Vh = V^T) + try: + u, s, vh = torch.linalg.svd(m_transpose) + except torch.linalg.LinAlgError as e: + # SVD 失败的罕见情况 (例如,如果输入是全零或NaN) + print(f"SVD failed: {e}. Returning identity.") + # 返回一个 batch 的单位矩阵 + return torch.eye(3, device=m.device, dtype=m.dtype).unsqueeze(0).expand(B, 3, 3) + + # 3. 计算 R = U @ Vh (这是正交矩阵,但可能 det(R) = -1) + R_ortho = u @ vh + + # 4. 计算行列式 det(R) + det = torch.det(R_ortho) + + # 5. 创建修正矩阵 D = diag(1, 1, det(R)) + # 这会处理反射(reflection)情况 (det = -1) + # 我们需要为 batch 中的每个元素单独创建 D + D_vec = torch.stack([torch.ones_like(det), torch.ones_like(det), det], dim=-1) + D = torch.diag_embed(D_vec) + + # 6. 计算 R_so3 = U @ D @ Vh + # 这是最终的 SO(3) 旋转矩阵 R + R = u @ D @ vh + + # 7. 你的原始实现返回的是 R 的转置 (R^T) + # R^T = (U @ D @ Vh)^T = Vh.T @ D.T @ U.T = V @ D @ U^T + # 为了作为你原始代码的“直接替换”,我们返回 R^T + R_T = R.transpose(-1, -2) + + return R_T + \ No newline at end of file diff --git a/src/openworldlib/representations/point_clouds_generation/pi3/loger/pi3.py b/src/openworldlib/representations/point_clouds_generation/pi3/loger/pi3.py new file mode 100644 index 00000000..f70e58f5 --- /dev/null +++ b/src/openworldlib/representations/point_clouds_generation/pi3/loger/pi3.py @@ -0,0 +1,1266 @@ +import torch +import torch.nn as nn +from functools import partial +from copy import deepcopy +from typing import Optional, Union, List + +from .....base_models.perception_core.general_perception.dinov2.layers import Mlp +from .utils.geometry import homogenize_points, robust_scale_estimation +from ..pi3.layers.pos_embed import RoPE2D, PositionGetter +from .layers.block import BlockRope +from .layers.attention import FlashAttentionRope +from ..pi3.layers.transformer_head import TransformerDecoder, LinearPts3d, ContextOnlyTransformerDecoder +from .layers.camera_head import CameraHead +from ..pi3x.layers.conv_head import ConvHead +from .....base_models.perception_core.general_perception.dinov2.hub.backbones import dinov2_vitl14, dinov2_vitl14_reg +from huggingface_hub import PyTorchModelHubMixin +from .ttt import FastWeightGluMLPMultihead, TTTOperator + +class Pi3(nn.Module, PyTorchModelHubMixin): + def __init__( + self, + pos_type='rope100', + decoder_size='large', + ttt_insert_after: Union[int, List[int]] = None, + ttt_head_dim: int = 512, + ttt_inter_multi: int = 2, + num_muon_update_steps: int = 5, + use_momentum: bool = False, + ttt_update_steps: int = 1, + conf: bool = True, + attn_insert_after: Union[int, List[int], None] = None, + ttt_pre_norm: bool = False, + pi3x: bool = False, + pi3x_metric: bool = True, + ): + super().__init__() + + # ---------------------- + # Encoder + # ---------------------- + def _normalize_insert_positions(value: Union[int, List[int], None]) -> List[int]: + if isinstance(value, (int, float)): + return [int(value)] + if isinstance(value, (list, tuple)): + return [int(x) for x in value] + return [] + + parsed_ttt_insert_after = _normalize_insert_positions(ttt_insert_after) + parsed_attn_insert_after = _normalize_insert_positions(attn_insert_after) + + if not parsed_attn_insert_after: + parsed_attn_insert_after = parsed_ttt_insert_after.copy() + + self.ttt_insert_after = parsed_ttt_insert_after + self.attn_insert_after = parsed_attn_insert_after + self.detach_swa_history = False + self.initialize_swa_from_global = True + self.encoder = dinov2_vitl14_reg(pretrained=False) + self.patch_size = 14 + self.num_muon_update_steps = int(num_muon_update_steps) + self.num_pe_tokens = 3 + self.use_momentum = use_momentum + self.ttt_update_steps = int(ttt_update_steps) + self.use_conf = bool(conf) + self.ttt_pre_norm = ttt_pre_norm + self.pi3x = pi3x + self.pi3x_metric = pi3x_metric + del self.encoder.mask_token + + # ---------------------- + # Positonal Encoding + # ---------------------- + self.pos_type = pos_type if pos_type is not None else 'none' + self.rope=None + if self.pos_type.startswith('rope'): # eg rope100 + if RoPE2D is None: raise ImportError("Cannot find cuRoPE2D, please install it following the README instructions") + freq = float(self.pos_type[len('rope'):]) + self.rope = RoPE2D(freq=freq) + self.position_getter = PositionGetter() + else: + raise NotImplementedError + + + # ---------------------- + # Decoder + # ---------------------- + enc_embed_dim = self.encoder.blocks[0].attn.qkv.in_features # 1024 + if decoder_size == 'small': + dec_embed_dim = 384 + dec_num_heads = 6 + mlp_ratio = 4 + dec_depth = 24 + elif decoder_size == 'base': + dec_embed_dim = 768 + dec_num_heads = 12 + mlp_ratio = 4 + dec_depth = 24 + elif decoder_size == 'large': + dec_embed_dim = 1024 + dec_num_heads = 16 + mlp_ratio = 4 + dec_depth = 36 + else: + raise NotImplementedError + self.decoder = nn.ModuleList([ + BlockRope( + dim=dec_embed_dim, + num_heads=dec_num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=True, + proj_bias=True, + ffn_bias=True, + drop_path=0.0, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + act_layer=nn.GELU, + ffn_layer=Mlp, + init_values=0.01, + qk_norm=True, + attn_class=FlashAttentionRope, + rope=self.rope + ) for _ in range(dec_depth)]) + self.dec_embed_dim = dec_embed_dim + + # ---------------------- + # Register_token + # ---------------------- + num_register_tokens = 5 + self.patch_start_idx = num_register_tokens + self.register_token = nn.Parameter(torch.randn(1, 1, num_register_tokens, self.dec_embed_dim)) + nn.init.normal_(self.register_token, std=1e-6) + + for i in range(3): + pe_token = nn.Parameter(torch.randn(1, 1, 1, self.dec_embed_dim)) + nn.init.normal_(pe_token, std=1e-6) + self.register_parameter(f'pe_token_{i}', pe_token) + self.patch_start_idx += 1 + + # ---------------------- + # Local Points Decoder + # ---------------------- + self.point_decoder = TransformerDecoder( + in_dim=2*self.dec_embed_dim, + dec_embed_dim=1024, + dec_num_heads=16, + out_dim=1024, + rope=self.rope, + ) + if self.pi3x: + self.point_head = ConvHead( + num_features=4, + dim_in=1024, + projects=nn.Identity(), + dim_out=[2, 1], + dim_proj=1024, + dim_upsample=[256, 128, 64], + dim_times_res_block_hidden=2, + num_res_blocks=2, + res_block_norm='group_norm', + last_res_blocks=0, + last_conv_channels=32, + last_conv_size=1, + using_uv=True + ) + else: + self.point_head = LinearPts3d(patch_size=14, dec_embed_dim=1024, output_dim=3) + + # ---------------------- + # Conf Decoder + # ---------------------- + if self.use_conf: + self.conf_decoder = deepcopy(self.point_decoder) + self.conf_head = LinearPts3d(patch_size=14, dec_embed_dim=1024, output_dim=1) + else: + self.conf_decoder = None + self.conf_head = None + + # ---------------------- + # Metric Decoder + # ---------------------- + if self.pi3x and self.pi3x_metric: + self.metric_token = nn.Parameter(torch.randn(1, 1, 2*self.dec_embed_dim)) + self.metric_decoder = ContextOnlyTransformerDecoder( + in_dim=2*self.dec_embed_dim, + dec_embed_dim=512, + dec_num_heads=8, # 8 + out_dim=512, + rope=self.rope, + ) + self.metric_head = nn.Linear(512, 1) + nn.init.normal_(self.metric_token, std=1e-6) + else: + self.metric_token = None + self.metric_decoder = None + self.metric_head = None + + # ---------------------- + # Camera Pose Decoder + # ---------------------- + self.camera_decoder = TransformerDecoder( + in_dim=2*self.dec_embed_dim, + dec_embed_dim=1024, + dec_num_heads=16, # 8 + out_dim=512, + rope=self.rope, + use_checkpoint=False + ) + self.camera_head = CameraHead(dim=512, output_quat=False) + + # For ImageNet Normalize + image_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) + image_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) + + self.register_buffer("image_mean", image_mean) + self.register_buffer("image_std", image_std) + + # ---------------------- + # TTT + # ---------------------- + + self.ttt_layers = None + self.ttt_gate_projs = None + self.ttt_op_order = None + + self.ttt_layers = nn.ModuleList([ + FastWeightGluMLPMultihead( + dim=dec_embed_dim, + head_dim=ttt_head_dim, + inter_multi=ttt_inter_multi, + bias=False, + base_lr=0.01, + muon_update_steps=self.num_muon_update_steps, + use_momentum=self.use_momentum, + ttt_update_steps=self.ttt_update_steps, + ttt_pre_norm=self.ttt_pre_norm, + ) + for _ in self.ttt_insert_after + ]) + self.ttt_gate_projs = nn.ModuleList([ + nn.Linear(dec_embed_dim, 1) + for _ in self.ttt_insert_after + ]) + + for gate_proj in self.ttt_gate_projs: + torch.nn.init.zeros_(gate_proj.weight) + if gate_proj.bias is not None: + torch.nn.init.zeros_(gate_proj.bias) + + self.ttt_op_order = [ + TTTOperator(start=0, end=None, update=False, apply=True), + TTTOperator(start=0, end=None, update=True, apply=False), + ] + + # ---------------------- + # Attention Adapters + # ---------------------- + self.swa_layers = nn.ModuleList([ + BlockRope( + dim=dec_embed_dim, + num_heads=dec_num_heads, + mlp_ratio=ttt_inter_multi, + qkv_bias=True, + proj_bias=True, + ffn_bias=True, + drop_path=0.0, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + act_layer=nn.GELU, + ffn_layer=Mlp, + init_values=0.01, + qk_norm=True, + attn_class=FlashAttentionRope, + rope=self.rope, + ) + for _ in self.attn_insert_after + ]) + self.swa_gate_projs = nn.ModuleList([ + nn.Linear(dec_embed_dim, 1) + for _ in self.attn_insert_after + ]) + + for gate_proj in self.swa_gate_projs: + torch.nn.init.zeros_(gate_proj.weight) + if gate_proj.bias is not None: + torch.nn.init.zeros_(gate_proj.bias) + + def _initialize_ttt_layers_from_global( + self, + layers: Optional[nn.ModuleList], + kind: str, + insert_after: Optional[List[int]] = None, + ) -> None: + """Helper for initializing adapter layers from decoder global attention weights.""" + if layers is None or len(layers) == 0: + print(f"{kind} initialization skipped: no target layers defined.") + return + + insert_positions = insert_after if insert_after is not None else self.ttt_insert_after + if not insert_positions: + print(f"{kind} initialization skipped: no insert positions defined.") + return + + num_decoder_layers = len(self.decoder) + print(f"Initializing {len(layers)} {kind} layers from decoder attention blocks") + print(f" Insert positions: {insert_positions}") + + + for layer_idx, insert_idx in enumerate(insert_positions): + decoder_idx = int(insert_idx) + if decoder_idx % 2 == 0: + decoder_idx += 1 # move to the subsequent global-attention layer + + if decoder_idx >= num_decoder_layers: + raise IndexError( + f"Decoder index {decoder_idx} out of range for {kind} initialization (decoder has {num_decoder_layers} layers)." + ) + + if decoder_idx % 2 == 0: + raise AssertionError( + f"Decoder index {decoder_idx} is not a global-attention layer after adjustment." + ) + + source_layer = self.decoder[decoder_idx] + target_layer = layers[layer_idx] + target_layer.load_state_dict(source_layer.state_dict()) + + print(f" Initialized {kind}_layer[{layer_idx}] from decoder[{decoder_idx}]") + + def _initialize_swa_from_global(self): + if self.swa_layers is None: + return + self._initialize_ttt_layers_from_global(self.swa_layers, "swa", self.attn_insert_after) + + def decode(self, hidden, N, H, W, ttt_dict: Optional[dict] = None, window_size: Optional[int] = None, overlap_size: Optional[int] = None, is_first_window: bool = False, + turn_off_ttt=False, turn_off_swa=False) -> torch.Tensor: + BN, hw, _ = hidden.shape + B = BN // N + + final_output = [] + + hidden = hidden.reshape(B*N, hw, -1) + + register_token = self.register_token.repeat(B, N, 1, 1).reshape(B*N, *self.register_token.shape[-2:]) + + pe_token_0 = getattr(self, 'pe_token_0') # (1, 1, 1, dim) + pe_token_1 = getattr(self, 'pe_token_1') # (1, 1, 1, dim) + pe_token_2 = getattr(self, 'pe_token_2') # (1, 1, 1, dim) + if overlap_size is None or window_size is None: + raise ValueError("overlap_size and window_size must be provided when num_pe_tokens > 0") + num_overlap_with_previous = min(overlap_size, N) + num_other_frames = min(max(window_size - 2 * overlap_size, 0), N - num_overlap_with_previous) + num_overlap_with_later = max(min(overlap_size, N, N - num_overlap_with_previous - num_other_frames), 0) + pe_tokens = torch.cat([ + pe_token_0.repeat(B, num_overlap_with_previous, 1, 1), + pe_token_1.repeat(B, num_other_frames, 1, 1), + pe_token_2.repeat(B, num_overlap_with_later, 1, 1) + ], dim=1).to(hidden.device).to(hidden.dtype).reshape(B*N, *pe_token_0.shape[-2:]) # (B*N, 1, dim) + hidden = torch.cat([pe_tokens, hidden], dim=1) + + # Concatenate special tokens with patch tokens + hidden = torch.cat([register_token, hidden], dim=1) + hw = hidden.shape[1] + + if self.pos_type.startswith('rope'): + pos = self.position_getter(B * N, H//self.patch_size, W//self.patch_size, hidden.device) + + if self.patch_start_idx > 0: + # do not use position embedding for special tokens (camera and register tokens) + # so set pos to 0 for the special tokens + pos = pos + torch.ones_like(pos) + pos_special = torch.zeros(B * N, self.patch_start_idx, 2).to(hidden.device).to(pos.dtype) + pos = torch.cat([pos_special, pos], dim=1) + + ttt_output_info = None + ttt_state = ttt_dict.get("ttt") if ttt_dict is not None else None + attn_state = ttt_dict.get("attn") if ttt_dict is not None else None + gate_scales: List[torch.Tensor] = [] + attn_gate_scales: List[torch.Tensor] = [] + for i in range(len(self.decoder)): + blk = self.decoder[i] + + if i % 2 == 0: + # frame attention + pos_reshaped = pos.reshape(B*N, hw, -1) if pos is not None else None + hidden = hidden.reshape(B*N, hw, -1) + hidden_for_block = hidden + pos_for_block = pos_reshaped + else: + # global attention + pos_reshaped = pos.reshape(B, N*hw, -1) if pos is not None else None + hidden = hidden.reshape(B, N*hw, -1) + hidden_for_block = hidden + pos_for_block = pos_reshaped + + # Save pre-block hidden for the fixed no-skip-residual path. + # With skip0 config removed, default behavior is skip0=False. + layer_skip0 = ( + len(self.ttt_insert_after) == 36 + and i in self.ttt_insert_after + and self.ttt_insert_after.index(i) % 2 == 0 + ) + + if i % 2 == 1 and not layer_skip0: + hidden_before_block = hidden_for_block + elif i % 2 == 0 and layer_skip0: + hidden_before_block = hidden_for_block + else: + hidden_before_block = hidden_for_block # dummy + + hidden = blk(hidden_for_block, xpos=pos_for_block) + + if ttt_state is not None and i in ttt_state.get("insert_after", []): + # Help static analyzers: ensure non-None + assert self.ttt_gate_projs is not None and self.ttt_layers is not None + insert_after_list = ttt_state.get("insert_after", []) + layer_idx = insert_after_list.index(i) + + x_for_residual = hidden.view(B, N, hw, -1) + tokens_post = x_for_residual + tokens_in = tokens_post + + gate_scale = torch.nn.functional.silu(self.ttt_gate_projs[layer_idx](tokens_in)) + # keep the gate scale to be always 0 + # if i <= 19: gate_scale = torch.zeros_like(gate_scale) # turn off ttt + if turn_off_ttt: gate_scale = torch.zeros_like(gate_scale) # turn off ttt + gate_scales.append(gate_scale) + info = { + "ttt_op_order": ttt_state.get("ttt_op_order", []), + "w0": ttt_state["w0"][layer_idx], + "w1": ttt_state["w1"][layer_idx], + "w2": ttt_state["w2"][layer_idx], + } + ttt_output, output = self.ttt_layers[layer_idx](tokens_in, info) + + update_term = ttt_output * gate_scale + + tokens_out = update_term + tokens_post + + hidden = tokens_out + + if ttt_output_info is None: + ttt_output_info = { + "w0": [None] * len(insert_after_list), + "w1": [None] * len(insert_after_list), + "w2": [None] * len(insert_after_list), + } + ttt_output_info["w0"][layer_idx] = output["w0"] + ttt_output_info["w1"][layer_idx] = output["w1"] + ttt_output_info["w2"][layer_idx] = output["w2"] + + # Sliding Window Attention (SWA) + if attn_state is not None and i in attn_state.get("insert_after", []): + assert self.swa_gate_projs is not None and self.swa_layers is not None + insert_after_list = attn_state.get("insert_after", []) + layer_idx = insert_after_list.index(i) + + patch_tokens_post_block = hidden + x_for_residual = patch_tokens_post_block.view(B, N, hw, -1) + x_in = x_for_residual + + history_list = attn_state.get("history", [None] * len(insert_after_list)) + history = history_list[layer_idx] + x_in_for_layer = x_in + + # Prepare position embeddings for current tokens + if pos is not None: + pos_current = pos.reshape(B, N, hw, -1).reshape(B, N * hw, -1) + else: + pos_current = None + + # Check if we have KV cache from history + use_kv_cache = ( + history is not None + and isinstance(history, dict) + and "k" in history + ) + + if use_kv_cache: + # Use KV cache path + k_cache = history["k"] # [B, num_heads, N_hist * hw, head_dim] + v_cache = history["v"] # [B, num_heads, N_hist * hw, head_dim] + # Forward with KV cache + x_curr_flat = x_in_for_layer.reshape(B, N * hw, -1) + swa_output_flat = self.swa_layers[layer_idx].forward_with_kv_cache( + x_curr_flat, k_cache, v_cache, + xpos=pos_current, + ) + swa_output = swa_output_flat.reshape(B, N, hw, -1) + else: + # Original path (no history or legacy format) + # Handle legacy history format (raw tensor instead of dict) + history_raw = history if history is not None and not isinstance(history, dict) else None + + if history_raw is not None: + x_with_history = torch.cat([history_raw, x_in_for_layer], dim=1) + else: + x_with_history = x_in_for_layer + + N_total = x_with_history.shape[1] + x_swa = x_with_history.reshape(B, N_total * hw, -1) + + if pos is not None: + pos_swa = pos.reshape(B, N, hw, -1) + if history_raw is not None: + N_hist = history_raw.shape[1] + pos_hist = pos_swa[:, :1].repeat(1, N_hist, 1, 1) + pos_swa = torch.cat([pos_hist, pos_swa], dim=1) + pos_swa = pos_swa.reshape(B, N_total * hw, -1) + else: + pos_swa = None + + swa_output_full = self.swa_layers[layer_idx]( + x_swa, + xpos=pos_swa, + ) + swa_output_full = swa_output_full.reshape(B, N_total, hw, x_in.shape[-1]) + if history_raw is not None: + N_hist = history_raw.shape[1] + swa_output = swa_output_full[:, N_hist:, :, :] + else: + swa_output = swa_output_full + + gate_scale = torch.nn.functional.silu(self.swa_gate_projs[layer_idx](swa_output)) + if turn_off_swa: gate_scale = torch.zeros_like(gate_scale) + attn_gate_scales.append(gate_scale) + + update_term = swa_output * gate_scale + x_out_patch = update_term + x_for_residual + x_out_patch_flat = x_out_patch.reshape(B, N * hw, -1) + hidden = x_out_patch_flat.reshape(B * N, hw, -1) + + # Store KV cache for next window + # Compute KV for current x_in with history_pe (since it will be history next time) + if ttt_output_info is None: + ttt_output_info = {"history": [None] * len(insert_after_list)} + elif "history" not in ttt_output_info: + ttt_output_info["history"] = [None] * len(insert_after_list) + + x_for_cache = x_in + x_for_cache_flat = x_for_cache.reshape(B, N * hw, -1) + + # Position for cache: use first frame's position repeated (same as original logic) + if pos is not None: + pos_for_cache = pos.reshape(B, N, hw, -1)[:, :1].repeat(1, N, 1, 1).reshape(B, N * hw, -1) + else: + pos_for_cache = None + + k_new, v_new = self.swa_layers[layer_idx].compute_kv_cache(x_for_cache_flat, xpos=pos_for_cache) + + if getattr(self, "detach_swa_history", False): + k_new = k_new.detach() + v_new = v_new.detach() + + ttt_output_info["history"][layer_idx] = {"k": k_new, "v": v_new} + + if i+1 in [len(self.decoder)-1, len(self.decoder)]: + final_output.append(hidden.reshape(B*N, hw, -1)) + + avg_gate_scale = torch.tensor(0.0, device=hidden.device, dtype=torch.float32) + avg_attn_gate_scale: Optional[torch.Tensor] = None + if gate_scales: + all_gate_scales = torch.cat([g.flatten() for g in gate_scales]) + if all_gate_scales.numel() > 0: + avg_gate_scale = all_gate_scales.abs().mean() + if attn_gate_scales: + all_attn_gate_scales = torch.cat([g.flatten() for g in attn_gate_scales]) + if all_attn_gate_scales.numel() > 0: + avg_attn_gate_scale = all_attn_gate_scales.abs().mean() + + if len(final_output) < 2: + raise RuntimeError( + f"Decoder expected to collect two final outputs but got {len(final_output)}." + ) + + return ( + torch.cat([final_output[0], final_output[1]], dim=-1), + (pos.reshape(B*N, hw, -1) if pos is not None else None), + ttt_output_info, + avg_gate_scale, + avg_attn_gate_scale, + gate_scales, + ) + + def forward(self, imgs, *args, **kwargs): + # Windowing controls (optional) + window_size = kwargs.pop('window_size', -1) + overlap_size = kwargs.pop('overlap_size', 1) + num_iterations = kwargs.pop('num_iterations', 1) + no_detach = kwargs.pop('no_detach', False) + sim3 = kwargs.pop('sim3', False) + se3 = kwargs.pop('se3', False) + reset_every = kwargs.pop('reset_every', 0) # reset TTT / adapter state every N windows (0 disables) + turn_off_ttt = kwargs.pop('turn_off_ttt', False) + turn_off_swa = kwargs.pop('turn_off_swa', False) + sim3_scale_mode = kwargs.pop('sim3_scale_mode', 'median') + + if sim3 and se3: + raise ValueError("'sim3' and 'se3' alignments are mutually exclusive; enable only one.") + + # Ensure at least one decode iteration so that 'hidden' is always defined + try: + num_iterations = int(num_iterations) + except Exception: + num_iterations = 1 + if num_iterations < 1: + num_iterations = 1 + try: + reset_every = int(reset_every) + except Exception: + reset_every = 0 + if reset_every < 0: + reset_every = 0 + + # Ensure batch dimension + if imgs.dim() == 4: + imgs = imgs.unsqueeze(0) + + # Normalize + # imgs = (imgs - self.image_mean) / self.image_std + + B, N, C, H, W = imgs.shape + patch_h, patch_w = H // 14, W // 14 + + # --- Unified Windowed Inference --- + if window_size <= 0 or window_size >= N: + windows = [(0, N)] + eff_overlap = 0 + eff_window_size = N + else: + windows = [] + step = max(window_size - overlap_size, 1) + for start_idx in range(0, N, step): + end_idx = min(start_idx + window_size, N) + if end_idx - start_idx >= overlap_size or (end_idx == N and start_idx < N): + windows.append((start_idx, end_idx)) + if end_idx == N: + break + eff_overlap = overlap_size + eff_window_size = window_size + + # Cache the effective window and overlap sizes for downstream merging utilities + self._last_window_size = eff_window_size + self._last_overlap_size = eff_overlap + + # Prepare TTT states across windows + if self.ttt_layers is not None: + w0 = [None] * len(self.ttt_insert_after) + w1 = [None] * len(self.ttt_insert_after) + w2 = [None] * len(self.ttt_insert_after) + else: + w0 = w1 = w2 = None + + # Prepare SWA history states across windows + swa_history = [None] * len(self.attn_insert_after) if self.swa_layers is not None else None + + def reset_adaptive_states(): + """Reset fast-weight TTT states only; SWA history is preserved across resets.""" + nonlocal w0, w1, w2 + if self.ttt_layers is not None: + w0 = [None] * len(self.ttt_insert_after) + w1 = [None] * len(self.ttt_insert_after) + w2 = [None] * len(self.ttt_insert_after) + + all_predictions = [] + all_gate_scales: List[torch.Tensor] = [] + all_attn_gate_scales: List[torch.Tensor] = [] + + windows_iter = windows + for window_idx, (start_idx, end_idx) in enumerate(windows_iter): + if reset_every > 0 and window_idx > 0 and window_idx % reset_every == 0: + reset_adaptive_states() + imgs_w = imgs[:, start_idx:end_idx] # (B, Nw, C, H, W) + imgs_w = imgs_w.to(self.image_mean.device) + imgs_w = (imgs_w - self.image_mean) / self.image_std + Nw = imgs_w.shape[1] + + # Initialize to satisfy static analyzers; will be set inside decode loop + hidden = None # type: ignore[assignment] + pos = None # type: ignore[assignment] + + for _ in range(num_iterations): + if self.ttt_layers is not None and w0 is None: + w0 = [None] * len(self.ttt_insert_after) + w1 = [None] * len(self.ttt_insert_after) + w2 = [None] * len(self.ttt_insert_after) + + if self.swa_layers is not None and swa_history is None: + swa_history = [None] * len(self.attn_insert_after) + + imgs_flat = imgs_w.reshape(B * Nw, C, H, W) + hidden_input = self.encoder(imgs_flat, is_training=True) + if isinstance(hidden_input, dict): + hidden_input = hidden_input["x_norm_patchtokens"] + + # Prepare adapter control dictionaries for decode + ttt_state = None + attn_state = None + + if self.ttt_layers is not None: + ttt_state = { + "ttt_op_order": self.ttt_op_order if self.ttt_op_order is not None else [], + "insert_after": self.ttt_insert_after, + "w0": w0, + "w1": w1, + "w2": w2, + } + + if self.swa_layers is not None: + attn_state = { + "insert_after": self.attn_insert_after, + "history": swa_history, + } + + if ttt_state is None and attn_state is None: + ttt_dict = None + else: + ttt_dict = { + "ttt": ttt_state, + "attn": attn_state, + } + hidden, pos, ttt_output_info, decode_avg_gate_scale, decode_avg_attn_gate_scale, _decode_gate_scales = self.decode( + hidden_input, Nw, H, W, + ttt_dict=ttt_dict, + window_size=window_size, + overlap_size=overlap_size, + is_first_window=(start_idx == 0), + turn_off_ttt=turn_off_ttt, + turn_off_swa=turn_off_swa, + ) + if decode_avg_gate_scale is not None: + all_gate_scales.append(decode_avg_gate_scale.detach().cpu()) + if decode_avg_attn_gate_scale is not None: + all_attn_gate_scales.append(decode_avg_attn_gate_scale.detach().cpu()) + + # TODO: get the updated state from the ttt layer + if self.ttt_layers is not None and ttt_output_info is not None: + w0, w1, w2 = ttt_output_info["w0"], ttt_output_info["w1"], ttt_output_info["w2"] + + # TODO: get the updated history from the swa layer + if ttt_output_info is not None: + swa_history = ttt_output_info.get("history", swa_history) + + # If for some reason decoding didn't produce hidden (e.g., empty window), skip this window + if hidden is None: + continue + + point_hidden = self.point_decoder(hidden, xpos=pos) + if self.use_conf and self.conf_decoder is not None: + conf_hidden = self.conf_decoder(hidden, xpos=pos) + else: + conf_hidden = None + + if self.pi3x and self.pi3x_metric: + hw = hidden.shape[1] + pos_hw = pos.reshape(B, Nw*hw, -1) + metric_hidden = self.metric_decoder(self.metric_token.repeat(B, 1, 1), hidden.reshape(B, Nw*hw, -1), xpos=pos_hw[:, 0:1], ypos=pos_hw) + else: + metric_hidden = None + + camera_hidden = self.camera_decoder(hidden, xpos=pos) + + global_camera_hidden = camera_hidden + + with torch.autocast(device_type='cuda', enabled=False): + # local points + point_hidden = point_hidden.float() + if self.pi3x: + xy, z = self.point_head(point_hidden[:, self.patch_start_idx:], patch_h=patch_h, patch_w=patch_w) + xy = xy.permute(0, 2, 3, 1).reshape(B, Nw, H, W, -1) + z = z.permute(0, 2, 3, 1).reshape(B, Nw, H, W, -1) + z = torch.exp(z.clamp(max=15.0)) + local_points = torch.cat([xy * z, z], dim=-1) + else: + ret = self.point_head([point_hidden[:, self.patch_start_idx:]], (H, W)).reshape(B, Nw, H, W, -1) + xy, z = ret.split([2, 1], dim=-1) + z = torch.exp(z) + local_points = torch.cat([xy * z, z], dim=-1) + + # confidence + if conf_hidden is not None and self.conf_head is not None: + conf_hidden = conf_hidden.float() + conf = self.conf_head([conf_hidden[:, self.patch_start_idx:]], (H, W)).reshape(B, Nw, H, W, -1) + else: + conf = None + + # camera + global_camera_hidden = global_camera_hidden.float() + camera_poses = self.camera_head(global_camera_hidden[:, self.patch_start_idx:], patch_h, patch_w).reshape(B, Nw, 4, 4) + camera_qvec = None + local_camera_poses = None + local_camera_qvec = None + + # metric + if self.pi3x and self.pi3x_metric and metric_hidden is not None: + metric = self.metric_head(metric_hidden.float()).reshape(B).exp() + + # apply metric to points and camera poses + # points = torch.einsum('bnij, bnhwj -> bnhwi', camera_poses, homogenize_points(local_points))[..., :3] * metric.view(B, 1, 1, 1, 1) + camera_poses[..., :3, 3] = camera_poses[..., :3, 3] * metric.view(B, 1, 1) + local_points = local_points * metric.view(B, 1, 1, 1, 1) + if local_camera_poses is not None: + local_camera_poses[..., :3, 3] = local_camera_poses[..., :3, 3] * metric.view(B, 1, 1) + else: + metric = None + + + # unproject local points using camera poses + with torch.autocast(device_type='cuda', enabled=False): + points = torch.einsum('bnij, bnhwj -> bnhwi', camera_poses, homogenize_points(local_points))[..., :3] + + + def maybe_detach(t, no_detach=no_detach): + if t is None: + return None + return t if self.training or no_detach else t.detach().cpu() + + pred_dict = dict( + points=maybe_detach(points, no_detach=no_detach), + local_points=maybe_detach(local_points, no_detach=no_detach), + conf=maybe_detach(conf, no_detach=no_detach), + camera_poses=maybe_detach(camera_poses, no_detach=no_detach), + local_camera_poses=maybe_detach(local_camera_poses, no_detach=no_detach), + camera_qvec=maybe_detach(camera_qvec, no_detach=no_detach), + local_camera_qvec=maybe_detach(local_camera_qvec, no_detach=no_detach), + metric=maybe_detach(metric, no_detach=no_detach), + ) + all_predictions.append(pred_dict) + + # Merge windowed predictions + # When reset is enabled but explicit Sim3/SE3 alignment is off, keep each reset block + # in a stable rigid frame by applying one estimated transform per block. + align_on_resets_without_explicit_pose = reset_every > 0 and not sim3 and not se3 + if sim3: + merged = self._merge_windowed_predictions_sim3( + all_predictions, + allow_scale=True, + scale_mode=sim3_scale_mode, + ) + elif se3 or align_on_resets_without_explicit_pose: + merged = self._merge_windowed_predictions_sim3( + all_predictions, + allow_scale=False, + reset_every=reset_every, + reuse_transform_within_reset_block=align_on_resets_without_explicit_pose, + ) + else: + merged = self._merge_windowed_predictions(all_predictions, eff_window_size, eff_overlap) + if all_gate_scales: + merged["avg_gate_scale"] = torch.stack(all_gate_scales).mean() + if all_attn_gate_scales: + merged["attn_gate_scale"] = torch.stack(all_attn_gate_scales).mean() + + return merged + + def _merge_windowed_predictions(self, all_predictions, window_size, overlap_size): + """ + Merge predictions from multiple windows by concatenating along the time dimension + while removing overlapping frames. + """ + if not all_predictions: + return {} + if len(all_predictions) == 1: + return all_predictions[0] + + merged_predictions = {} + keys = list(all_predictions[0].keys()) + sequence_keys = {"points", "local_points", "conf", "camera_poses", "local_camera_poses", "camera_qvec", "local_camera_qvec"} + for key in keys: + # Collect window tensors + window_tensors = [pred.get(key, None) for pred in all_predictions] + + # Skip if all windows have None for this key + if all(t is None for t in window_tensors): + continue + + # Only perform overlap-aware concatenation for known sequence-shaped tensors + if key in sequence_keys: + # Filter out None windows safely while preserving positions for slicing + result_parts = [] + + # First window: drop last overlap_size frames + first = window_tensors[0] + if first is not None: + if overlap_size > 0 and first.shape[1] > overlap_size: + result_parts.append(first[:, :-overlap_size]) + elif overlap_size > 0 and first.shape[1] <= overlap_size: + # If window shorter or equal to overlap, drop completely + pass + else: + result_parts.append(first) + + # Middle windows: drop last overlap_size frames + for tensor in window_tensors[1:-1]: + if tensor is None: + continue + if overlap_size > 0 and tensor.shape[1] > overlap_size: + result_parts.append(tensor[:, :-overlap_size]) + elif overlap_size > 0 and tensor.shape[1] <= overlap_size: + # If window shorter or equal to overlap, drop completely + continue + else: + result_parts.append(tensor) + + # Last window: keep all frames + last_tensor = window_tensors[-1] + if last_tensor is not None: + result_parts.append(last_tensor) + + if result_parts: + merged_predictions[key] = torch.cat(result_parts, dim=1) + else: + # Fallback: if everything was dropped due to tiny windows, keep last non-None + for t in reversed(window_tensors): + if t is not None: + merged_predictions[key] = t + break + else: + # Non-sequence keys: keep the last non-None + for t in reversed(window_tensors): + if t is not None: + merged_predictions[key] = t + break + + # Instead of computing overlap losses here, export overlap prev/next tensors for trainer-side chunk losses + if overlap_size > 0 and len(all_predictions) > 1: + prev_cam_chunks = [] + next_cam_chunks = [] + prev_pcd_chunks = [] + next_pcd_chunks = [] + next_conf_chunks = [] + + for i in range(len(all_predictions) - 1): + pred_a = all_predictions[i] + pred_b = all_predictions[i + 1] + + cam_a = pred_a.get("camera_poses", None) + cam_b = pred_b.get("camera_poses", None) + lpts_a = pred_a.get("local_points", None) + lpts_b = pred_b.get("local_points", None) + conf_a = pred_a.get("conf", None) + conf_b = pred_b.get("conf", None) + + # Only collect when both sides have enough frames for a full overlap window + if cam_a is not None and cam_b is not None and cam_a.shape[1] >= overlap_size and cam_b.shape[1] >= overlap_size: + S_a = cam_a.shape[1] + # Take last overlap_size from A and first overlap_size from B + prev_cam_chunks.append(cam_a[:, S_a - overlap_size: S_a]) # (B, O, 4, 4) + next_cam_chunks.append(cam_b[:, 0: overlap_size]) # (B, O, 4, 4) + + if lpts_a is not None and lpts_b is not None and lpts_a.shape[1] >= overlap_size and lpts_b.shape[1] >= overlap_size: + S_a = lpts_a.shape[1] + prev_pcd_chunks.append(lpts_a[:, S_a - overlap_size: S_a]) # (B, O, H, W, 3) + next_pcd_chunks.append(lpts_b[:, 0: overlap_size]) # (B, O, H, W, 3) + if conf_b is not None and conf_b.shape[1] >= overlap_size: + next_conf_chunks.append(conf_b[:, 0: overlap_size].squeeze(-1)) # (B, O, H, W) + + # Stack along a new chunk dimension if any collected + if prev_cam_chunks and next_cam_chunks: + merged_predictions["overlap_prev_cam"] = torch.stack(prev_cam_chunks, dim=1) # (B, K, O, 4, 4) + merged_predictions["overlap_next_cam"] = torch.stack(next_cam_chunks, dim=1) # (B, K, O, 4, 4) + if prev_pcd_chunks and next_pcd_chunks: + merged_predictions["overlap_prev_pcd"] = torch.stack(prev_pcd_chunks, dim=1) # (B, K, O, H, W, 3) + merged_predictions["overlap_next_pcd"] = torch.stack(next_pcd_chunks, dim=1) # (B, K, O, H, W, 3) + if next_conf_chunks: + merged_predictions["overlap_next_conf"] = torch.stack(next_conf_chunks, dim=1) # (B, K, O, H, W) + + return merged_predictions + + def _merge_windowed_predictions_sim3( + self, + all_predictions, + allow_scale: bool = True, + scale_mode: str = 'median', + reset_every: int = 0, + reuse_transform_within_reset_block: bool = False, + ): + """ + Merge windowed predictions by estimating relative poses between overlaps. + When ``allow_scale`` is True this performs Sim(3) alignment (scale+SE(3)); + when False it reduces to SE(3) alignment by keeping the scale fixed to 1. + If ``reuse_transform_within_reset_block`` is enabled with ``reset_every > 0``, + one transform is estimated at each reset boundary and reused for the rest of + that reset block. + """ + # print("allow_scale -----------------------------", allow_scale) + if not all_predictions: + return {} + if len(all_predictions) == 1: + return all_predictions[0] + + # Locate a reference tensor to determine batch/device/dtype information + sample_tensor = None + for pred in all_predictions: + for key in ("points", "camera_poses", "local_points", "conf"): + tensor = pred.get(key, None) + if tensor is not None: + sample_tensor = tensor + break + if sample_tensor is not None: + break + if sample_tensor is None: + raise ValueError("Sim3 merge requires at least one tensor prediction") + + device = sample_tensor.device + dtype = sample_tensor.dtype + batch_size = sample_tensor.shape[0] + + identity_rot = torch.eye(3, device=device, dtype=dtype).unsqueeze(0).repeat(batch_size, 1, 1) + zero_trans = torch.zeros(batch_size, 3, device=device, dtype=dtype) + one_scale = torch.ones(batch_size, device=device, dtype=dtype) + + aligned_predictions: List[dict] = [] + sim3_scales: Optional[List[torch.Tensor]] = [] if allow_scale else None + sim3_poses: List[torch.Tensor] = [] + + window_size = getattr(self, "_last_window_size", -1) + overlap_size = getattr(self, "_last_overlap_size", 0) + + def _estimate_relative_sim3(prev_aligned: dict, curr_raw: dict, overlap: int, current_allow_scale: bool, forced_scale: Optional[torch.Tensor] = None): + if overlap <= 0: + return torch.ones_like(one_scale), identity_rot, zero_trans + + prev_cam = prev_aligned.get("camera_poses", None) + curr_cam = curr_raw.get("camera_poses", None) + if prev_cam is None or curr_cam is None or prev_cam.shape[1] == 0 or curr_cam.shape[1] == 0: + return torch.ones_like(one_scale), identity_rot, zero_trans + + prev_frames = prev_cam.shape[1] + prev_idx = max(prev_frames - overlap, 0) + + prev_pose = prev_cam[:, prev_idx] + curr_pose = curr_cam[:, 0] + + R_prev = prev_pose[:, :3, :3] + t_prev = prev_pose[:, :3, 3] + R_curr = curr_pose[:, :3, :3] + t_curr = curr_pose[:, :3, 3] + + relative_rot = torch.matmul(R_prev, R_curr.transpose(-1, -2)) + + relative_scale = torch.ones_like(one_scale) + if forced_scale is not None: + relative_scale = forced_scale + elif current_allow_scale: + prev_local_raw = prev_aligned.get("local_points", None) + if prev_local_raw is None: + prev_local_raw = prev_aligned.get("_local_points_raw", None) + curr_local_raw = curr_raw.get("local_points", None) + + if ( + prev_local_raw is not None + and curr_local_raw is not None + and prev_local_raw.shape[1] > prev_idx + and curr_local_raw.shape[1] > 0 + ): + if scale_mode in ['median_all', 'trimmed_mean_all']: + # Use all overlapping frames + actual_overlap = min(overlap, prev_local_raw.shape[1] - prev_idx, curr_local_raw.shape[1]) + if actual_overlap > 0: + prev_depth = prev_local_raw[:, prev_idx : prev_idx + actual_overlap, ..., 2] + curr_depth = curr_local_raw[:, :actual_overlap, ..., 2] + else: + # Fallback to single frame if overlap calculation fails (should not happen given checks above) + prev_depth = prev_local_raw[:, prev_idx, ..., 2] + curr_depth = curr_local_raw[:, 0, ..., 2] + else: + # Use only the first overlapping frame (standard behavior) + prev_depth = prev_local_raw[:, prev_idx, ..., 2] + curr_depth = curr_local_raw[:, 0, ..., 2] + + prev_depth_f32 = prev_depth.to(torch.float32) + curr_depth_f32 = curr_depth.to(torch.float32) + eps_depth = torch.finfo(torch.float32).eps + valid = ( + torch.isfinite(prev_depth_f32) + & torch.isfinite(curr_depth_f32) + & (curr_depth_f32.abs() > eps_depth) + ) + + prev_depth_flat = prev_depth_f32.reshape(batch_size, -1) + curr_depth_flat = curr_depth_f32.reshape(batch_size, -1) + valid_flat = valid.reshape(batch_size, -1) + + if scale_mode in ['median', 'median_all']: + scale_values = [] + for b in range(batch_size): + valid_idx = valid_flat[b] + if valid_idx.any(): + ratios = prev_depth_flat[b, valid_idx] / curr_depth_flat[b, valid_idx] + scale_values.append(ratios.median()) + else: + scale_values.append(torch.tensor(1.0, device=device, dtype=torch.float32)) + relative_scale = torch.stack(scale_values).to(dtype) + elif scale_mode in ['trimmed_mean', 'trimmed_mean_all']: + # Vectorized implementation for trimmed mean + # Mask invalid entries with NaN or filter before passing? + # robust_scale_estimation expects (B, N) + # Since N varies per batch due to validity, we might still need a loop or careful padding. + # However, valid_flat is (B, N_pixels). + + # To keep it simple and consistent with the median loop structure for now (which handles varying valid counts per batch): + scale_values = [] + for b in range(batch_size): + valid_idx = valid_flat[b] + if valid_idx.any(): + ratios = prev_depth_flat[b, valid_idx] / curr_depth_flat[b, valid_idx] + # ratios is 1D tensor of valid pixels + # We need to pass (1, N) to robust_scale_estimation to reuse it, or just use it directly if we modify it to handle 1D + # robust_scale_estimation expects (B, N). Let's reshape. + scale_val = robust_scale_estimation(ratios.unsqueeze(0), trim_ratio=0.25).squeeze(0) + scale_values.append(scale_val) + else: + scale_values.append(torch.tensor(1.0, device=device, dtype=torch.float32)) + relative_scale = torch.stack(scale_values).to(dtype) + elif scale_mode in ['sim3_avg1']: + scale_values = [] + for b in range(batch_size): + valid_idx = valid_flat[b] + if valid_idx.any(): + ratios = prev_depth_flat[b, valid_idx] / curr_depth_flat[b, valid_idx] + scale_values.append(ratios.median()) + else: + scale_values.append(torch.tensor(1.0, device=device, dtype=torch.float32)) + relative_scale = torch.stack(scale_values).to(dtype) + relative_scale = (relative_scale + 1.0) / 2.0 + else: + raise ValueError(f"Unknown scale_mode: {scale_mode}") + + relative_scale = torch.clamp(relative_scale, min=1e-3, max=1e3) + + rotated_curr_centers = torch.matmul(relative_rot, t_curr.unsqueeze(-1)).squeeze(-1) + relative_trans = t_prev - relative_scale.unsqueeze(-1) * rotated_curr_centers + + return relative_scale, relative_rot.to(dtype), relative_trans.to(dtype) + + block_scale: Optional[torch.Tensor] = None + block_rot: Optional[torch.Tensor] = None + block_trans: Optional[torch.Tensor] = None + + for window_idx, pred in enumerate(all_predictions): + if window_idx == 0: + current_scale = torch.ones_like(one_scale) + current_rot = identity_rot.clone() + current_trans = zero_trans.clone() + if reuse_transform_within_reset_block and reset_every > 0: + block_scale = current_scale.clone() + block_rot = current_rot.clone() + block_trans = current_trans.clone() + else: + prev_aligned = aligned_predictions[-1] + reuse_block_transform = ( + reuse_transform_within_reset_block + and reset_every > 0 + and window_idx % reset_every != 0 + and block_rot is not None + and block_trans is not None + ) + if reuse_block_transform: + current_rot = block_rot.clone() + current_trans = block_trans.clone() + if allow_scale and block_scale is not None: + current_scale = block_scale.clone() + else: + current_scale = torch.ones_like(one_scale) + else: + current_scale, current_rot, current_trans = _estimate_relative_sim3( + prev_aligned, pred, overlap_size, allow_scale + ) + if reuse_transform_within_reset_block and reset_every > 0: + block_scale = current_scale.clone() + block_rot = current_rot.clone() + block_trans = current_trans.clone() + + if allow_scale and sim3_scales is not None: + sim3_scales.append(current_scale.clone()) + # print(current_scale, 'current_scale-----------------') + pose_mat = torch.eye(4, device=device, dtype=dtype).unsqueeze(0).repeat(batch_size, 1, 1) + pose_mat[:, :3, :3] = current_rot + pose_mat[:, :3, 3] = current_trans + sim3_poses.append(pose_mat) + + aligned_pred: dict = {} + + original_local_points = pred.get("local_points", None) + aligned_pred["_local_points_raw"] = original_local_points + + if original_local_points is not None: + if allow_scale: # Keep using global allow_scale for applying scale if we have it, or maybe we should track per-window scale application? + # Actually, current_scale will be 1.0 if current_allow_scale was False. + # So we can just always apply current_scale. + scale_factor = current_scale.view(batch_size, 1, 1, 1, 1) + aligned_local_points = original_local_points * scale_factor + else: + aligned_local_points = original_local_points + else: + aligned_local_points = None + aligned_pred["local_points"] = aligned_local_points + + def _transform_camera(cam_tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + if cam_tensor is None: + return None + frames = cam_tensor.shape[1] + rot_local = cam_tensor[..., :3, :3] + trans_local = cam_tensor[..., :3, 3] + rot_global = torch.matmul( + current_rot.unsqueeze(1).expand(-1, frames, -1, -1), + rot_local + ) + rotated_trans = torch.matmul( + current_rot.unsqueeze(1).expand(-1, frames, -1, -1), + trans_local.unsqueeze(-1) + ).squeeze(-1) + if allow_scale: + rotated_trans = rotated_trans * current_scale.view(batch_size, 1, 1) + trans_global = rotated_trans + current_trans.unsqueeze(1) + cam_out = cam_tensor.clone() + cam_out[..., :3, :3] = rot_global + cam_out[..., :3, 3] = trans_global + return cam_out + + camera_global = _transform_camera(pred.get("camera_poses", None)) + aligned_pred["camera_poses"] = camera_global + + local_camera_global = _transform_camera(pred.get("local_camera_poses", None)) + aligned_pred["local_camera_poses"] = local_camera_global + + if camera_global is not None and aligned_local_points is not None: + aligned_points = torch.einsum( + 'bnij, bnhwj -> bnhwi', + camera_global, + homogenize_points(aligned_local_points) + )[..., :3] + else: + points = pred.get("points", None) + if points is not None: + rotated_points = torch.einsum('bij, bnhwj -> bnhwi', current_rot, points) + if allow_scale: + rotated_points = rotated_points * current_scale.view(batch_size, 1, 1, 1, 1) + aligned_points = rotated_points + current_trans.view(batch_size, 1, 1, 1, 3) + else: + aligned_points = None + aligned_pred["points"] = aligned_points + + aligned_pred["conf"] = pred.get("conf", None) + + for key, value in pred.items(): + if key in aligned_pred: + continue + aligned_pred[key] = value + + aligned_predictions.append(aligned_pred) + + aligned_predictions_clean = [] + for pred in aligned_predictions: + cleaned = pred.copy() + cleaned.pop("_local_points_raw", None) + aligned_predictions_clean.append(cleaned) + + merged = self._merge_windowed_predictions(aligned_predictions_clean, window_size, overlap_size) + + pose_key = "chunk_sim3_poses" if allow_scale else "chunk_se3_poses" + if allow_scale and sim3_scales: + merged["chunk_sim3_scales"] = torch.stack(sim3_scales, dim=1) + if sim3_poses: + merged[pose_key] = torch.stack(sim3_poses, dim=1) + merged["alignment_mode"] = "sim3" if allow_scale else "se3" + + return merged \ No newline at end of file diff --git a/src/openworldlib/representations/point_clouds_generation/pi3/loger/ttt.py b/src/openworldlib/representations/point_clouds_generation/pi3/loger/ttt.py new file mode 100644 index 00000000..22bf0965 --- /dev/null +++ b/src/openworldlib/representations/point_clouds_generation/pi3/loger/ttt.py @@ -0,0 +1,323 @@ +# LaCT + +import math +from einops import rearrange +import torch +import torch.nn as nn +from torch.nn import functional as F + +import collections + + +TTTOperator = collections.namedtuple("TTTOperator", ["start", "end", "update", "apply"]) + + +def inv_softplus(x): + y = x + math.log(-math.expm1(-x)) + return y + + +def silu_backprop(dy: torch.Tensor, x: torch.Tensor): + """ + Args: + dy: [b, d, l], gradient of the outer loss wrt the y + x: [b, d, l], input of the silu activation + outs: + dx: [b, d, l], gradient of the outer loss wrt the x + dx = dy * sigma * (1 + x * (1 - sigma)) + """ + sigma = torch.sigmoid(x) + dx = dy * sigma * (1 + x * (1 - sigma)) + return dx + + +@torch.compile +def zeropower_via_newtonschulz5(G, steps): + """ + modified from https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py#L49 + Major change: G is [b, d, d] rather than [d, d] + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + Args: + G: [b, d, d] + steps: int + Returns: + X: [b, d, d] + """ + # TODO: log the update loss + assert len(G.shape) == 3 + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(1) > G.size(2): + X = X.transpose(1, 2) + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(1, 2), keepdim=True) + 1e-7) + # Perform the NS iterations + for _ in range(steps): + A = X @ X.transpose(1, 2) + B = ( + b * A + c * A @ A + ) # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng + X = a * X + B @ X + + if G.size(1) > G.size(2): + X = X.transpose(1, 2) + return X + + + +@torch.compile +# TODO: add a version that uses the torch.compile +def fast_weight_swish_glu_weight_norm_mini_batch_apply( + w0: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + lr0: torch.Tensor, + lr1: torch.Tensor, + lr2: torch.Tensor, + ttt_ua_order: list, + muon_update_steps: int = 0, + momentum: torch.Tensor | None = None, + ttt_update_steps: int = 1, +): + """ + Note: + Forward: + (silu(x @ w0) * (x @ w2)) @ w1 + + w0, w2: [b, d, dh] + w1: [b, dh, d] + q: [b, l, d] + k: [b, l, d] + v: [b, l, d] + lr0, lr1, lr2: [b, l, 1] + + """ + w0_norm = w0.detach().norm(dim=1, keepdim=True) + w1_norm = w1.detach().norm(dim=1, keepdim=True) + w2_norm = w2.detach().norm(dim=1, keepdim=True) + + if momentum is not None: + dw0_momentum = torch.zeros_like(w0) + dw1_momentum = torch.zeros_like(w1) + dw2_momentum = torch.zeros_like(w2) + + output = [] + + for start, end, update, apply in ttt_ua_order: + w0_now, w1_now, w2_now = w0, w1, w2 + + if update: + ki, vi = k[:, start:end, :], v[:, start:end, :] # bf16 [b, l, d] + + lr0i = lr0[:, start:end, :] # [b, l, d/1] fp32 + lr1i = lr1[:, start:end, :] # [b, l, d/1] fp32 + lr2i = lr2[:, start:end, :] # [b, l, d/1] fp32 + + gate_before_act = ki @ w0_now # b[b, l, dh] = [b, l, d] @ [b, d, dh] + hidden_before_mul = ki @ w2_now # b[b, l, dh] = [b, l, d] @ [b, d, dh] + hidden = F.silu(gate_before_act, inplace=False) * hidden_before_mul + + for _ in range(ttt_update_steps): + # Fixed objective: neg_dot_product (gradient ascent) + dhidden = vi @ w1_now.transpose(-1, -2) # [b, l, dh] = [b, l, d] @ [b, d, dh] + dhidden_before_mul = dhidden * F.silu(gate_before_act, inplace=False) + dgate = dhidden * hidden_before_mul + dgate_before_act = silu_backprop(dgate, gate_before_act) + + w1_grad = zeropower_via_newtonschulz5( + (hidden * lr1i).transpose(-1, -2) @ vi, muon_update_steps + ) + w0_grad = zeropower_via_newtonschulz5( + (ki * lr0i).transpose(-1, -2) @ dgate_before_act, muon_update_steps + ) + w2_grad = zeropower_via_newtonschulz5( + (ki * lr2i).transpose(-1, -2) @ dhidden_before_mul, muon_update_steps + ) + + if momentum is not None: + m_i = momentum[:, start:end, :].mean(dim=1, keepdim=True) + w0_grad = w0_grad + dw0_momentum * m_i + w1_grad = w1_grad + dw1_momentum * m_i + w2_grad = w2_grad + dw2_momentum * m_i + dw0_momentum = w0_grad + dw1_momentum = w1_grad + dw2_momentum = w2_grad + + # Gradient ascent: add gradients + w1_now = w1_now + w1_grad + w0_now = w0_now + w0_grad + w2_now = w2_now + w2_grad + + # do weight norm here + w0_now = w0_now / (w0_now.norm(dim=1, keepdim=True) + 1e-5) * w0_norm + w1_now = w1_now / (w1_now.norm(dim=1, keepdim=True) + 1e-5) * w1_norm + w2_now = w2_now / (w2_now.norm(dim=1, keepdim=True) + 1e-5) * w2_norm + + w0, w1, w2 = w0_now, w1_now, w2_now + + if apply: + # Only calculate the output in the last repeat. + qi = q[:, start:end, :] + oi = (F.silu(qi @ w0_now, inplace=True) * (qi @ w2_now)) @ w1_now + output.append(oi) + + output = torch.cat(output, dim=1) + + return output, w0, w1, w2 + + +class FastWeightGluMLPMultihead(nn.Module): + """ + On init of fast_weight: + + Let's start with the magnitude of the value. + value_proj is initialized with uniform distribution with range [-1.0/sqrt(d), 1.0/sqrt(d)] + x is layernormed. So during init, value is unit norm total (not per head, per head is 1.0/sqrt(num_head)) + After silu, value is around norm of 2.7 per head. (why? seems wired) + + Then for the fast weight, assume initial lr = 0. + Then with l2_norm of q,k, input is unit normed. + if w0 is initialized with kaiming, relu(w0 @ q) is unit normed. + Then w1 is initialized with kaiming, so w1 @ relu(w0 @ q) is of norm sqrt(2) per head + Since I compute total norm, it is sqrt(2) * sqrt(num_head), which is around 2.7 for dim=512, num_head=4. + """ + + def __init__( + self, + dim: int, + head_dim: int, + inter_multi: int = 1, + bias: bool = False, + base_lr=0.01, + muon_update_steps=0, + use_momentum: bool = False, + ttt_update_steps: int = 1, + ttt_pre_norm: bool = False, + ): + super().__init__() + self.dim = dim + assert dim % head_dim == 0 + self.num_heads = dim // head_dim + self.muon_update_steps = muon_update_steps + self.use_momentum = use_momentum + self.ttt_update_steps = ttt_update_steps + self.ttt_pre_norm = ttt_pre_norm + + d_in = d_out = head_dim + d_h = int(head_dim * inter_multi) + + gain = math.sqrt(2) # for relu activations + self.w0 = nn.Parameter( + torch.randn(self.num_heads, d_in, d_h) * gain / math.sqrt(d_in) + ) # [d_h * num_heads, d_in] + self.w1 = nn.Parameter( + torch.randn(self.num_heads, d_h, d_out) * gain / math.sqrt(d_h) + ) # [d_in * num_heads, d_h] + self.w2 = nn.Parameter( + torch.randn(self.num_heads, d_in, d_h) * gain / math.sqrt(d_in) + ) # [d_h * num_heads, d_in] + + self.to_qkv = nn.Linear(dim, 3 * dim, bias=bias) + # Backward-compatibility for old checkpoints that contain + # "to_qkv_stack.0.*" even though we now use a fixed single-path forward. + self.to_qkv_stack = nn.Sequential(nn.Linear(dim, 3 * dim, bias=bias)) + self.c_proj = nn.Linear(dim, dim, bias=bias) + + self.lr_dim = self.num_heads + if self.use_momentum: + self.lr_fc = nn.Linear(dim, self.lr_dim * 3 + 1) + else: + self.lr_fc = nn.Linear(dim, self.lr_dim * 3) + self.base_lr_inv = inv_softplus(base_lr) + + if self.ttt_pre_norm: + self.pre_norm = torch.nn.RMSNorm(dim, eps=1e-5, elementwise_affine=True) + + self.o_norm = torch.nn.RMSNorm(head_dim, eps=1e-5, elementwise_affine=True) + + def forward(self, x: torch.Tensor, info: dict | None = None, *args): + """ + x: (b, t, l, d) -> (b, t*l, d) + """ + num_dims = len(x.shape) + if num_dims == 3: + x = x.unsqueeze(1) + + b, t, l, d = x.shape + + if self.ttt_pre_norm: + x = self.pre_norm(x) + + x = rearrange(x, "b t l d -> b (t l) d") + qkv = F.silu(self.to_qkv(x), inplace=True) + + q, k, v = rearrange( + qkv, "b l (qkv h d) -> qkv (b h) l d", + qkv=3, h=self.num_heads + ) + q = q / (q.norm(dim=2, keepdim=True) + 1e-5) + k = k / (k.norm(dim=2, keepdim=True) + 1e-5) + + with torch.autocast(device_type="cuda", enabled=False): + lr = self.lr_fc(x.float()) # [b, l, lr_dim] + + if self.use_momentum: + momentum = torch.sigmoid(lr[..., -1:]) + lr = lr[..., :-1] + else: + momentum = None + + lr = torch.nn.functional.softplus(lr.float() + self.base_lr_inv) + + lr0, lr1, lr2 = rearrange( + lr, "b l (lrs h d) -> lrs (b h) l d", + lrs=3, h=self.num_heads + ) + + if info and "w0" in info and info.get("w0") is not None: + assert "w1" in info and "w2" in info + w0 = info["w0"] + w1 = info["w1"] + w2 = info["w2"] + else: + w0 = self.w0.repeat(x.shape[0], 1, 1) + w1 = self.w1.repeat(x.shape[0], 1, 1) + w2 = self.w2.repeat(x.shape[0], 1, 1) + + output, w0, w1, w2 = fast_weight_swish_glu_weight_norm_mini_batch_apply( + w0, w1, w2, q, k, v, lr0, lr1, lr2, + info["ttt_op_order"] if info else [], + muon_update_steps=self.muon_update_steps, + momentum=momentum, + ttt_update_steps=self.ttt_update_steps, + ) + + output = self.o_norm(output) + + output = rearrange( + output, "(b h) l d -> b l (h d)", h=self.num_heads, b=x.shape[0] + ) + + output = self.c_proj(output) + output = rearrange(output, "b (t l) d -> b t l d", t=t).to(x.dtype) + + if num_dims == 3: + output = rearrange(output, "b t l d -> b (t l) d", t=t) + + return output, { + "w0": w0, "w1": w1, "w2": w2, + } + + def extra_repr(self) -> str: + return (f"w0 shape: {self.w0.shape}, w1 shape: {self.w1.shape}, w2 shape: {self.w2.shape}, " + f"Muon update steps: {self.muon_update_steps}, " + f"Base lr: {math.log(1 + math.exp(self.base_lr_inv))}, ") diff --git a/src/openworldlib/representations/point_clouds_generation/pi3/loger/utils/basic.py b/src/openworldlib/representations/point_clouds_generation/pi3/loger/utils/basic.py new file mode 100644 index 00000000..759439ea --- /dev/null +++ b/src/openworldlib/representations/point_clouds_generation/pi3/loger/utils/basic.py @@ -0,0 +1,90 @@ +import os +import os.path as osp +import math +import cv2 +from PIL import Image +import torch +from torchvision import transforms +from plyfile import PlyData, PlyElement +import numpy as np + +from ...pi3.utils.basic import ( + load_multimodal_data, tensor_to_pil, array_to_pil, rotate_target_dim_to_last_axis, write_ply +) + +def load_images_as_tensor(path='data/truck', interval=1, PIXEL_LIMIT=255000, Target_W=None, Target_H=None): + """ + Loads images from a directory or video, resizes them to a uniform size, + then converts and stacks them into a single [N, 3, H, W] PyTorch tensor. + """ + sources = [] + + # --- 1. Load image paths or video frames --- + if osp.isdir(path): + print(f"Loading images from directory: {path}") + filenames = sorted([x for x in os.listdir(path) if x.lower().endswith(('.png', '.jpg', '.jpeg'))]) + for i in range(0, len(filenames), interval): + img_path = osp.join(path, filenames[i]) + try: + sources.append(Image.open(img_path).convert('RGB')) + except Exception as e: + print(f"Could not load image {filenames[i]}: {e}") + elif path.lower().endswith('.mp4'): + print(f"Loading frames from video: {path}") + cap = cv2.VideoCapture(path) + if not cap.isOpened(): raise IOError(f"Cannot open video file: {path}") + frame_idx = 0 + while True: + ret, frame = cap.read() + if not ret: break + if frame_idx % interval == 0: + rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + sources.append(Image.fromarray(rgb_frame)) + frame_idx += 1 + cap.release() + else: + raise ValueError(f"Unsupported path. Must be a directory or a .mp4 file: {path}") + + if not sources: + print("No images found or loaded.") + return torch.empty(0) + + print(f"Found {len(sources)} images/frames. Processing...") + + # --- 2. Determine a uniform target size for all images based on the first image --- + # This is necessary to ensure all tensors have the same dimensions for stacking. + if Target_W is None and Target_H is None: + first_img = sources[0] + W_orig, H_orig = first_img.size + scale = math.sqrt(PIXEL_LIMIT / (W_orig * H_orig)) if W_orig * H_orig > 0 else 1 + W_target, H_target = W_orig * scale, H_orig * scale + k, m = round(W_target / 14), round(H_target / 14) + while (k * 14) * (m * 14) > PIXEL_LIMIT: + if k / m > W_target / H_target: k -= 1 + else: m -= 1 + TARGET_W, TARGET_H = max(1, k) * 14, max(1, m) * 14 + else: + TARGET_W, TARGET_H = Target_W, Target_H + print(f"All images will be resized to a uniform size: ({TARGET_W}, {TARGET_H})") + + # --- 3. Resize images and convert them to tensors in the [0, 1] range --- + tensor_list = [] + # Define a transform to convert a PIL Image to a CxHxW tensor and normalize to [0,1] + to_tensor_transform = transforms.ToTensor() + + for img_pil in sources: + try: + # Resize to the uniform target size + resized_img = img_pil.resize((TARGET_W, TARGET_H), Image.Resampling.LANCZOS) + # Convert to tensor + img_tensor = to_tensor_transform(resized_img) + tensor_list.append(img_tensor) + except Exception as e: + print(f"Error processing an image: {e}") + + if not tensor_list: + print("No images were successfully processed.") + return torch.empty(0) + + # --- 4. Stack the list of tensors into a single [N, C, H, W] batch tensor --- + return torch.stack(tensor_list, dim=0) diff --git a/src/openworldlib/representations/point_clouds_generation/pi3/loger/utils/geometry.py b/src/openworldlib/representations/point_clouds_generation/pi3/loger/utils/geometry.py new file mode 100644 index 00000000..3ca81a6c --- /dev/null +++ b/src/openworldlib/representations/point_clouds_generation/pi3/loger/utils/geometry.py @@ -0,0 +1,36 @@ +import numpy as np +import torch +import torch.nn.functional as F + +from ...pi3.utils.geometry import ( + se3_inverse, get_pixel, depthmap_to_absolute_camera_coordinates, depthmap_to_camera_coordinates, homogenize_points, get_gt_warp, warp_kpts, geotrf, inv, opencv_camera_to_plucker, depth_edge +) + +def robust_scale_estimation(ratios: torch.Tensor, trim_ratio: float = 0.25) -> torch.Tensor: + """ + Compute a robust mean of ratios by trimming the top and bottom trim_ratio fraction. + Args: + ratios: (B, N) tensor of ratios + trim_ratio: fraction to trim from each end (0.0 to 0.5) + Returns: + (B,) tensor of robust means + """ + B, N = ratios.shape + if N == 0: + return torch.ones(B, device=ratios.device, dtype=ratios.dtype) + + # Sort ratios along the last dimension + sorted_ratios, _ = torch.sort(ratios, dim=-1) + + # Determine indices to keep + trim_cnt = int(N * trim_ratio) + start_idx = trim_cnt + end_idx = N - trim_cnt + + if start_idx >= end_idx: + # Fallback to median if trimming removes everything (shouldn't happen with reasonable N and trim_ratio < 0.5) + return sorted_ratios[:, N // 2] + + # Slice and compute mean + valid_ratios = sorted_ratios[:, start_idx:end_idx] + return valid_ratios.mean(dim=-1) diff --git a/src/openworldlib/representations/point_clouds_generation/pi3/loger/utils/rotation.py b/src/openworldlib/representations/point_clouds_generation/pi3/loger/utils/rotation.py new file mode 100644 index 00000000..eec4b45a --- /dev/null +++ b/src/openworldlib/representations/point_clouds_generation/pi3/loger/utils/rotation.py @@ -0,0 +1,180 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Modified from PyTorch3D, https://github.com/facebookresearch/pytorch3d + +import torch +import numpy as np +import torch.nn.functional as F + + +def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor: + """ + Quaternion Order: XYZW or say ijkr, scalar-last + + Convert rotations given as quaternions to rotation matrices. + Args: + quaternions: quaternions with real part last, + as tensor of shape (..., 4). + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + i, j, k, r = torch.unbind(quaternions, -1) + # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`. + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as rotation matrices to quaternions. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + quaternions with real part last, as tensor of shape (..., 4). + Quaternion Order: XYZW or say ijkr, scalar-last + """ + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + + batch_dim = matrix.shape[:-2] + m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1) + + q_abs = _sqrt_positive_part( + torch.stack( + [ + 1.0 + m00 + m11 + m22, + 1.0 + m00 - m11 - m22, + 1.0 - m00 + m11 - m22, + 1.0 - m00 - m11 + m22, + ], + dim=-1, + ) + ) + + # we produce the desired quaternion multiplied by each of r, i, j, k + quat_by_rijk = torch.stack( + [ + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), + ], + dim=-2, + ) + + # We floor here at 0.1 but the exact level is not important; if q_abs is small, + # the candidate won't be picked. + flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) + quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) + + # if not for numerical problems, quat_candidates[i] should be same (up to a sign), + # forall i; we pick the best-conditioned one (with the largest denominator) + out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(batch_dim + (4,)) + + # Convert from rijk to ijkr + out = out[..., [1, 2, 3, 0]] + + out = standardize_quaternion(out) + + return out + + +def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: + """ + Returns torch.sqrt(torch.max(0, x)) + but with a zero subgradient where x is 0. + """ + ret = torch.zeros_like(x) + positive_mask = x > 0 + if torch.is_grad_enabled(): + ret[positive_mask] = torch.sqrt(x[positive_mask]) + else: + ret = torch.where(positive_mask, torch.sqrt(x), ret) + return ret + + +def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: + """ + Convert a unit quaternion to a standard form: one in which the real + part is non negative. + + Args: + quaternions: Quaternions with real part last, + as tensor of shape (..., 4). + + Returns: + Standardized quaternions as tensor of shape (..., 4). + """ + return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions) + + +def quat_multiply(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor: + """ + Multiply two quaternions. + + Args: + q1: First quaternion with shape (..., 4), order XYZW (ijkr, scalar-last) + q2: Second quaternion with shape (..., 4), order XYZW (ijkr, scalar-last) + + Returns: + Product quaternion q1 * q2 with shape (..., 4) + """ + x1, y1, z1, w1 = torch.unbind(q1, dim=-1) + x2, y2, z2, w2 = torch.unbind(q2, dim=-1) + + # Quaternion multiplication formula + w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2 + x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2 + y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2 + z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2 + + return torch.stack([x, y, z, w], dim=-1) + + +def quat_inverse(q: torch.Tensor) -> torch.Tensor: + """ + Compute the inverse of a quaternion. + + Args: + q: Quaternion with shape (..., 4), order XYZW (ijkr, scalar-last) + + Returns: + Inverse quaternion with shape (..., 4) + """ + # For unit quaternions, inverse is just conjugate: (x, y, z, w) -> (-x, -y, -z, w) + q_conj = q.clone() + q_conj[..., :3] = -q_conj[..., :3] # Negate x, y, z components + + # Normalize to handle potential numerical errors + norm_sq = (q * q).sum(dim=-1, keepdim=True) + return q_conj / norm_sq diff --git a/src/openworldlib/representations/point_clouds_generation/pi3/loger/utils/viser_utils.py b/src/openworldlib/representations/point_clouds_generation/pi3/loger/utils/viser_utils.py new file mode 100644 index 00000000..1465e848 --- /dev/null +++ b/src/openworldlib/representations/point_clouds_generation/pi3/loger/utils/viser_utils.py @@ -0,0 +1,781 @@ + +import os +import glob +import time +import threading +from typing import List, Optional, Tuple, Callable, Union + +import numpy as np +import torch +from tqdm.auto import tqdm +import viser +import viser.transforms as vt +import cv2 +import matplotlib.cm as cm + +try: + import onnxruntime +except ImportError: + print("onnxruntime not found. Sky segmentation may not work.") + +from loger.utils.visual_util import segment_sky, download_file_from_url + + +def apply_ema(data: np.ndarray, alpha: float) -> np.ndarray: + """Apply exponential moving average smoothing over time.""" + if not (0 < alpha <= 1.0): + raise ValueError("EMA alpha must be between 0 and 1") + if data.ndim == 1: + data = data[:, None] + smoothed_data = np.zeros_like(data) + smoothed_data[0] = data[0] + for i in range(1, len(data)): + smoothed_data[i] = alpha * data[i] + (1 - alpha) * smoothed_data[i - 1] + return smoothed_data + + +def setup_camera_follow( + server: viser.ViserServer, + slider: viser.GuiSliderHandle, + target_positions: np.ndarray, + camera_positions: Optional[np.ndarray] = None, + camera_wxyz: Optional[np.ndarray] = None, + camera_distance: float = 2.0, + camera_height: float = 1.0, + camera_angle: float = -30.0, + up_direction: Tuple[float, float, float] = (0.0, 1.0, 0.0), + fov: float = 45.0, + target_ema_alpha: float = 0.05, + camera_ema_alpha: Union[float, Callable[[], float]] = 0.05, + frame_lag: Union[int, Callable[[], int]] = 0, + backoff_distance: Union[float, Callable[[], float]] = 0.0, + camera_forward: Optional[np.ndarray] = None, +) -> Tuple[Callable[[], None], Callable[[], None]]: + """Set up camera follow behavior driven by a frame slider.""" + smoothed_target_positions = apply_ema(target_positions, target_ema_alpha) + + def _resolve_lag() -> int: + value = frame_lag() if callable(frame_lag) else frame_lag + return max(0, int(value)) + + def _resolve_backoff() -> float: + value = backoff_distance() if callable(backoff_distance) else backoff_distance + return max(0.0, float(value)) + + def _resolve_camera_ema_alpha() -> float: + value = camera_ema_alpha() if callable(camera_ema_alpha) else camera_ema_alpha + return float(np.clip(value, 1e-4, 1.0)) + + if camera_positions is not None: + if camera_wxyz is None: + raise ValueError("camera_wxyz must be provided when camera_positions is given.") + if len(camera_positions) != len(smoothed_target_positions) or len(camera_wxyz) != len(smoothed_target_positions): + raise ValueError("camera_positions/camera_wxyz and target_positions must have the same length.") + if camera_forward is not None and len(camera_forward) != len(smoothed_target_positions): + raise ValueError("camera_forward and target_positions must have the same length.") + # Apply EMA to explicit camera trajectory so lag mode still looks smooth. + smoothed_camera_positions = camera_positions.copy() + smoothed_camera_forward = None + last_camera_ema_alpha: Optional[float] = None + + def refresh_camera_ema_if_needed(): + nonlocal smoothed_camera_positions, smoothed_camera_forward, last_camera_ema_alpha + alpha = _resolve_camera_ema_alpha() + if last_camera_ema_alpha is not None and np.isclose(alpha, last_camera_ema_alpha): + return + smoothed_camera_positions = apply_ema(camera_positions, alpha) + if camera_forward is not None: + smoothed_camera_forward = apply_ema(camera_forward, alpha) + norms = np.linalg.norm(smoothed_camera_forward, axis=1, keepdims=True) + smoothed_camera_forward = smoothed_camera_forward / np.clip(norms, 1e-8, None) + else: + smoothed_camera_forward = None + last_camera_ema_alpha = alpha + + def update_camera_for_target(client: viser.ClientHandle, t: int): + refresh_camera_ema_if_needed() + t_follow = max(0, t - _resolve_lag()) + cam_pos = smoothed_camera_positions[t_follow].copy() + backoff = _resolve_backoff() + if smoothed_camera_forward is not None and backoff > 0.0: + cam_pos = cam_pos - smoothed_camera_forward[t_follow] * backoff + + client.camera.position = cam_pos + client.camera.wxyz = camera_wxyz[t_follow] + client.camera.fov = np.radians(fov) + else: + angle_rad = np.radians(camera_angle) + + def update_camera_for_target(client: viser.ClientHandle, t: int): + target_pos = smoothed_target_positions[t] + cam_offset = np.array( + [ + -camera_distance * np.cos(angle_rad), + camera_height, + -camera_distance * np.sin(angle_rad), + ] + ) + if tuple(up_direction) == (0.0, 1.0, 0.0): + final_cam_offset = np.array([cam_offset[0], cam_offset[1], cam_offset[2]]) + elif tuple(up_direction) == (0.0, 0.0, 1.0): + final_cam_offset = np.array([cam_offset[0], cam_offset[2], cam_offset[1]]) + else: + final_cam_offset = cam_offset + + client.camera.position = target_pos + final_cam_offset + client.camera.look_at = target_pos + client.camera.up_direction = up_direction + client.camera.fov = np.radians(fov) + + original_callback: Optional[Callable] = None + + def stop_camera_follow(): + nonlocal original_callback + if original_callback is not None: + slider.remove_update_callback(original_callback) + original_callback = None + + def resume_camera_follow(): + nonlocal original_callback + if original_callback is None: + @slider.on_update + def callback(_): + t = int(max(0, min(slider.value, len(smoothed_target_positions) - 1))) + for client in server.get_clients().values(): + update_camera_for_target(client, t) + + original_callback = callback + + return stop_camera_follow, resume_camera_follow + + +def viser_wrapper( + pred_dict: dict, + port: int = 8080, + init_conf_threshold: float = 50.0, # Low confidence percentage filter + background_mode: bool = False, + mask_sky: bool = False, + image_folder_for_sky_mask: str | None = None, + subsample: int = 1, + video_width: int = 320, # Video display width + share: bool = False, + point_size: float = 0.001, + canonical_first_frame: bool = True, # Use first frame as canonical (identity pose) +): + """Visualize predictions using Viser. + + Handles multiple camera inputs (cam0, cam01, ..., cam05). + Point clouds are placed in Frame child nodes. + Camera 01 (and others) are rendered and synchronized with Camera 0 playback. + """ + + # ───────────────────────────── Parse Data ──────────────────────────── + img_data = {} + xyz_data = {} + conf_data = {} + cam2world_data = {} + pcd_handles = {} + frustums = {} + frames_roots = {} + video_previews = {} + gui_show_cams = {} + + # Store original (unsubsampled) data for online subsample adjustment + img_data_original = {} + xyz_data_original = {} + conf_data_original = {} + current_subsample = [subsample] # Use list to allow modification in nested functions + + # Main camera (cam0) + # Pi3 outputs images as (S,H,W,3) or (S,C,H,W) depending on processing. + # Viser expects (H,W,3) for image previews. Let's standardize. + # The permute in demo_viser_pi3.py results in (S, H, W, 3) + cam0_images = pred_dict["images"] + if cam0_images.shape[-1] != 3: # If not (S,H,W,3), assume (S,C,H,W) + cam0_images = cam0_images.transpose(0, 2, 3, 1) + + img_data_original["cam0"] = cam0_images + xyz_data_original["cam0"] = pred_dict["points"] + conf_data_original["cam0"] = pred_dict["conf"] + + img_data["cam0"] = cam0_images[:, ::subsample, ::subsample] + xyz_data["cam0"] = pred_dict["points"][:, ::subsample, ::subsample] # (S,H,W,3) + conf_data["cam0"] = pred_dict["conf"][:, ::subsample, ::subsample] # (S,H,W) + S = xyz_data["cam0"].shape[0] + + cam_ids = ["cam0"] + for i in range(1, 6): # Check for cam01 to cam05 + cam_id = f"cam{i:02d}" + if cam_id in pred_dict: + cam_ids.append(cam_id) + + other_cam_images = pred_dict[cam_id]["images"] + if other_cam_images.shape[-1] != 3: # If not (S,H,W,3), assume (S,C,H,W) + other_cam_images = other_cam_images.transpose(0, 2, 3, 1) + + img_data_original[cam_id] = other_cam_images + xyz_data_original[cam_id] = pred_dict[cam_id]["points"] + conf_data_original[cam_id] = pred_dict[cam_id]["conf"] + + img_data[cam_id] = other_cam_images[:, :, ::subsample, ::subsample] + xyz_data[cam_id] = pred_dict[cam_id]["points"][:, ::subsample, ::subsample] + conf_data[cam_id] = pred_dict[cam_id]["conf"][:, ::subsample, ::subsample] + S = min(S, xyz_data[cam_id].shape[0]) # Unify frame count + + # Trim all data to unified frame count S + for cam_id in cam_ids: + img_data[cam_id] = img_data[cam_id][:S] + xyz_data[cam_id] = xyz_data[cam_id][:S] + conf_data[cam_id] = conf_data[cam_id][:S] + img_data_original[cam_id] = img_data_original[cam_id][:S] + xyz_data_original[cam_id] = xyz_data_original[cam_id][:S] + conf_data_original[cam_id] = conf_data_original[cam_id][:S] + + # ───────────────────────────── Sky Mask ─────────────────────────── + if mask_sky and image_folder_for_sky_mask is not None: + sky_masks_for_conf = apply_sky_segmentation(conf_data["cam0"], image_folder_for_sky_mask, is_conf_scores=True) + for cam_id in cam_ids: + if conf_data[cam_id].shape == sky_masks_for_conf.shape: + conf_data[cam_id] = conf_data[cam_id] * sky_masks_for_conf + else: + print(f"Warning: Shape mismatch for sky masking on {cam_id}. Skipping sky mask for this camera.") + + # ───────────────────────────── Setup Server ────────────────────────── + server = viser.ViserServer(host="0.0.0.0", port=port) + if share: server.request_share_url() + server.scene.set_up_direction("-y") + server.scene.add_frame("/frames", show_axes=False) # Root node + + H_main, W_main = xyz_data["cam0"].shape[1:3] + + def process_video_frame(frame_idx, cam_id_to_process="cam0"): + frame = img_data[cam_id_to_process][frame_idx] + if frame.max() <= 1.0: + frame = (frame * 255).astype(np.uint8) + else: + frame = frame.astype(np.uint8) + h, w = frame.shape[:2] + new_w = video_width + new_h = int(h * (new_w / w)) + resized_frame = cv2.resize(frame, (new_w, new_h), interpolation=cv2.INTER_AREA) + return resized_frame + + for cam_id in cam_ids: + video_previews[cam_id] = server.gui.add_image( + process_video_frame(0, cam_id), + format="jpeg", + label=f"Camera {cam_id.replace('cam', '')}" + ) + + # ───────────── GUI – Playback ───────────── + with server.gui.add_folder("Playback"): + gui_play = server.gui.add_checkbox("Playing", True) + gui_frame = server.gui.add_slider("Frame", 0, S-1, 1, 0, disabled=True) + gui_next = server.gui.add_button("Next", disabled=True) + gui_prev = server.gui.add_button("Prev", disabled=True) + gui_fps = server.gui.add_slider("FPS", 1, 60, 0.1, 20) + gui_fps_btn = server.gui.add_button_group("FPS options", ("10","20","30","60")) + gui_all = server.gui.add_checkbox("Show all frames", False) + gui_accumulate_play = server.gui.add_checkbox("Accumulate on play", False) + gui_stride = server.gui.add_slider("Stride", 1, S, 1, min(10, max(0, S - 1)), disabled=True) + + # ───────────── GUI – Visualization ───────── + with server.gui.add_folder("Visualization"): + gui_show_all_cams_master = server.gui.add_checkbox("Show All Cameras", True) + gui_conf = server.gui.add_slider("Confidence Percent", 0,100,0.1, init_conf_threshold) + gui_point_size = server.gui.add_slider("Point Size", 0.0001, 0.05, 0.0001, point_size) + gui_camera_size = server.gui.add_slider("Camera Size", 0.01, 0.3, 0.01, 0.03) + gui_camera_follow = server.gui.add_checkbox("Camera Follow Cam0", False) + gui_camera_follow_lag = server.gui.add_slider("Follow Lag (frames)", 0, min(30, max(0, S - 1)), 1, min(10, max(0, S - 1))) + gui_camera_follow_backoff = server.gui.add_slider("Follow Backoff", 0.0, 3.0, 0.01, 0.25) + gui_camera_follow_ema = server.gui.add_slider("Follow EMA Alpha", 0.001, 1.0, 0.001, 0.05) + for cam_id in cam_ids: + gui_show_cams[cam_id] = server.gui.add_checkbox(f"Show {cam_id.upper()}", True) + + # ───────────── GUI – Frame Range & Subsample ───────── + with server.gui.add_folder("Frame Range & Subsample"): + gui_start_frame = server.gui.add_slider("Start Frame", 0, S-1, 1, 0) + gui_end_frame = server.gui.add_slider("End Frame", 0, S-1, 1, S-1) + gui_subsample = server.gui.add_slider("Subsample", 1, 10, 1, subsample) + gui_apply_range = server.gui.add_button("Apply Range & Subsample") + + # ───────────── Helper Function: Confidence Filtering ──────── + def gen_mask(conf_array, percent): + if conf_array.size == 0 or not np.any(np.isfinite(conf_array)): + return conf_array > -np.inf + + # In Pi3, conf is already a probability in [0,1]. Percentile is not the right tool. + # We should use the percentage as a direct threshold. + thresh = percent / 100.0 + return (conf_array >= thresh) & (conf_array > 1e-5) + + # ───────────── Create All Frame Nodes ───────────── + # 不再进行中心化处理,直接使用原始坐标 + xyz_centered_data = {} + + # First, collect all camera poses to find the canonical transform + T0_inv = None + if canonical_first_frame: + # Get the first camera pose from cam0 to use as canonical frame + cam0_poses = pred_dict.get("camera_poses") + if cam0_poses is not None and len(cam0_poses) > 0: + T0 = cam0_poses[0] # First frame's camera-to-world transform (4x4) + if T0.shape == (4, 4): + T0_inv = np.linalg.inv(T0) + print("Using first frame as canonical frame (identity pose).") + elif T0.shape == (3, 4): + T0_full = np.eye(4) + T0_full[:3, :] = T0 + T0_inv = np.linalg.inv(T0_full) + print("Using first frame as canonical frame (identity pose).") + + for cam_id in cam_ids: + pcd_handles[cam_id] = [] + frustums[cam_id] = [] + frames_roots[cam_id] = [] + + # Pi3 provides camera_poses directly, assuming they are camera-to-world 4x4 matrices + if cam_id == "cam0": + poses = pred_dict.get("camera_poses") + else: + poses = pred_dict.get(cam_id, {}).get("camera_poses") + + if poses is None: + print(f"Warning: camera_poses not found for {cam_id}. Using identity.") + poses = np.tile(np.eye(4), (S, 1, 1)) + + # Truncate poses + poses = poses[:S] + + # Apply canonical transform if enabled + if canonical_first_frame and T0_inv is not None: + # Transform all poses: T'_i = T0_inv @ T_i + transformed_poses = [] + for i in range(len(poses)): + if poses[i].shape == (4, 4): + T_new = T0_inv @ poses[i] + elif poses[i].shape == (3, 4): + T_full = np.eye(4) + T_full[:3, :] = poses[i] + T_new = T0_inv @ T_full + else: + T_new = poses[i] + transformed_poses.append(T_new) + poses = np.array(transformed_poses) + + # Convert to (S, 3, 4) for viser + if poses.shape[-2:] == (4, 4): + cam2world_data[cam_id] = poses[:, :3, :] + else: + cam2world_data[cam_id] = poses + + # Transform point clouds to canonical frame if enabled + if canonical_first_frame and T0_inv is not None: + # Transform points: p' = R0_inv @ p + t0_inv (apply inverse of first pose) + R0_inv = T0_inv[:3, :3] + t0_inv = T0_inv[:3, 3] + xyz_original = xyz_data[cam_id] # (S, H, W, 3) + # Reshape for matrix multiplication: (S*H*W, 3) + original_shape = xyz_original.shape + xyz_flat = xyz_original.reshape(-1, 3) + # Apply rotation and translation + xyz_transformed = (R0_inv @ xyz_flat.T).T + t0_inv + xyz_centered_data[cam_id] = xyz_transformed.reshape(original_shape) + else: + # 直接使用原始点云坐标,不进行中心化 + xyz_centered_data[cam_id] = xyz_data[cam_id] + + # Camera follow uses cam0 camera trajectory over the frame slider. + cam0_poses = cam2world_data["cam0"] # (S, 3, 4) + cam0_positions = cam0_poses[:, :3, 3] + cam0_wxyz = np.array([vt.SO3.from_matrix(pose[:, :3]).wxyz for pose in cam0_poses], dtype=np.float32) + cam0_forward = np.array([pose[:, :3] @ np.array([0.0, 0.0, 1.0]) for pose in cam0_poses], dtype=np.float32) + cam0_lookat = cam0_positions + cam0_forward + stop_camera_follow, resume_camera_follow = setup_camera_follow( + server=server, + slider=gui_frame, + target_positions=cam0_lookat, + camera_positions=cam0_positions, + camera_wxyz=cam0_wxyz, + camera_forward=cam0_forward, + camera_ema_alpha=lambda: gui_camera_follow_ema.value, + frame_lag=lambda: gui_camera_follow_lag.value, + backoff_distance=lambda: gui_camera_follow_backoff.value, + up_direction=(0.0, -1.0, 0.0), + fov=60.0, + ) + + print("Building frames / point clouds …") + for i in tqdm(range(S)): + f_root_timestep = server.scene.add_frame(f"/frames/t{i}", show_axes=False) + + for cam_id in cam_ids: + frames_roots[cam_id].append(f_root_timestep) + + # Point Cloud + current_conf = conf_data[cam_id][i] + current_xyz_c = xyz_centered_data[cam_id][i] + current_img = img_data[cam_id][i] + + if gui_show_cams[cam_id].value: + mask = gen_mask(current_conf, gui_conf.value) + # Reshape arrays for masking: (H,W,3) -> (H*W,3) and (H,W) -> (H*W,) + pts_flat = current_xyz_c.reshape(-1, 3) + mask_flat = mask.reshape(-1) + rgb_img_for_pts = current_img + if rgb_img_for_pts.max() <= 1.0: rgb_img_for_pts = rgb_img_for_pts * 255 + rgb_flat = rgb_img_for_pts.astype(np.uint8).reshape(-1, 3) + + # Apply mask + pts = pts_flat[mask_flat] + rgb = rgb_flat[mask_flat] + else: + pts = np.zeros((0,3), np.float32) + rgb = np.zeros((0,3), np.uint8) + + pcd_handle = server.scene.add_point_cloud( + f"/frames/t{i}/pc_{cam_id}", pts, rgb, point_size=point_size*(subsample**(1/2)), point_shape="rounded" + ) + pcd_handles[cam_id].append(pcd_handle) + + # Frustum + norm_i = i/(S-1) if S>1 else 0.0 + col = cm.get_cmap('gist_rainbow')(norm_i)[:3] + + # Since Pi3 doesn't provide intrinsics, we have to use a heuristic for FOV. + # This is a limitation. A fixed FOV is a reasonable fallback. + h_img_cam, w_img_cam = img_data[cam_id].shape[-3:-1] + # Use a more reasonable FOV - around 60 degrees (1.047 radians) + fov_cam = 1.047 # 60 degrees in radians, typical camera FOV + aspect_cam = w_img_cam / h_img_cam + + # Reconstruct 4x4 matrix from 3x4 for SE3 + cam_pose_3x4 = cam2world_data[cam_id][i] + cam_pose_4x4 = np.eye(4) + cam_pose_4x4[:3, :] = cam_pose_3x4 + T_cam = vt.SE3.from_matrix(cam_pose_4x4) + + # Use processed image for frustum view + frustum_img = current_img + if frustum_img.max() <= 1.0: frustum_img = frustum_img * 255 + frustum_img = frustum_img.astype(np.uint8) + + frustum_handle = server.scene.add_camera_frustum( + f"/frames/t{i}/frustum_{cam_id}", fov_cam, aspect_cam, scale=gui_camera_size.value, + image=frustum_img, + wxyz=T_cam.rotation().wxyz, position=T_cam.translation(), + color=col, line_width=2.0 + ) + frustums[cam_id].append(frustum_handle) + + # ───────────── Update Visibility ───────────── + def set_visibility(): + show_all_ts = gui_all.value + accumulate_on_play = gui_accumulate_play.value and gui_play.value and (not show_all_ts) + stride_ts = gui_stride.value + current_ts = gui_frame.value + master_show_frustums = gui_show_all_cams_master.value + start_f = gui_start_frame.value + end_f = gui_end_frame.value + + for i in range(S): + in_range = (start_f <= i <= end_f) + if show_all_ts: + vis_timestep_level = (i % stride_ts == 0) and in_range + elif accumulate_on_play: + vis_timestep_level = (start_f <= i <= current_ts) and (((i - start_f) % stride_ts) == 0) and in_range + else: + vis_timestep_level = (i == current_ts) and in_range + + if len(cam_ids) > 0 and frames_roots[cam_ids[0]][i] is not None: + frames_roots[cam_ids[0]][i].visible = vis_timestep_level + + for cam_id in cam_ids: + individual_cam_active = gui_show_cams[cam_id].value + + if pcd_handles[cam_id][i] is not None: + pcd_handles[cam_id][i].visible = vis_timestep_level and individual_cam_active + + if frustums[cam_id][i] is not None: + frustums[cam_id][i].visible = vis_timestep_level and individual_cam_active and master_show_frustums + + set_visibility() + + # ───────────── Refresh Point Clouds (Confidence Slider) ───────────── + def refresh_pointclouds(): + pct = gui_conf.value + cur_subsample = current_subsample[0] + new_point_size = gui_point_size.value * (cur_subsample ** 0.5) + + for i in tqdm(range(S), leave=False, desc="Refreshing PCs"): + for cam_id in cam_ids: + if gui_show_cams[cam_id].value : + current_conf = conf_data[cam_id][i] + current_xyz_c = xyz_centered_data[cam_id][i] + current_img = img_data[cam_id][i] + + mask = gen_mask(current_conf, pct) + # Reshape arrays for masking: (H,W,3) -> (H*W,3) and (H,W) -> (H*W,) + pts_flat = current_xyz_c.reshape(-1, 3) + mask_flat = mask.reshape(-1) + rgb_img_for_pts = current_img + if rgb_img_for_pts.max() <= 1.0: rgb_img_for_pts = rgb_img_for_pts * 255 + rgb_flat = rgb_img_for_pts.astype(np.uint8).reshape(-1, 3) + + # Apply mask + pts = pts_flat[mask_flat] + rgb = rgb_flat[mask_flat] + + pcd_handles[cam_id][i].points = pts + pcd_handles[cam_id][i].colors = rgb + pcd_handles[cam_id][i].point_size = new_point_size + + # ───────────── GUI Callback Bindings ───────────── + @gui_next.on_click + def _(_): gui_frame.value = (gui_frame.value+1)%S + + @gui_prev.on_click + def _(_): gui_frame.value = (gui_frame.value-1+S)%S + + @gui_fps_btn.on_click + def _(_): gui_fps.value = float(gui_fps_btn.value) + + @gui_play.on_update + def _(_): + controls_disabled = gui_play.value or gui_all.value + gui_frame.disabled = controls_disabled + gui_next.disabled = controls_disabled + gui_prev.disabled = controls_disabled + gui_stride.disabled = not (gui_all.value or gui_accumulate_play.value) + set_visibility() + + @gui_conf.on_update + def _(_): refresh_pointclouds() + + @gui_point_size.on_update + def _(_): + new_point_size = gui_point_size.value + for cam_id in cam_ids: + for handle in pcd_handles[cam_id]: + if handle is not None: + handle.point_size = new_point_size + + @gui_camera_size.on_update + def _(_): + new_camera_size = gui_camera_size.value + for cam_id in cam_ids: + for handle in frustums[cam_id]: + if handle is not None: + handle.scale = new_camera_size + + @gui_frame.on_update + def _(_): + set_visibility() + current_frame_val = gui_frame.value + for cam_id in cam_ids: + video_previews[cam_id].image = process_video_frame(current_frame_val, cam_id) + + @gui_all.on_update + def _(_): + gui_stride.disabled = not (gui_all.value or gui_accumulate_play.value) + controls_disabled = gui_play.value or gui_all.value + gui_frame.disabled = controls_disabled + gui_next.disabled = controls_disabled + gui_prev.disabled = controls_disabled + set_visibility() + + @gui_stride.on_update + def _(_): set_visibility() + + @gui_accumulate_play.on_update + def _(_): + gui_stride.disabled = not (gui_all.value or gui_accumulate_play.value) + set_visibility() + + @gui_show_all_cams_master.on_update + def _(_): + set_visibility() + + @gui_start_frame.on_update + def _(_): set_visibility() + + @gui_end_frame.on_update + def _(_): set_visibility() + + @gui_apply_range.on_click + def _(_): + new_subsample = gui_subsample.value + if new_subsample != current_subsample[0]: + print(f"Applying new subsample: {new_subsample} (was {current_subsample[0]})") + current_subsample[0] = new_subsample + + # Update subsampled data + for cam_id in cam_ids: + img_data[cam_id] = img_data_original[cam_id][:, ::new_subsample, ::new_subsample] + xyz_data[cam_id] = xyz_data_original[cam_id][:, ::new_subsample, ::new_subsample] + conf_data[cam_id] = conf_data_original[cam_id][:, ::new_subsample, ::new_subsample] + + # Update xyz_centered_data with canonical transform + if canonical_first_frame and T0_inv is not None: + R0_inv = T0_inv[:3, :3] + t0_inv = T0_inv[:3, 3] + xyz_original = xyz_data[cam_id] + original_shape = xyz_original.shape + xyz_flat = xyz_original.reshape(-1, 3) + xyz_transformed = (R0_inv @ xyz_flat.T).T + t0_inv + xyz_centered_data[cam_id] = xyz_transformed.reshape(original_shape) + else: + xyz_centered_data[cam_id] = xyz_data[cam_id] + + # Refresh point clouds with new subsample + refresh_pointclouds() + print(f"Subsample updated to {new_subsample}") + else: + # Just refresh visibility for frame range + set_visibility() + print(f"Frame range applied: {gui_start_frame.value} - {gui_end_frame.value}") + + for cam_id in cam_ids: + # Use a closure to capture the correct cam_id for the callback + def make_callback(cam_id_captured): + def callback(_): + set_visibility() + refresh_pointclouds() + return callback + gui_show_cams[cam_id].on_update(make_callback(cam_id)) + + @gui_camera_follow.on_update + def _(_): + if gui_camera_follow.value: + resume_camera_follow() + # Apply once immediately so manual frame value is reflected. + gui_frame.value = gui_frame.value + else: + stop_camera_follow() + + @gui_camera_follow_lag.on_update + def _(_): + if gui_camera_follow.value: + gui_frame.value = gui_frame.value + + @gui_camera_follow_backoff.on_update + def _(_): + if gui_camera_follow.value: + gui_frame.value = gui_frame.value + + @gui_camera_follow_ema.on_update + def _(_): + if gui_camera_follow.value: + gui_frame.value = gui_frame.value + + # ───────────── Playback Loop ───────────── + def loop(): + prev_time = time.time() + while True: + if gui_play.value and not gui_all.value: + now = time.time() + if now - prev_time >= 1.0/gui_fps.value: + start_f = gui_start_frame.value + end_f = gui_end_frame.value + next_frame = gui_frame.value + 1 + if next_frame > end_f: + next_frame = start_f + elif next_frame < start_f: + next_frame = start_f + gui_frame.value = next_frame + prev_time = now + time.sleep(0.005) + + if background_mode: + threading.Thread(target=loop, daemon=True).start() + print(f"Viser server running in background on port {port}") + else: + print(f"Viser server running in foreground on port {port}. Press Ctrl+C to stop.") + loop() + + return server + +def apply_sky_segmentation( + data_to_mask: np.ndarray, + image_folder_for_sky_mask: str, + is_conf_scores: bool = False +) -> np.ndarray: + """ + Apply sky segmentation. If is_conf_scores is True, `data_to_mask` are confidence scores (S, H, W) + and the function returns a binary mask (0 for sky, 1 for non-sky) of the same shape. + Otherwise, it assumes data_to_mask are images and directly masks them (not implemented here for that path). + Args: + data_to_mask (np.ndarray): Data to apply mask to, typically confidence (S, H, W) or potentially images. + image_folder_for_sky_mask (str): Path to the folder containing original input images. + is_conf_scores (bool): If true, data_to_mask is confidence and a binary mask is returned. + Returns: + np.ndarray: If is_conf_scores, returns binary non-sky mask (S,H,W). Otherwise, modifies data_to_mask (not fully implemented). + """ + S_data, H_data, W_data = data_to_mask.shape + sky_masks_dir = image_folder_for_sky_mask.rstrip("/") + "_sky_masks" + os.makedirs(sky_masks_dir, exist_ok=True) + + onnx_path = "skyseg.onnx" + if not os.path.exists(onnx_path): + print("Downloading skyseg.onnx...") + try: + download_file_from_url("https://huggingface.co/JianyuanWang/skyseg/resolve/main/skyseg.onnx", onnx_path) + except Exception as e: + print(f"Failed to download skyseg.onnx: {e}. Sky segmentation will be skipped.") + return np.ones_like(data_to_mask) if is_conf_scores else data_to_mask + + try: + skyseg_session = onnxruntime.InferenceSession(onnx_path) + except Exception as e: + print(f"Error loading ONNX model: {e}. Sky segmentation will be skipped.") + return np.ones_like(data_to_mask) if is_conf_scores else data_to_mask + + try: + from natsort import natsorted + except ImportError: + print("natsort library not found. File sorting may be incorrect for sky masks.") + def natsorted(x): + return sorted(x) + + source_image_files = natsorted(glob.glob(os.path.join(image_folder_for_sky_mask, "*"))) + if not source_image_files: + print(f"No images found in {image_folder_for_sky_mask} for sky segmentation. Sky segmentation skipped.") + return np.ones_like(data_to_mask) if is_conf_scores else data_to_mask + + sky_mask_list = [] + print("Generating sky masks...") + num_images_to_process = min(S_data, len(source_image_files)) + + for i in tqdm(range(num_images_to_process), desc="Sky Segmentation"): + image_path = source_image_files[i] + image_name = os.path.basename(image_path) + mask_filename = os.path.splitext(image_name)[0] + ".png" + mask_filepath = os.path.join(sky_masks_dir, mask_filename) + + if os.path.exists(mask_filepath): + sky_mask_individual = cv2.imread(mask_filepath, cv2.IMREAD_GRAYSCALE) + else: + sky_mask_individual = segment_sky(image_path, skyseg_session, mask_filepath) + + if sky_mask_individual is None: + print(f"Warning: Sky mask for {image_name} could not be generated/loaded. Using no-sky mask for this frame.") + sky_mask_individual = np.zeros((H_data, W_data), dtype=np.uint8) + + if sky_mask_individual.shape[0] != H_data or sky_mask_individual.shape[1] != W_data: + sky_mask_individual = cv2.resize(sky_mask_individual, (W_data, H_data), interpolation=cv2.INTER_NEAREST) + sky_mask_list.append(sky_mask_individual) + + while len(sky_mask_list) < S_data: + print(f"Warning: Not enough sky masks ({len(sky_mask_list)}) for all {S_data} frames/depths. Padding with no-sky masks.") + sky_mask_list.append(np.zeros((H_data, W_data), dtype=np.uint8)) + + sky_mask_array_stacked = np.array(sky_mask_list) + + non_sky_mask_binary = (sky_mask_array_stacked < 128).astype(np.float32) + + if is_conf_scores: + print("Sky segmentation applied successfully (returning binary mask).") + return non_sky_mask_binary + else: + print("Warning: Direct image masking in apply_sky_segmentation is not fully implemented for this path.") + return data_to_mask \ No newline at end of file diff --git a/src/openworldlib/representations/point_clouds_generation/pi3/loger/utils/visual_util.py b/src/openworldlib/representations/point_clouds_generation/pi3/loger/utils/visual_util.py new file mode 100644 index 00000000..431b0c77 --- /dev/null +++ b/src/openworldlib/representations/point_clouds_generation/pi3/loger/utils/visual_util.py @@ -0,0 +1,710 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import trimesh +import gradio as gr +import numpy as np +import matplotlib +from scipy.spatial.transform import Rotation +import copy +import cv2 +import os +import requests + + +def predictions_to_glb( + predictions, + conf_thres=50.0, + filter_by_frames="all", + mask_black_bg=False, + mask_white_bg=False, + show_cam=True, + mask_sky=False, + target_dir=None, + prediction_mode="Predicted Pointmap", +) -> trimesh.Scene: + """ + Converts VGGT predictions to a 3D scene represented as a GLB file. + + Args: + predictions (dict): Dictionary containing model predictions with keys: + - world_points: 3D point coordinates (S, H, W, 3) + - world_points_conf: Confidence scores (S, H, W) + - images: Input images (S, H, W, 3) + - extrinsic: Camera extrinsic matrices (S, 3, 4) + conf_thres (float): Percentage of low-confidence points to filter out (default: 50.0) + filter_by_frames (str): Frame filter specification (default: "all") + mask_black_bg (bool): Mask out black background pixels (default: False) + mask_white_bg (bool): Mask out white background pixels (default: False) + show_cam (bool): Include camera visualization (default: True) + mask_sky (bool): Apply sky segmentation mask (default: False) + target_dir (str): Output directory for intermediate files (default: None) + prediction_mode (str): Prediction mode selector (default: "Predicted Pointmap") + + Returns: + trimesh.Scene: Processed 3D scene containing point cloud and cameras + + Raises: + ValueError: If input predictions structure is invalid + """ + if not isinstance(predictions, dict): + raise ValueError("predictions must be a dictionary") + + if conf_thres is None: + conf_thres = 10.0 + + # print("Building GLB scene") + selected_frame_idx = None + if filter_by_frames != "all" and filter_by_frames != "All": + try: + # Extract the index part before the colon + selected_frame_idx = int(filter_by_frames.split(":")[0]) + except (ValueError, IndexError): + pass + + if "Pointmap" in prediction_mode: + # print("Using Pointmap Branch") + if "world_points" in predictions: + pred_world_points = predictions["world_points"] # No batch dimension to remove + pred_world_points_conf = predictions.get("world_points_conf", np.ones_like(pred_world_points[..., 0])) + else: + print("Warning: world_points not found in predictions, falling back to depth-based points") + pred_world_points = predictions["world_points_from_depth"] + pred_world_points_conf = predictions.get("depth_conf", np.ones_like(pred_world_points[..., 0])) + else: + # print("Using Depthmap and Camera Branch") + pred_world_points = predictions["world_points_from_depth"] + pred_world_points_conf = predictions.get("depth_conf", np.ones_like(pred_world_points[..., 0])) + + # Get images from predictions + images = predictions["images"] + # Use extrinsic matrices instead of pred_extrinsic_list + try: + camera_matrices = predictions["extrinsic"] + except: # set dummy camera matrices for pi3 + camera_matrices = np.eye(4).reshape(1, 4, 4).repeat(len(images), axis=0)[:, :3, :4] + + if mask_sky: + if target_dir is not None: + import onnxruntime + + skyseg_session = None + target_dir_images = target_dir + "/images" + image_list = sorted(os.listdir(target_dir_images)) + sky_mask_list = [] + + # Get the shape of pred_world_points_conf to match + S, H, W = ( + pred_world_points_conf.shape + if hasattr(pred_world_points_conf, "shape") + else (len(images), images.shape[1], images.shape[2]) + ) + + # Download skyseg.onnx if it doesn't exist + if not os.path.exists("skyseg.onnx"): + print("Downloading skyseg.onnx...") + download_file_from_url( + "https://huggingface.co/JianyuanWang/skyseg/resolve/main/skyseg.onnx", "skyseg.onnx" + ) + + for i, image_name in enumerate(image_list): + image_filepath = os.path.join(target_dir_images, image_name) + mask_filepath = os.path.join(target_dir, "sky_masks", image_name) + + # Check if mask already exists + if os.path.exists(mask_filepath): + # Load existing mask + sky_mask = cv2.imread(mask_filepath, cv2.IMREAD_GRAYSCALE) + else: + # Generate new mask + if skyseg_session is None: + skyseg_session = onnxruntime.InferenceSession("skyseg.onnx") + sky_mask = segment_sky(image_filepath, skyseg_session, mask_filepath) + + # Resize mask to match H×W if needed + if sky_mask.shape[0] != H or sky_mask.shape[1] != W: + sky_mask = cv2.resize(sky_mask, (W, H)) + + sky_mask_list.append(sky_mask) + + # Convert list to numpy array with shape S×H×W + sky_mask_array = np.array(sky_mask_list) + + # Apply sky mask to confidence scores + sky_mask_binary = (sky_mask_array > 0.1).astype(np.float32) + pred_world_points_conf = pred_world_points_conf * sky_mask_binary + + if selected_frame_idx is not None: + pred_world_points = pred_world_points[selected_frame_idx][None] + pred_world_points_conf = pred_world_points_conf[selected_frame_idx][None] + images = images[selected_frame_idx][None] + camera_matrices = camera_matrices[selected_frame_idx][None] + + vertices_3d = pred_world_points.reshape(-1, 3) + # Handle different image formats - check if images need transposing + if images.ndim == 4 and images.shape[1] == 3: # NCHW format + colors_rgb = np.transpose(images, (0, 2, 3, 1)) + else: # Assume already in NHWC format + colors_rgb = images + colors_rgb = (colors_rgb.reshape(-1, 3) * 255).astype(np.uint8) + + conf = pred_world_points_conf.reshape(-1) + # Convert percentage threshold to actual confidence value + if conf_thres == 0.0: + conf_threshold = 0.0 + else: + conf_threshold = np.percentile(conf, conf_thres) + + conf_mask = (conf >= conf_threshold) & (conf > 1e-5) + + if mask_black_bg: + black_bg_mask = colors_rgb.sum(axis=1) >= 16 + conf_mask = conf_mask & black_bg_mask + + if mask_white_bg: + # Filter out white background pixels (RGB values close to white) + # Consider pixels white if all RGB values are above 240 + white_bg_mask = ~((colors_rgb[:, 0] > 240) & (colors_rgb[:, 1] > 240) & (colors_rgb[:, 2] > 240)) + conf_mask = conf_mask & white_bg_mask + + vertices_3d = vertices_3d[conf_mask] + colors_rgb = colors_rgb[conf_mask] + + if vertices_3d is None or np.asarray(vertices_3d).size == 0: + vertices_3d = np.array([[1, 0, 0]]) + colors_rgb = np.array([[255, 255, 255]]) + scene_scale = 1 + else: + # Calculate the 5th and 95th percentiles along each axis + lower_percentile = np.percentile(vertices_3d, 5, axis=0) + upper_percentile = np.percentile(vertices_3d, 95, axis=0) + + # Calculate the diagonal length of the percentile bounding box + scene_scale = np.linalg.norm(upper_percentile - lower_percentile) + + colormap = matplotlib.colormaps.get_cmap("gist_rainbow") + + # Initialize a 3D scene + scene_3d = trimesh.Scene() + + # Add point cloud data to the scene + point_cloud_data = trimesh.PointCloud(vertices=vertices_3d, colors=colors_rgb) + + scene_3d.add_geometry(point_cloud_data) + + # Prepare 4x4 matrices for camera extrinsics + num_cameras = len(camera_matrices) + extrinsics_matrices = np.zeros((num_cameras, 4, 4)) + extrinsics_matrices[:, :3, :4] = camera_matrices + extrinsics_matrices[:, 3, 3] = 1 + + if show_cam: + # Add camera models to the scene + for i in range(num_cameras): + world_to_camera = extrinsics_matrices[i] + camera_to_world = np.linalg.inv(world_to_camera) + rgba_color = colormap(i / num_cameras) + current_color = tuple(int(255 * x) for x in rgba_color[:3]) + + integrate_camera_into_scene(scene_3d, camera_to_world, current_color, scene_scale) + + # Align scene to the observation of the first camera + scene_3d = apply_scene_alignment(scene_3d, extrinsics_matrices) + + return scene_3d + + +def predictions_gt_comparison_to_glb( + predictions, + gt_predictions, + conf_thres=30.0, + filter_by_frames="All", + mask_black_bg=False, + mask_white_bg=False, + show_cam=True, + mask_sky=False, + target_dir=None, + prediction_mode="Depthmap and Camera Branch", + spatial_offset=1.0, + subsample_scale=1.0, +) -> trimesh.Scene: + """ + Creates a GLB scene with both GT and predicted point clouds for comparison. + + Args: + predictions (dict): Dictionary containing model predictions + gt_predictions (dict): Dictionary containing ground truth data + conf_thres (float): Confidence threshold for filtering + filter_by_frames (str): Frame filter specification + mask_black_bg (bool): Mask out black background pixels + mask_white_bg (bool): Mask out white background pixels + show_cam (bool): Include camera visualization + mask_sky (bool): Apply sky segmentation mask + target_dir (str): Output directory for intermediate files + prediction_mode (str): Prediction mode selector + spatial_offset (float): Spatial offset between GT and prediction point clouds + subsample_scale (float): Subsample scale for point cloud + Returns: + trimesh.Scene: Scene containing both GT and predicted point clouds + """ + # print("Building GT vs Prediction comparison GLB scene") + + # Initialize scene + scene_3d = trimesh.Scene() + + # Process frame selection + selected_frame_idx = None + if filter_by_frames != "all" and filter_by_frames != "All": + try: + selected_frame_idx = int(filter_by_frames.split(":")[0]) + except (ValueError, IndexError): + pass + + # Helper function to process point cloud data + def process_point_cloud(data_dict, is_gt=False, color_offset=(0, 0, 0), gt_conf_mask=None): + """Process point cloud data and return vertices and colors""" + + # Select data source based on prediction mode + if "Pointmap" in prediction_mode and not is_gt: + if "world_points" in data_dict: + world_points = data_dict["world_points"] + world_points_conf = data_dict.get("world_points_conf", np.ones_like(world_points[..., 0])) + else: + world_points = data_dict["world_points_from_depth"] + world_points_conf = data_dict.get("depth_conf", np.ones_like(world_points[..., 0])) + else: + world_points = data_dict["world_points_from_depth"] + world_points_conf = data_dict.get("depth_conf", np.ones_like(world_points[..., 0])) + if is_gt and "world_points_conf" in data_dict: + # For GT, use world_points_conf if available + world_points_conf = data_dict["world_points_conf"] + + images = data_dict["images"] + + # Apply frame selection + if selected_frame_idx is not None: + world_points = world_points[selected_frame_idx][None] + world_points_conf = world_points_conf[selected_frame_idx][None] + images = images[selected_frame_idx][None] + + # Prepare vertices and colors + vertices = world_points.reshape(-1, 3) + + # Handle different image formats + if images.ndim == 4 and images.shape[1] == 3: # NCHW format + colors = np.transpose(images, (0, 2, 3, 1)) + else: # NHWC format + colors = images + colors = (colors.reshape(-1, 3) * 255).astype(np.uint8) + + # Apply color offset for distinction + colors = colors.astype(np.float32) + colors[:, 0] = np.clip(colors[:, 0] + color_offset[0], 0, 255) + colors[:, 1] = np.clip(colors[:, 1] + color_offset[1], 0, 255) + colors[:, 2] = np.clip(colors[:, 2] + color_offset[2], 0, 255) + colors = colors.astype(np.uint8) + + # Apply confidence filtering + conf = world_points_conf.reshape(-1) + + if gt_conf_mask is not None: + # Use provided GT confidence mask for fair comparison + conf_mask = gt_conf_mask + else: + # Create confidence mask based on data type + if is_gt: + # For GT, use a lower threshold since it's generally more reliable + conf_threshold = 0.01 # Much lower threshold for GT + else: + # For predictions, use the specified threshold + if conf_thres == 0.0: + conf_threshold = 0.0 + else: + conf_threshold = np.percentile(conf, conf_thres) + + conf_mask = (conf >= conf_threshold) & (conf > 1e-5) + + # Apply background masks + if mask_black_bg: + black_bg_mask = colors.sum(axis=1) >= 16 + conf_mask = conf_mask & black_bg_mask + + if mask_white_bg: + white_bg_mask = ~((colors[:, 0] > 240) & (colors[:, 1] > 240) & (colors[:, 2] > 240)) + conf_mask = conf_mask & white_bg_mask + + # Filter vertices and colors + vertices = vertices[conf_mask] + colors = colors[conf_mask] + + return vertices, colors, conf_mask + + # First process GT data to get the confidence mask + gt_vertices, gt_colors, gt_conf_mask = process_point_cloud(gt_predictions, is_gt=True, color_offset=(30, -10, -10)) + + # Process prediction data using GT confidence mask for fair comparison + pred_vertices, pred_colors, _ = process_point_cloud(predictions, is_gt=False, gt_conf_mask=gt_conf_mask) + + # Apply spatial offset to separate GT and prediction + if gt_vertices.size > 0 and pred_vertices.size > 0: + # Calculate scene bounds to determine appropriate offset + all_vertices = np.vstack([pred_vertices, gt_vertices]) + scene_bounds = np.max(all_vertices, axis=0) - np.min(all_vertices, axis=0) + actual_offset = np.max(scene_bounds) * spatial_offset + + # Offset GT along X-axis + gt_vertices_offset = gt_vertices.copy() + gt_vertices_offset[:, 0] += actual_offset + else: + gt_vertices_offset = gt_vertices + actual_offset = spatial_offset + + # Add prediction point cloud (blue-tinted) + if pred_vertices.size > 0: + pred_point_cloud = trimesh.PointCloud(vertices=pred_vertices[::subsample_scale], colors=pred_colors[::subsample_scale]) + scene_3d.add_geometry(pred_point_cloud, node_name="prediction_points") + + # Add GT point cloud (red-tinted) + if gt_vertices_offset.size > 0: + gt_point_cloud = trimesh.PointCloud(vertices=gt_vertices_offset[::subsample_scale], colors=gt_colors[::subsample_scale]) + scene_3d.add_geometry(gt_point_cloud, node_name="gt_points") + + # Calculate scene scale for camera visualization + if pred_vertices.size > 0 or gt_vertices_offset.size > 0: + all_vertices = [] + if pred_vertices.size > 0: + all_vertices.append(pred_vertices) + if gt_vertices_offset.size > 0: + all_vertices.append(gt_vertices_offset) + + combined_vertices = np.vstack(all_vertices) + lower_percentile = np.percentile(combined_vertices, 5, axis=0) + upper_percentile = np.percentile(combined_vertices, 95, axis=0) + scene_scale = np.linalg.norm(upper_percentile - lower_percentile) + else: + scene_scale = 1.0 + + # Add cameras if requested + if show_cam: + colormap = matplotlib.colormaps.get_cmap("gist_rainbow") + + # Add prediction cameras (original positions) + try: + pred_camera_matrices = predictions["extrinsic"] + except: + # Use images from predictions data to determine number of cameras + pred_images = gt_predictions.get("images", np.zeros((1, 3, 224, 224))) + pred_camera_matrices = np.eye(4).reshape(1, 4, 4).repeat(len(pred_images), axis=0)[:, :3, :4] + if selected_frame_idx is not None: + pred_camera_matrices = pred_camera_matrices[selected_frame_idx][None] + + num_cameras = len(pred_camera_matrices) + pred_extrinsics_matrices = np.zeros((num_cameras, 4, 4)) + pred_extrinsics_matrices[:, :3, :4] = pred_camera_matrices + pred_extrinsics_matrices[:, 3, 3] = 1 + + for i in range(num_cameras): + world_to_camera = pred_extrinsics_matrices[i] + camera_to_world = np.linalg.inv(world_to_camera) + rgba_color = colormap(i / num_cameras) + current_color = tuple(int(255 * x) for x in rgba_color[:3]) + integrate_camera_into_scene(scene_3d, camera_to_world, current_color, scene_scale) + + # Add GT cameras (offset positions) + gt_camera_matrices = gt_predictions["extrinsic"] + if selected_frame_idx is not None: + gt_camera_matrices = gt_camera_matrices[selected_frame_idx][None] + + gt_extrinsics_matrices = np.zeros((num_cameras, 4, 4)) + gt_extrinsics_matrices[:, :3, :4] = gt_camera_matrices + gt_extrinsics_matrices[:, 3, 3] = 1 + + # Apply spatial offset to GT cameras + offset_transform = np.eye(4) + offset_transform[0, 3] = actual_offset + + for i in range(num_cameras): + world_to_camera = gt_extrinsics_matrices[i] + camera_to_world = np.linalg.inv(world_to_camera) + # Apply offset + camera_to_world_offset = offset_transform @ camera_to_world + + rgba_color = colormap(i / num_cameras) + # Make GT cameras slightly different color (more red) + gt_color = (min(255, int(255 * rgba_color[0]) + 50), + max(0, int(255 * rgba_color[1]) - 30), + max(0, int(255 * rgba_color[2]) - 30)) + integrate_camera_into_scene(scene_3d, camera_to_world_offset, gt_color, scene_scale) + + # Apply scene alignment based on first prediction camera + try: + pred_camera_matrices = predictions["extrinsic"] + except: + # Use images from predictions data to determine number of cameras + pred_images = gt_predictions.get("images", np.zeros((1, 3, 224, 224))) + pred_camera_matrices = np.eye(4).reshape(1, 4, 4).repeat(len(pred_images), axis=0)[:, :3, :4] + if selected_frame_idx is not None: + pred_camera_matrices = pred_camera_matrices[selected_frame_idx][None] + + pred_extrinsics_matrices = np.zeros((len(pred_camera_matrices), 4, 4)) + pred_extrinsics_matrices[:, :3, :4] = pred_camera_matrices + pred_extrinsics_matrices[:, 3, 3] = 1 + + scene_3d = apply_scene_alignment(scene_3d, pred_extrinsics_matrices) + + # print("GT vs Prediction comparison GLB scene built") + # print(f"Prediction points: {pred_vertices.shape[0] if pred_vertices.size > 0 else 0}") + # print(f"GT points: {gt_vertices.shape[0] if gt_vertices.size > 0 else 0}") + # print(f"Spatial offset applied: {actual_offset:.2f}") + + return scene_3d + + +def integrate_camera_into_scene( + scene: trimesh.Scene, + transform: np.ndarray, + face_colors: tuple, + scene_scale: float, +): + """ + Integrates a fake camera mesh into the 3D scene. + + Args: + scene (trimesh.Scene): The 3D scene to add the camera model. + transform (np.ndarray): Transformation matrix for camera positioning. + face_colors (tuple): Color of the camera face. + scene_scale (float): Scale of the scene. + """ + + cam_width = scene_scale * 0.015 + cam_height = scene_scale * 0.03 + + # Create cone shape for camera + rot_45_degree = np.eye(4) + rot_45_degree[:3, :3] = Rotation.from_euler("z", 45, degrees=True).as_matrix() + rot_45_degree[2, 3] = -cam_height + + opengl_transform = get_opengl_conversion_matrix() + # Combine transformations + complete_transform = transform @ opengl_transform @ rot_45_degree + camera_cone_shape = trimesh.creation.cone(cam_width, cam_height, sections=4) + + # Generate mesh for the camera + slight_rotation = np.eye(4) + slight_rotation[:3, :3] = Rotation.from_euler("z", 2, degrees=True).as_matrix() + + vertices_combined = np.concatenate( + [ + camera_cone_shape.vertices, + 0.95 * camera_cone_shape.vertices, + transform_points(slight_rotation, camera_cone_shape.vertices), + ] + ) + vertices_transformed = transform_points(complete_transform, vertices_combined) + + mesh_faces = compute_camera_faces(camera_cone_shape) + + # Add the camera mesh to the scene + camera_mesh = trimesh.Trimesh(vertices=vertices_transformed, faces=mesh_faces) + camera_mesh.visual.face_colors[:, :3] = face_colors + scene.add_geometry(camera_mesh) + + +def apply_scene_alignment(scene_3d: trimesh.Scene, extrinsics_matrices: np.ndarray) -> trimesh.Scene: + """ + Aligns the 3D scene based on the extrinsics of the first camera. + + Args: + scene_3d (trimesh.Scene): The 3D scene to be aligned. + extrinsics_matrices (np.ndarray): Camera extrinsic matrices. + + Returns: + trimesh.Scene: Aligned 3D scene. + """ + # Set transformations for scene alignment + opengl_conversion_matrix = get_opengl_conversion_matrix() + + # Rotation matrix for alignment (180 degrees around the y-axis) + align_rotation = np.eye(4) + align_rotation[:3, :3] = Rotation.from_euler("y", 180, degrees=True).as_matrix() + + # Apply transformation + initial_transformation = np.linalg.inv(extrinsics_matrices[0]) @ opengl_conversion_matrix @ align_rotation + scene_3d.apply_transform(initial_transformation) + return scene_3d + + +def get_opengl_conversion_matrix() -> np.ndarray: + """ + Constructs and returns the OpenGL conversion matrix. + + Returns: + numpy.ndarray: A 4x4 OpenGL conversion matrix. + """ + # Create an identity matrix + matrix = np.identity(4) + + # Flip the y and z axes + matrix[1, 1] = -1 + matrix[2, 2] = -1 + + return matrix + + +def transform_points(transformation: np.ndarray, points: np.ndarray, dim: int = None) -> np.ndarray: + """ + Applies a 4x4 transformation to a set of points. + + Args: + transformation (np.ndarray): Transformation matrix. + points (np.ndarray): Points to be transformed. + dim (int, optional): Dimension for reshaping the result. + + Returns: + np.ndarray: Transformed points. + """ + points = np.asarray(points) + initial_shape = points.shape[:-1] + dim = dim or points.shape[-1] + + # Apply transformation + transformation = transformation.swapaxes(-1, -2) # Transpose the transformation matrix + points = points @ transformation[..., :-1, :] + transformation[..., -1:, :] + + # Reshape the result + result = points[..., :dim].reshape(*initial_shape, dim) + return result + + +def compute_camera_faces(cone_shape: trimesh.Trimesh) -> np.ndarray: + """ + Computes the faces for the camera mesh. + + Args: + cone_shape (trimesh.Trimesh): The shape of the camera cone. + + Returns: + np.ndarray: Array of faces for the camera mesh. + """ + # Create pseudo cameras + faces_list = [] + num_vertices_cone = len(cone_shape.vertices) + + for face in cone_shape.faces: + if 0 in face: + continue + v1, v2, v3 = face + v1_offset, v2_offset, v3_offset = face + num_vertices_cone + v1_offset_2, v2_offset_2, v3_offset_2 = face + 2 * num_vertices_cone + + faces_list.extend( + [ + (v1, v2, v2_offset), + (v1, v1_offset, v3), + (v3_offset, v2, v3), + (v1, v2, v2_offset_2), + (v1, v1_offset_2, v3), + (v3_offset_2, v2, v3), + ] + ) + + faces_list += [(v3, v2, v1) for v1, v2, v3 in faces_list] + return np.array(faces_list) + + +def segment_sky(image_path, onnx_session, mask_filename=None): + """ + Segments sky from an image using an ONNX model. + Thanks for the great model provided by https://github.com/xiongzhu666/Sky-Segmentation-and-Post-processing + + Args: + image_path: Path to input image + onnx_session: ONNX runtime session with loaded model + mask_filename: Path to save the output mask + + Returns: + np.ndarray: Binary mask where 255 indicates non-sky regions + """ + + assert mask_filename is not None + image = cv2.imread(image_path) + + result_map = run_skyseg(onnx_session, [320, 320], image) + # resize the result_map to the original image size + result_map_original = cv2.resize(result_map, (image.shape[1], image.shape[0])) + + # Fix: Invert the mask so that 255 = non-sky, 0 = sky + # The model outputs low values for sky, high values for non-sky + output_mask = np.zeros_like(result_map_original) + output_mask[result_map_original < 32] = 255 # Use threshold of 32 + + os.makedirs(os.path.dirname(mask_filename), exist_ok=True) + cv2.imwrite(mask_filename, output_mask) + return output_mask + + +def run_skyseg(onnx_session, input_size, image): + """ + Runs sky segmentation inference using ONNX model. + + Args: + onnx_session: ONNX runtime session + input_size: Target size for model input (width, height) + image: Input image in BGR format + + Returns: + np.ndarray: Segmentation mask + """ + + # Pre process:Resize, BGR->RGB, Transpose, PyTorch standardization, float32 cast + temp_image = copy.deepcopy(image) + resize_image = cv2.resize(temp_image, dsize=(input_size[0], input_size[1])) + x = cv2.cvtColor(resize_image, cv2.COLOR_BGR2RGB) + x = np.array(x, dtype=np.float32) + mean = [0.485, 0.456, 0.406] + std = [0.229, 0.224, 0.225] + x = (x / 255 - mean) / std + x = x.transpose(2, 0, 1) + x = x.reshape(-1, 3, input_size[0], input_size[1]).astype("float32") + + # Inference + input_name = onnx_session.get_inputs()[0].name + output_name = onnx_session.get_outputs()[0].name + onnx_result = onnx_session.run([output_name], {input_name: x}) + + # Post process + onnx_result = np.array(onnx_result).squeeze() + min_value = np.min(onnx_result) + max_value = np.max(onnx_result) + onnx_result = (onnx_result - min_value) / (max_value - min_value) + onnx_result *= 255 + onnx_result = onnx_result.astype("uint8") + + return onnx_result + + +def download_file_from_url(url, filename): + """Downloads a file from a Hugging Face model repo, handling redirects.""" + try: + # Get the redirect URL + response = requests.get(url, allow_redirects=False) + response.raise_for_status() # Raise HTTPError for bad requests (4xx or 5xx) + + if response.status_code == 302: # Expecting a redirect + redirect_url = response.headers["Location"] + response = requests.get(redirect_url, stream=True) + response.raise_for_status() + else: + print(f"Unexpected status code: {response.status_code}") + return + + with open(filename, "wb") as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + print(f"Downloaded {filename} successfully.") + + except requests.exceptions.RequestException as e: + print(f"Error downloading file: {e}") diff --git a/src/openworldlib/representations/point_clouds_generation/pi3/loger_representation.py b/src/openworldlib/representations/point_clouds_generation/pi3/loger_representation.py new file mode 100644 index 00000000..89f91fed --- /dev/null +++ b/src/openworldlib/representations/point_clouds_generation/pi3/loger_representation.py @@ -0,0 +1,362 @@ +import os +import yaml +import inspect +import torch +import numpy as np +from typing import Dict, Any, Optional, List + +from huggingface_hub import snapshot_download + +from .loger.pi3 import Pi3 +from .loger.utils.geometry import depth_edge +from ...base_representation import BaseRepresentation + + +class LoGeRRepresentation(BaseRepresentation): + """ + Representation class for the LoGeR model. + + Supports windowed inference with optional TTT (Test-Time Training), + SWA (Sliding Window Attention) adapters, and Sim3/SE3 alignment. + + Expected input via get_representation(): + data["images"] : torch.Tensor of shape (B, N, C, H, W), values in [0, 1] + + Optional inference controls in data (override config defaults if provided): + window_size : int sliding window size (-1 = full sequence) + overlap_size : int overlap between consecutive windows + sim3 : bool enable Sim3 scale alignment across windows + se3 : bool enable SE3 (no scale) alignment across windows + reset_every : int reset TTT state every N windows (0 = never) + turn_off_ttt : bool disable TTT even if layers exist + turn_off_swa : bool disable SWA even if layers exist + sim3_scale_mode : str one of median / trimmed_mean / median_all / ... + num_iterations : int number of TTT decode iterations per window + conf_threshold : float confidence sigmoid threshold for mask (default 0.1) + edge_rtol : float relative tolerance for depth-edge filter (default 0.03) + """ + def __init__( + self, + model: Optional[Pi3] = None, + device: Optional[str] = None, + inference_defaults: Optional[Dict[str, Any]] = None, + ): + self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") + self.model = model + self.model_type = "loger" + # Inference-time forward() defaults parsed from original_config.yaml. + # Keys here serve as fallbacks; values in data dict always take priority. + self.inference_defaults: Dict[str, Any] = inference_defaults or {} + + if self.model is not None: + self.model = self.model.to(self.device).eval() + + if self.device == "cuda" and torch.cuda.is_available(): + compute_capability = torch.cuda.get_device_capability()[0] + self.dtype = torch.bfloat16 if compute_capability >= 8 else torch.float16 + else: + self.dtype = torch.float32 + + @classmethod + def from_pretrained( + cls, + pretrained_model_path: str, + device: Optional[str] = None, + config_path: Optional[str] = None, + subfolder: Optional[str] = None, + **kwargs, + ) -> "LoGeRRepresentation": + device = device or ("cuda" if torch.cuda.is_available() else "cpu") + + # ── Resolve paths ────────────────────────────────────────── + if not os.path.exists(pretrained_model_path): + from huggingface_hub import hf_hub_download + + ckpt_filename = f"{subfolder}/latest.pt" if subfolder else "latest.pt" + config_filename = f"{subfolder}/original_config.yaml" if subfolder else "original_config.yaml" + + # 先查本地缓存,没有再联网 + try: + ckpt_file = hf_hub_download( + repo_id=pretrained_model_path, + filename=ckpt_filename, + local_files_only=True, + ) + except Exception: + ckpt_file = hf_hub_download( + repo_id=pretrained_model_path, + filename=ckpt_filename, + ) + print(f"Checkpoint: {ckpt_file}") + model_root = os.path.dirname(ckpt_file) + + if config_path is None: + try: + config_path = hf_hub_download( + repo_id=pretrained_model_path, + filename=config_filename, + local_files_only=True, + ) + except Exception: + try: + config_path = hf_hub_download( + repo_id=pretrained_model_path, + filename=config_filename, + ) + except Exception: + pass + + elif os.path.isfile(pretrained_model_path): + ckpt_file = pretrained_model_path + model_root = os.path.dirname(ckpt_file) + if config_path is None: + candidate = os.path.join(model_root, "original_config.yaml") + config_path = candidate if os.path.exists(candidate) else None + + else: + model_root = pretrained_model_path + if subfolder: + model_root = os.path.join(model_root, subfolder) + ckpt_file = os.path.join(model_root, "latest.pt") + if not os.path.exists(ckpt_file): + pts = sorted(f for f in os.listdir(model_root) if f.endswith(".pt")) + if not pts: + raise FileNotFoundError(f"No .pt checkpoint found in {model_root}") + ckpt_file = os.path.join(model_root, pts[0]) + if config_path is None: + candidate = os.path.join(model_root, "original_config.yaml") + config_path = candidate if os.path.exists(candidate) else None + + # ── Parse config ─────────────────────────────────────────── + model_kwargs: Dict[str, Any] = {} + inference_defaults: Dict[str, Any] = {} + + if config_path: + model_kwargs, inference_defaults = cls._parse_config(config_path) + + model_kwargs.update(kwargs) + + # ── Instantiate and load ─────────────────────────────────── + model = Pi3(**model_kwargs) + + print(f"Loading checkpoint from {ckpt_file} ...") + checkpoint = torch.load(ckpt_file, map_location="cpu", weights_only=False) + state_dict = ( + checkpoint["model_state_dict"] + if "model_state_dict" in checkpoint + else checkpoint + ) + state_dict = { + (k[7:] if k.startswith("module.") else k): v + for k, v in state_dict.items() + } + model.load_state_dict(state_dict, strict=True) + print("Checkpoint loaded successfully.") + + return cls(model=model, device=device, inference_defaults=inference_defaults) + + def api_init(self, api_key: str, endpoint: str): + """Placeholder for future API-based inference.""" + pass + + @staticmethod + def _to_numpy(tensor: Optional[torch.Tensor]) -> Optional[np.ndarray]: + """Move a tensor to CPU and convert to float32 numpy array.""" + if tensor is None: + return None + if isinstance(tensor, torch.Tensor): + return tensor.detach().cpu().float().numpy() + return np.asarray(tensor, dtype=np.float32) + + @staticmethod + def _parse_config(config_path: str): + """ + Parse original_config.yaml and return: + (model_kwargs, inference_defaults) + + model_kwargs : valid Pi3.__init__ parameters + inference_defaults: forward() parameters (se3, window_size, overlap_size, ...) + """ + model_kwargs: Dict[str, Any] = {} + inference_defaults: Dict[str, Any] = {} + + try: + with open(config_path, "r") as f: + config = yaml.safe_load(f) + + # ── Model init kwargs ────────────────────────────────── + model_cfg = config.get("model", {}) + pi3_sig = inspect.signature(Pi3.__init__) + valid_init_keys = { + name + for name, param in pi3_sig.parameters.items() + if name not in {"self", "args", "kwargs"} + and param.kind in ( + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + ) + } + + def _maybe_parse_sequence(value): + if isinstance(value, str): + stripped = value.strip() + if stripped.startswith("[") and stripped.endswith("]"): + try: + parsed = yaml.safe_load(stripped) + if isinstance(parsed, (list, tuple)): + return list(parsed) + except Exception: + pass + return value + + for key in sorted(valid_init_keys): + if key in model_cfg: + value = model_cfg[key] + if key in {"ttt_insert_after", "attn_insert_after"}: + value = _maybe_parse_sequence(value) + model_kwargs[key] = value + + # ── Inference / forward() defaults ──────────────────── + # Priority: model section > training_settings section > top-level + training_cfg = config.get("training_settings", {}) + + def _get(key, default): + """Look up key in model_cfg, then training_cfg, then top-level config.""" + if key in model_cfg: + return model_cfg[key] + if key in training_cfg: + return training_cfg[key] + return config.get(key, default) + + inference_defaults = { + "se3": bool(_get("se3", True)), + "sim3": bool(_get("sim3", False)), + "window_size": int(_get("window_size", 32)), + "overlap_size": int(_get("overlap_size", 3)), + "reset_every": int(_get("reset_every", 0)), + "num_iterations": int(_get("num_iterations", 1)), + "sim3_scale_mode": str(_get("sim3_scale_mode", "median")), + "turn_off_ttt": bool(_get("turn_off_ttt", False)), + "turn_off_swa": bool(_get("turn_off_swa", False)), + } + + except Exception as exc: + print(f"Warning: could not parse config {config_path}: {exc}. " + "Using default Pi3 init parameters.") + + return model_kwargs, inference_defaults + + @torch.no_grad() + def get_representation(self, data: Dict[str, Any]) -> Dict[str, Any]: + """ + Run LoGeR inference and return all scene representation outputs. + + Args: + data : dict with at minimum: + "images" : torch.Tensor (B, N, C, H, W) in [0, 1] + + Returns: + dict with numpy arrays: + points : (B, N, H, W, 3) global 3-D point cloud + local_points : (B, N, H, W, 3) camera-frame point cloud + camera_poses : (B, N, 4, 4) camera-to-world SE3 + conf : (B, N, H, W, 1) raw confidence logits + masks : (B, N, H, W) binary quality mask + depth_map : (B, N, H, W) z-depth (local frame) + + Plus any extra keys forwarded directly from the model output, e.g.: + avg_gate_scale, attn_gate_scale, + chunk_sim3_scales, chunk_sim3_poses / chunk_se3_poses, + alignment_mode + """ + if self.model is None: + raise RuntimeError("Model not loaded. Call from_pretrained() first.") + + imgs = data["images"] + if not isinstance(imgs, torch.Tensor): + raise TypeError( + f"data['images'] must be a torch.Tensor, got {type(imgs)}" + ) + if imgs.dim() == 4: + imgs = imgs.unsqueeze(0) # add batch dim if missing + imgs = imgs.to(self.device) + + # ── Build forward kwargs ─────────────────────────────────── + # Priority (highest → lowest): + # 1. values explicitly set in `data` + # 2. self.inference_defaults (parsed from original_config.yaml) + # 3. hard-coded fallbacks below + def _get(key, fallback): + if key in data: + return data[key] + if key in self.inference_defaults: + return self.inference_defaults[key] + return fallback + + forward_kwargs = dict( + window_size = int(_get("window_size", 32)), + overlap_size = int(_get("overlap_size", 3)), + num_iterations = int(_get("num_iterations", 1)), + sim3 = bool(_get("sim3", False)), + se3 = bool(_get("se3", False)), + reset_every = int(_get("reset_every", 0)), + turn_off_ttt = bool(_get("turn_off_ttt", False)), + turn_off_swa = bool(_get("turn_off_swa", False)), + sim3_scale_mode = str(_get("sim3_scale_mode", "median")), + ) + + conf_threshold = float(_get("conf_threshold", 0.1)) + edge_rtol = float(_get("edge_rtol", 0.03)) + + # ── Forward pass ─────────────────────────────────────────── + autocast_enabled = (self.device == "cuda") + with torch.amp.autocast("cuda", dtype=self.dtype, enabled=autocast_enabled): + raw = self.model(imgs, **forward_kwargs) + + # ── Core geometry outputs ────────────────────────────────── + results: Dict[str, Any] = {} + results["points"] = self._to_numpy(raw.get("points")) + results["local_points"] = self._to_numpy(raw.get("local_points")) + results["camera_poses"] = self._to_numpy(raw.get("camera_poses")) + results["conf"] = self._to_numpy(raw.get("conf")) + + # ── Quality mask: sigmoid(conf) > threshold AND non-depth-edge ─ + conf_tensor = raw.get("conf") + if conf_tensor is not None: + conf_prob = torch.sigmoid(conf_tensor[..., 0]) + masks = conf_prob > conf_threshold + lp = raw.get("local_points") + if lp is not None: + non_edge = ~depth_edge(lp[..., 2], rtol=edge_rtol) + masks = torch.logical_and(masks, non_edge) + results["masks"] = masks.cpu().numpy() + results["depth_map"] = self._to_numpy( + lp[..., 2] if lp is not None else None + ) + else: + # use_conf=False model: accept everything + B, N, C, H, W = imgs.shape + results["masks"] = np.ones((B, N, H, W), dtype=bool) + lp = raw.get("local_points") + results["depth_map"] = self._to_numpy(lp[..., 2]) if lp is not None else None + + # ── Optional / diagnostic outputs ───────────────────────── + _scalar_keys = ("avg_gate_scale", "attn_gate_scale", "alignment_mode") + _tensor_keys = ( + "chunk_sim3_scales", "chunk_sim3_poses", "chunk_se3_poses", + "metric", "local_camera_poses", "camera_qvec", + "overlap_prev_cam", "overlap_next_cam", + "overlap_prev_pcd", "overlap_next_pcd", "overlap_next_conf", + ) + + for key in _scalar_keys: + if key in raw and raw[key] is not None: + val = raw[key] + results[key] = val.item() if isinstance(val, torch.Tensor) else val + + for key in _tensor_keys: + if key in raw and raw[key] is not None: + results[key] = self._to_numpy(raw[key]) + + return results + \ No newline at end of file diff --git a/test/test_loger.py b/test/test_loger.py new file mode 100644 index 00000000..f0ee324e --- /dev/null +++ b/test/test_loger.py @@ -0,0 +1,54 @@ +import os +import sys +import cv2 +import numpy as np + +sys.path.append("..") + +from openworldlib.pipelines.pi3.pipeline_loger import LoGeRPipeline + +MODE = "loger" # or "loger_star" +MODEL_PATH = {"loger": "Junyi42/LoGeR", "loger_star": "Junyi42/LoGeR_star"}[MODE] +IMAGE_INPUT = "./data/test_case/test_image_case1/ref_image.png" +VIDEO_INPUT = None +OUTPUT_DIR = "output_loger" + +DATA_PATH = VIDEO_INPUT if VIDEO_INPUT is not None else IMAGE_INPUT + +pipeline = LoGeRPipeline.from_pretrained(model_path=MODEL_PATH, mode=MODE) + +if VIDEO_INPUT is not None: + result = pipeline(videos=VIDEO_INPUT, task_type="reconstruction", interval=10) +else: + result = pipeline(images=IMAGE_INPUT, task_type="reconstruction", interval=10) +result.save(OUTPUT_DIR) +print(f"Mode: {MODE}") +print(f"Input: {DATA_PATH}") +print(f"Views: {result.camera_range['num_views']}") +print(f"Camera range: {result.camera_range}") + +rendered = pipeline(task_type="render_view", camera_view=0) +rendered.save(os.path.join(OUTPUT_DIR, "render_default.png")) + +interact_frames = pipeline(task_type="render_view", interactions=["forward", "left", "camera_r"]) +interact_video_path = os.path.join(OUTPUT_DIR, "interaction_video.mp4") +interact_video = cv2.VideoWriter( + interact_video_path, cv2.VideoWriter_fourcc(*"mp4v"), 15, interact_frames[0].size, +) +for f in interact_frames: + interact_video.write(cv2.cvtColor(np.array(f), cv2.COLOR_RGB2BGR)) +interact_video.release() +print(f"Interaction video saved: {interact_video_path} ({len(interact_frames)} frames)") + +frames = pipeline(task_type="render_trajectory") +video_path = os.path.join(OUTPUT_DIR, "trajectory_video.mp4") +video = cv2.VideoWriter( + video_path, + cv2.VideoWriter_fourcc(*"mp4v"), + 15, + frames[0].size, +) +for frame in frames: + video.write(cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR)) +video.release() +print(f"Trajectory video saved: {video_path}")