From 0db3f492985add8dd25c0320665075463e877515 Mon Sep 17 00:00:00 2001 From: jeevesh415 Date: Sat, 16 May 2026 20:28:26 +0530 Subject: [PATCH 1/7] Add phase-1 rigorous world-model evaluation harness and protocol --- CODE_ADDRESS_INDEX.md | 928 ++++++++++++++++++++++++++ README.md | 119 +++- docs/FRONTIER_GAP_ANALYSIS.md | 35 + docs/RIGOROUS_DEVELOPMENT_PROTOCOL.md | 33 + evaluate_world_model.py | 166 +++++ models/vjepa/planning.py | 29 +- models/vjepa/vjepa_model.py | 9 + vjepa_train.py | 13 +- 8 files changed, 1288 insertions(+), 44 deletions(-) create mode 100644 CODE_ADDRESS_INDEX.md create mode 100644 docs/FRONTIER_GAP_ANALYSIS.md create mode 100644 docs/RIGOROUS_DEVELOPMENT_PROTOCOL.md create mode 100644 evaluate_world_model.py diff --git a/CODE_ADDRESS_INDEX.md b/CODE_ADDRESS_INDEX.md new file mode 100644 index 00000000..98d5b96d --- /dev/null +++ b/CODE_ADDRESS_INDEX.md @@ -0,0 +1,928 @@ +# Code Address Index (Absolute Repository Map) + +This file is an exhaustive navigational index of the repository at generation time. + +- Root: `/workspace/HRM` +- Coverage: all project files except `.git` internals and cache folders. +- For each file: path, line count, byte size, and symbol/key anchors. + +## File Inventory + +### `.github/workflows/sync-from-upstream.yml` + +- Type: text +- Size: 664 bytes +- Lines: 1-19 + +- Address anchors: + - L1: `name` + - L2: `on` + - L3: `schedule` + - L5: `workflow_dispatch` + - L6: `jobs` + - L7: `sync` + - L8: `runs-on` + - L9: `steps` + +### `.gitignore` + +- Type: text +- Size: 3152 bytes +- Lines: 1-169 + +- First non-empty content (anchor): L1: `# WandB` + +### `.gitmodules` + +- Type: text +- Size: 364 bytes +- Lines: 1-9 + +- First non-empty content (anchor): L1: `[submodule "dataset/raw-data/ConceptARC"]` + +### `.vscode/launch.json` + +- Type: text +- Size: 778 bytes +- Lines: 1-26 + +- First non-empty content (anchor): L1: `{` + +### `.vscode/settings.json` + +- Type: text +- Size: 54 bytes +- Lines: 1-3 + +- First non-empty content (anchor): L1: `{` + +### `CODE_ADDRESS_INDEX.md` + +- Type: text +- Size: 20346 bytes +- Lines: 1-927 + +- Address anchors: + - L1: `# Code Address Index (Absolute Repository Map)` + - L9: `## File Inventory` + - L11: `### `.github/workflows/sync-from-upstream.yml`` + - L27: `### `.gitignore`` + - L35: `### `.gitmodules`` + - L43: `### `.vscode/launch.json`` + - L51: `### `.vscode/settings.json`` + - L59: `### `CODE_ADDRESS_INDEX.md`` + - L122: `### `LICENSE`` + - L130: `### `README.md`` + - L152: `### `arc_eval.ipynb`` + - L160: `### `assets/hrm.png`` + - L165: `### `assets/npyjs.js`` + - L173: `### `config/arch/hrm_v1.yaml`` + - L196: `### `config/cfg_pretrain.yaml`` + - L220: `### `config/vjepa_10b.yaml`` + - L289: `### `config/vjepa_micro.yaml`` + - L328: `### `dataset/build_arc_dataset.py`` + - L347: `### `dataset/build_maze_dataset.py`` + - L359: `### `dataset/common.py`` + - L370: `### `dataset/generate_dummy_data.py`` + - L379: `### `dataset/video_dataset.py`` + - L393: `### `evaluate.py`` + - L403: `### `models/adaptive_depth.py`` + - L417: `### `models/common.py`` + - L426: `### `models/hrm/hrm_act_v1.py`` + - L454: `### `models/hybrid_ssm.py`` + - L468: `### `models/information_bottleneck.py`` + - L483: `### `models/layers.py`` + - L511: `### `models/losses.py`` + - L527: `### `models/multimodal_grounding.py`` + - L544: `### `models/muon_optimizer.py`` + - L557: `### `models/proper_equivariance.py`` + - L575: `### `models/sparse_embedding.py`` + - L590: `### `models/spectral_conv.py`` + - L605: `### `models/topological.py`` + - L621: `### `models/ttt_layer.py`` + - L635: `### `models/uncertainty.py`` + - L650: `### `models/vjepa/flow_matching.py`` + - L669: `### `models/vjepa/gaussian_splatting.py`` + - L682: `### `models/vjepa/layers.py`` + - L703: `### `models/vjepa/losses.py`` + - L713: `### `models/vjepa/memory.py`` + - L742: `### `models/vjepa/physics_engine.py`` + - L756: `### `models/vjepa/planning.py`` + - L782: `### `models/vjepa/predictor.py`` + - L793: `### `models/vjepa/symplectic_integrator.py`` + - L807: `### `models/vjepa/utils.py`` + - L817: `### `models/vjepa/vit.py`` + - L835: `### `models/vjepa/vjepa_model.py`` + - L847: `### `pretrain.py`` + - L870: `### `puzzle_dataset.py`` + - L888: `### `puzzle_visualizer.html`` + - L896: `### `requirements.txt`` + - L904: `### `utils/functions.py`` + - L914: `### `vjepa_train.py`` + +### `LICENSE` + +- Type: text +- Size: 11357 bytes +- Lines: 1-202 + +- First non-empty content (anchor): L2: `Apache License` + +### `README.md` + +- Type: text +- Size: 3507 bytes +- Lines: 1-51 + +- Address anchors: + - L1: `# Hierarchical Reasoning Model - V-JEPA Integration (AGI Scale)` + - L5: `## Core Vision` + - L8: `## Key Architectural Pillars` + - L10: `### 1. Vision Encoder (The Eyes)` + - L15: `### 2. Physical Relativity (Lie Group Equivariance)` + - L18: `### 3. Continuous-Time Brain (Hamiltonian Neural ODEs & Predictive Coding)` + - L24: `### 4. Light & Shadow Intuition (Neural Radiance Latents)` + - L27: `### 5. Latent Planning (The Imagination)` + - L31: `### 6. Advanced Training Engine` + - L35: `## Getting Started` + - L37: `### Configuration` + - L40: `### Training` + - L43: `# or` + - L47: `## Future Multimodal Grounding` + +### `arc_eval.ipynb` + +- Type: text +- Size: 9317 bytes +- Lines: 1-252 + +- First non-empty content (anchor): L1: `{` + +### `assets/hrm.png` + +- Type: binary/non-UTF8 +- Size: 99852 bytes + +### `assets/npyjs.js` + +- Type: text +- Size: 5216 bytes +- Lines: 1-176 + +- First non-empty content (anchor): L1: `class npyjs {` + +### `config/arch/hrm_v1.yaml` + +- Type: text +- Size: 349 bytes +- Lines: 1-21 + +- Address anchors: + - L1: `name` + - L2: `loss` + - L3: `name` + - L4: `loss_type` + - L6: `halt_exploration_prob` + - L7: `halt_max_steps` + - L9: `H_cycles` + - L10: `L_cycles` + - L12: `H_layers` + - L13: `L_layers` + - L15: `hidden_size` + - L16: `num_heads` + - L17: `expansion` + - L19: `puzzle_emb_ndim` + - L21: `pos_encodings` + +### `config/cfg_pretrain.yaml` + +- Type: text +- Size: 492 bytes +- Lines: 1-31 + +- Address anchors: + - L3: `defaults` + - L7: `hydra` + - L8: `output_subdir` + - L11: `data_path` + - L14: `global_batch_size` + - L16: `epochs` + - L17: `eval_interval` + - L18: `checkpoint_every_eval` + - L20: `lr` + - L21: `lr_min_ratio` + - L22: `lr_warmup_steps` + - L25: `beta1` + - L26: `beta2` + - L27: `weight_decay` + - L28: `puzzle_emb_weight_decay` + - L31: `puzzle_emb_lr` + +### `config/vjepa_10b.yaml` + +- Type: text +- Size: 1709 bytes +- Lines: 1-85 + +- Address anchors: + - L2: `encoder` + - L3: `img_size` + - L4: `patch_size` + - L5: `in_chans` + - L6: `embed_dim` + - L7: `depth` + - L8: `num_heads` + - L9: `expansion` + - L10: `max_t` + - L11: `max_h` + - L12: `max_w` + - L14: `predictor` + - L16: `H_cycles` + - L17: `L_cycles` + - L18: `H_layers` + - L19: `L_layers` + - L20: `hidden_size` + - L21: `expansion` + - L22: `num_heads` + - L23: `pos_encodings` + - L24: `halt_max_steps` + - L25: `halt_exploration_prob` + - L26: `forward_dtype` + - L29: `use_gaussian_splatting` + - L30: `num_gaussians` + - L33: `use_flow_matching` + - L36: `use_symplectic` + - L39: `use_ttt` + - L40: `ttt_inner_lr` + - L41: `ttt_inner_steps` + - L44: `adaptive_depth` + - L45: `confidence_threshold` + - L46: `uncertainty_threshold` + - L49: `use_multimodal` + - L50: `audio_input_dim` + - L51: `tactile_input_dim` + - L54: `use_information_bottleneck` + - L55: `ib_beta` + - L57: `use_spectral_conv` + - L58: `spectral_num_filters` + - L59: `spectral_polynomial_order` + - L61: `use_topology` + - L62: `topology_filtration_steps` + - L64: `use_proper_se3` + - L65: `se3_num_frequencies` + - L66: `se3_l_max` + - L68: `use_uncertainty` + - L69: `uncertainty_mc_samples` + - L70: `uncertainty_dropout` + - L72: `training` + - L73: `batch_size` + - L74: `global_batch_size` + - L75: `lr` + - L76: `ema_momentum` + - L77: `sim_coeff` + - L78: `std_coeff` + - L79: `cov_coeff` + - L80: `weight_decay` + - L83: `optimizer` + - L84: `muon_momentum` + - L85: `muon_ns_steps` + +### `config/vjepa_micro.yaml` + +- Type: text +- Size: 555 bytes +- Lines: 1-34 + +- Address anchors: + - L2: `encoder` + - L3: `img_size` + - L4: `patch_size` + - L5: `in_chans` + - L6: `embed_dim` + - L7: `depth` + - L8: `num_heads` + - L9: `expansion` + - L10: `max_t` + - L11: `max_h` + - L12: `max_w` + - L14: `predictor` + - L15: `H_cycles` + - L16: `L_cycles` + - L17: `H_layers` + - L18: `L_layers` + - L19: `hidden_size` + - L20: `expansion` + - L21: `num_heads` + - L22: `pos_encodings` + - L23: `halt_max_steps` + - L24: `halt_exploration_prob` + - L25: `forward_dtype` + - L27: `training` + - L28: `batch_size` + - L29: `global_batch_size` + - L30: `lr` + - L31: `ema_momentum` + - L32: `sim_coeff` + - L33: `std_coeff` + - L34: `cov_coeff` + +### `dataset/build_arc_dataset.py` + +- Type: text +- Size: 10084 bytes +- Lines: 1-291 + +- Address anchors: + - L19: `class DataProcessConfig` + - L37: `class ARCPuzzle` + - L43: `def arc_grid_to_np` + - L54: `def np_grid_to_seq_translational_augment` + - L81: `def puzzle_hash` + - L83: `def _grid_hash` + - L98: `def convert_single_arc_puzzle` + - L122: `def _map_grid` + - L148: `def load_puzzles_arcagi` + - L184: `def convert_dataset` + - L286: `def main` + +### `dataset/build_maze_dataset.py` + +- Type: text +- Size: 4461 bytes +- Lines: 1-142 + +- Address anchors: + - L22: `class DataProcessConfig` + - L30: `def convert_subset` + - L89: `def _seq_to_numpy` + - L136: `def preprocess_data` + +### `dataset/common.py` + +- Type: text +- Size: 1381 bytes +- Lines: 1-51 + +- Address anchors: + - L12: `class PuzzleDatasetMetadata` + - L27: `def dihedral_transform` + - L50: `def inverse_dihedral_transform` + +### `dataset/generate_dummy_data.py` + +- Type: text +- Size: 860 bytes +- Lines: 1-29 + +- Address anchors: + - L5: `def generate_dummy_video` + +### `dataset/video_dataset.py` + +- Type: text +- Size: 4426 bytes +- Lines: 1-108 + +- Address anchors: + - L8: `class AdvancedVideoDataset` + - L13: `def __init__` + - L33: `def _get_video_stream` + - L42: `def _generate_3d_block_mask` + - L66: `def __iter__` + - L106: `def get_dataloader` + +### `evaluate.py` + +- Type: text +- Size: 2490 bytes +- Lines: 1-68 + +- Address anchors: + - L13: `class EvalConfig` + - L19: `def launch` + +### `models/adaptive_depth.py` + +- Type: text +- Size: 6649 bytes +- Lines: 1-199 + +- Address anchors: + - L24: `class AdaptiveDepthController` + - L41: `def __init__` + - L54: `def should_continue` + - L116: `class AdaptiveDepthWrapper` + - L131: `def __init__` + - L146: `def forward` + +### `models/common.py` + +- Type: text +- Size: 1216 bytes +- Lines: 1-32 + +- Address anchors: + - L7: `def trunc_normal_init_` + +### `models/hrm/hrm_act_v1.py` + +- Type: text +- Size: 12161 bytes +- Lines: 1-283 + +- Address anchors: + - L16: `class HierarchicalReasoningModel_ACTV1InnerCarry` + - L22: `class HierarchicalReasoningModel_ACTV1Carry` + - L31: `class HierarchicalReasoningModel_ACTV1Config` + - L60: `class HierarchicalReasoningModel_ACTV1Block` + - L61: `def __init__` + - L77: `def forward` + - L86: `class HierarchicalReasoningModel_ACTV1ReasoningModule` + - L87: `def __init__` + - L92: `def forward` + - L102: `class HierarchicalReasoningModel_ACTV1_Inner` + - L103: `def __init__` + - L146: `def _input_embeddings` + - L168: `def empty_carry` + - L174: `def reset_carry` + - L180: `def forward` + - L216: `class HierarchicalReasoningModel_ACTV1` + - L219: `def __init__` + - L225: `def puzzle_emb` + - L228: `def initial_carry` + - L240: `def forward` + +### `models/hybrid_ssm.py` + +- Type: text +- Size: 7530 bytes +- Lines: 1-228 + +- Address anchors: + - L26: `class SelectiveSSM` + - L41: `def __init__` + - L85: `def forward` + - L141: `class HybridSSMAttentionBlock` + - L161: `def __init__` + - L199: `def forward` + +### `models/information_bottleneck.py` + +- Type: text +- Size: 7603 bytes +- Lines: 1-227 + +- Address anchors: + - L28: `class VariationalInformationBottleneck` + - L42: `def __init__` + - L71: `def _reparameterize` + - L90: `def forward` + - L135: `class InformationBottleneckAttention` + - L150: `def __init__` + - L179: `def forward` + +### `models/layers.py` + +- Type: text +- Size: 6156 bytes +- Lines: 1-167 + +- Address anchors: + - L13: `def flash_attn_func` + - L29: `def _find_multiple` + - L33: `def rotate_half` + - L40: `def apply_rotary_pos_emb` + - L53: `class CastedLinear` + - L54: `def __init__` + - L68: `def forward` + - L72: `class CastedEmbedding` + - L73: `def __init__` + - L86: `def forward` + - L90: `class RotaryEmbedding` + - L91: `def __init__` + - L104: `def forward` + - L108: `class Attention` + - L109: `def __init__` + - L122: `def forward` + - L148: `class SwiGLU` + - L149: `def __init__` + - L156: `def forward` + - L161: `def rms_norm` + +### `models/losses.py` + +- Type: text +- Size: 3804 bytes +- Lines: 1-101 + +- Address anchors: + - L11: `def s` + - L19: `def log_stablemax` + - L24: `def stablemax_cross_entropy` + - L34: `def softmax_cross_entropy` + - L40: `class ACTLossHead` + - L41: `def __init__` + - L46: `def initial_carry` + - L49: `def forward` + +### `models/multimodal_grounding.py` + +- Type: text +- Size: 8774 bytes +- Lines: 1-258 + +- Address anchors: + - L24: `class ModalityEncoder` + - L37: `def __init__` + - L60: `def forward` + - L88: `class CrossModalAttention` + - L101: `def __init__` + - L122: `def forward` + - L158: `class MultiModalGrounding` + - L180: `def __init__` + - L213: `def forward` + +### `models/muon_optimizer.py` + +- Type: text +- Size: 6664 bytes +- Lines: 1-191 + +- Address anchors: + - L26: `class Muon` + - L44: `def __init__` + - L67: `def step` + - L137: `def _newton_schulz_orthogonalize` + - L170: `def _distributed_allreduce_grads` + +### `models/proper_equivariance.py` + +- Type: text +- Size: 11960 bytes +- Lines: 1-335 + +- Address anchors: + - L27: `class SO3Rotation` + - L37: `def axis_angle_to_matrix` + - L74: `def matrix_to_quaternion` + - L93: `class WignerDMatrices` + - L107: `def wigner_d_small` + - L170: `def rotation_to_euler` + - L187: `class ProperSE3EquivariantLayer` + - L204: `def __init__` + - L243: `def _positional_encoding` + - L259: `def forward` + +### `models/sparse_embedding.py` + +- Type: text +- Size: 4366 bytes +- Lines: 1-132 + +- Address anchors: + - L11: `class CastedSparseEmbedding` + - L12: `def __init__` + - L28: `def forward` + - L41: `class CastedSparseEmbeddingSignSGD_Distributed` + - L42: `def __init__` + - L63: `def step` + - L98: `def _sparse_emb_signsgd_dist` + +### `models/spectral_conv.py` + +- Type: text +- Size: 6117 bytes +- Lines: 1-193 + +- Address anchors: + - L25: `class GraphLaplacian` + - L34: `def __init__` + - L38: `def forward` + - L74: `class SpectralGraphConv` + - L88: `def __init__` + - L120: `def _chebyshev_polynomials` + - L159: `def forward` + +### `models/topological.py` + +- Type: text +- Size: 7232 bytes +- Lines: 1-215 + +- Address anchors: + - L25: `class DifferentiableBettiNumbers` + - L41: `def __init__` + - L68: `def _compute_distance_matrix` + - L74: `def _soft_threshold` + - L87: `def forward` + - L156: `class TopologicalAwareness` + - L171: `def __init__` + - L191: `def forward` + +### `models/ttt_layer.py` + +- Type: text +- Size: 7388 bytes +- Lines: 1-206 + +- Address anchors: + - L25: `class TTTLinear` + - L45: `def __init__` + - L74: `def forward` + - L148: `class TTTLinearWithAttention` + - L163: `def __init__` + - L191: `def forward` + +### `models/uncertainty.py` + +- Type: text +- Size: 7009 bytes +- Lines: 1-206 + +- Address anchors: + - L24: `class VariationalLinear` + - L41: `def __init__` + - L61: `def forward` + - L78: `def kl_divergence` + - L106: `class UncertaintyQuantification` + - L121: `def __init__` + - L151: `def forward` + +### `models/vjepa/flow_matching.py` + +- Type: text +- Size: 9015 bytes +- Lines: 1-281 + +- Address anchors: + - L28: `class SinusoidalTimeEmbedding` + - L31: `def __init__` + - L35: `def forward` + - L53: `class VelocityField` + - L61: `def __init__` + - L89: `def forward` + - L125: `class ConditionalFlowMatching` + - L150: `def __init__` + - L165: `def forward` + - L211: `def sample` + - L257: `def sample_rectified` + +### `models/vjepa/gaussian_splatting.py` + +- Type: text +- Size: 7036 bytes +- Lines: 1-203 + +- Address anchors: + - L28: `class LatentGaussianSplatting` + - L40: `def __init__` + - L75: `def _parse_gaussians` + - L118: `def _quaternion_to_rotation_matrix` + - L146: `def forward` + +### `models/vjepa/layers.py` + +- Type: text +- Size: 6641 bytes +- Lines: 1-157 + +- Address anchors: + - L8: `class LieGroupEquivariantLayer` + - L14: `def __init__` + - L31: `def forward` + - L50: `class LatentRayMarcher` + - L56: `def __init__` + - L71: `def forward` + - L113: `def apply_rotary_pos_emb_3d` + - L114: `def rotate_half` + - L128: `class RotaryEmbedding3D` + - L129: `def __init__` + - L141: `def _get_freqs` + - L147: `def _build_cache` + - L152: `def forward` + +### `models/vjepa/losses.py` + +- Type: text +- Size: 977 bytes +- Lines: 1-31 + +- Address anchors: + - L4: `def vicreg_loss` + - L22: `def covariance_loss` + +### `models/vjepa/memory.py` + +- Type: text +- Size: 13196 bytes +- Lines: 1-392 + +- Address anchors: + - L28: `class ResonatorNetwork` + - L49: `def __init__` + - L66: `def set_cleanup_memory` + - L70: `def cleanup` + - L96: `def resonator_step` + - L119: `def _unbind` + - L125: `def forward` + - L178: `class HolographicMemory` + - L198: `def __init__` + - L219: `def _bind_hrr` + - L225: `def _unbind_hrr` + - L231: `def _bind_fhrr` + - L243: `def _unbind_fhrr` + - L249: `def bind` + - L255: `def unbind` + - L261: `def superpose` + - L284: `def forward` + - L312: `def retrieve` + - L327: `def retrieve_with_cleanup` + - L348: `def multi_retrieve` + - L385: `def set_cleanup_memory` + +### `models/vjepa/physics_engine.py` + +- Type: text +- Size: 4774 bytes +- Lines: 1-107 + +- Address anchors: + - L6: `class HRMPhysicsODE` + - L12: `def __init__` + - L30: `def forward` + - L61: `class ContinuousTimeHRM` + - L67: `def __init__` + - L71: `def forward` + +### `models/vjepa/planning.py` + +- Type: text +- Size: 15338 bytes +- Lines: 1-458 + +- Address anchors: + - L28: `class MCTSNode` + - L49: `def __init__` + - L67: `def mean_value` + - L74: `def is_expanded` + - L79: `def effective_visits` + - L83: `def puct_score` + - L120: `class MCTS` + - L139: `def __init__` + - L160: `def _imagine_future` + - L192: `def _select` + - L222: `def _expand` + - L280: `def _backpropagate` + - L306: `def _get_action_probabilities` + - L340: `def plan` + - L393: `class LatentPlannerMCTS` + - L408: `def __init__` + - L424: `def plan` + - L442: `def plan_with_uncertainty` + +### `models/vjepa/predictor.py` + +- Type: text +- Size: 10722 bytes +- Lines: 1-260 + +- Address anchors: + - L22: `class VJEPAPredictorInner` + - L28: `def __init__` + - L156: `def forward` + +### `models/vjepa/symplectic_integrator.py` + +- Type: text +- Size: 4606 bytes +- Lines: 1-137 + +- Address anchors: + - L31: `class SymplecticEulerIntegrator` + - L48: `def __init__` + - L67: `def set_action` + - L71: `def hamiltonian` + - L87: `def forward` + - L126: `def compute_energy` + +### `models/vjepa/utils.py` + +- Type: text +- Size: 1478 bytes +- Lines: 1-44 + +- Address anchors: + - L3: `def get_block_mask` + - L21: `def apply_mask` + +### `models/vjepa/vit.py` + +- Type: text +- Size: 2960 bytes +- Lines: 1-94 + +- Address anchors: + - L9: `class PatchEmbed3D` + - L15: `def __init__` + - L20: `def forward` + - L30: `class VisionTransformerBlock` + - L31: `def __init__` + - L43: `def _forward_inner` + - L48: `def forward` + - L54: `class VisionEncoder` + - L55: `def __init__` + - L86: `def forward` + +### `models/vjepa/vjepa_model.py` + +- Type: text +- Size: 5221 bytes +- Lines: 1-127 + +- Address anchors: + - L12: `class VJEPA` + - L23: `def __init__` + - L72: `def update_target_encoder` + - L77: `def forward` + +### `pretrain.py` + +- Type: text +- Size: 15607 bytes +- Lines: 1-453 + +- Address anchors: + - L26: `class LossConfig` + - L32: `class ArchConfig` + - L39: `class PretrainConfig` + - L74: `class TrainState` + - L84: `def create_dataloader` + - L108: `def create_model` + - L162: `def cosine_schedule_with_warmup_lr_lambda` + - L172: `def init_train_state` + - L190: `def save_train_state` + - L199: `def compute_lr` + - L209: `def train_batch` + - L266: `def evaluate` + - L333: `def save_code_and_config` + - L359: `def load_synced_config` + - L381: `def launch` + +### `puzzle_dataset.py` + +- Type: text +- Size: 7980 bytes +- Lines: 1-199 + +- Address anchors: + - L14: `def _sample_batch` + - L41: `class PuzzleDatasetConfig` + - L53: `class PuzzleDataset` + - L54: `def __init__` + - L68: `def _load_metadata` + - L72: `def _lazy_load_dataset` + - L95: `def _collate_batch` + - L118: `def _iter_test` + - L151: `def _iter_train` + - L189: `def __iter__` + +### `puzzle_visualizer.html` + +- Type: text +- Size: 14119 bytes +- Lines: 1-426 + +- First non-empty content (anchor): L1: `` + +### `requirements.txt` + +- Type: text +- Size: 270 bytes +- Lines: 1-16 + +- First non-empty content (anchor): L1: `torch` + +### `utils/functions.py` + +- Type: text +- Size: 516 bytes +- Lines: 1-19 + +- Address anchors: + - L5: `def load_model_class` + - L15: `def get_model_source_path` + +### `vjepa_train.py` + +- Type: text +- Size: 7591 bytes +- Lines: 1-216 + +- Address anchors: + - L18: `def build_optimizer` + - L92: `class CombinedOptimizer` + - L97: `def __init__` + - L100: `def zero_grad` + - L104: `def step` + - L109: `def param_groups` + - L116: `def train` diff --git a/README.md b/README.md index a6b77cc2..96243230 100644 --- a/README.md +++ b/README.md @@ -1,49 +1,104 @@ -# Hierarchical Reasoning Model - V-JEPA Integration (AGI Scale) +# Hierarchical Reasoning Model (HRM) + V-JEPA -This repository hosts the advanced integration of the Hierarchical Reasoning Model (HRM) with the Video Joint-Embedding Predictive Architecture (V-JEPA), scaled to a **10 Billion parameter** architecture for deep physical world understanding. +An advanced research codebase for **continuous-time world modeling** from video, combining: +- **HRM** (hierarchical latent reasoning), +- **V-JEPA** (self-supervised predictive representation learning), +- and mathematically grounded modules for dynamics, geometry, planning, and uncertainty. -## Core Vision -Transitioning from discrete puzzle-solving to continuous-time, latent-space reasoning. The model is designed to learn **intuitive physics** (depth, shadows, object permanence, continuity) autonomously from raw video data, achieving a human-like understanding of the physical world. +--- + +## Purpose +Build a practical foundation for models that can: +1. Learn physical regularities directly from raw video, +2. Reason over future latent trajectories, +3. Support intervention-aware planning in latent space. + +This repository transitions from discrete puzzle-style reasoning to **continuous latent dynamics** with explicit architectural support for long-horizon prediction. + +## Vision +Our vision is a model that develops robust **intuitive physics** (e.g., continuity, object permanence, motion consistency, and causal effects of actions) by combining representation learning, geometric priors, and dynamics-aware objectives. + +## Goal +Deliver a scalable and analyzable training stack that can evolve from micro-scale experiments to large configurations (including 10B-class settings) while preserving: +- modularity, +- mathematical interpretability, +- and reproducible workflow. + +--- + +## Technical Architecture (Concept Map) -## Key Architectural Pillars +### 1) Spatio-Temporal Representation (Vision Encoder) +- **3D patch embedding** over `(T, H, W)` video volumes. +- **3D-RoPE** positional encoding in time-height-width coordinates. +- ViT-style latent tokenization for downstream predictive modeling. -### 1. Vision Encoder (The Eyes) -* **3D Patch Embedding**: Processes video clips as spatio-temporal volumes. -* **3D-RoPE**: 3D Rotary Positional Embeddings that natively encode Time, Height, and Width coordinates. -* **10B Scale ViT**: A massive Vision Transformer designed to capture high-density visual information. +### 2) Geometric Inductive Biases +- **Lie-group / equivariance-oriented layers** for transformation-aware latent features. +- **Stiefel-manifold style orthogonality constraints/projections** to stabilize relational geometry. +- **Proper SE(3)-inspired processing** for physically meaningful transformations. -### 2. Physical Relativity (Lie Group Equivariance) -* **Stiefel Manifold Projections**: Implements $O(D)$ complexity equivariant transformations using Cayley transforms, ensuring physical laws are relative across 10B parameter manifolds. +### 3) Continuous-Time Latent Dynamics +- **Hamiltonian-style latent dynamics** components. +- **Neural ODE adjoint** pathway (`torchdiffeq`) for memory-efficient continuous-time learning. +- **Symplectic integration path** for structure-preserving latent evolution at inference-style rollout. -### 3. Continuous-Time Brain (Hamiltonian Neural ODEs & Predictive Coding) -* **Symplectic Physics Engine (HNN)**: Uses **Hamiltonian Neural Networks** to compute the continuous-time dynamics ($dq/dt$, $dp/dt$). This guarantees absolute energy conservation and strict adherence to classical mechanics within the latent imagination space. -* **Adjoint Neural ODEs**: Uses the **Neural ODE Adjoint Method** (`torchdiffeq`) for constant-memory backpropagation, enabling infinite-depth continuous reasoning at 10B scale. -* **Top-Down Predictive Coding**: A hierarchical "handshake" where High-Level planning ($z_H$) suppresses error signals from Low-Level sensors ($z_L$), mimicking the human visual cortex. -* **Holographic Memory**: Vector Symbolic Architecture (VSA) based memory that binds and stores complex physical experiences into dense, high-dimensional holographic states. +### 4) Hierarchical Predictive Reasoning +- High/Low cycle interaction (`H_cycles`, `L_cycles`) for iterative latent refinement. +- Predictive coding flavor with top-down influence and bottom-up correction pressure. +- Adaptive compute hooks (e.g., ACT/depth controller) for confidence-aware depth. -### 4. Light & Shadow Intuition (Neural Radiance Latents) -* **Volumetric Ray-Marching**: Treats the latent space as a **Differentiable Continuous Radiance Field (NeRF)**. The model "traces" light and reflections through its imagined 3D manifold. +### 5) World Rendering and Latent Scene Composition +- **Latent Gaussian Splatting** path for explicit scene primitive aggregation. +- NeRF-inspired latent rendering concepts for geometry/appearance reasoning. -### 5. Latent Planning (The Imagination) -* **Latent MCTS**: Monte Carlo Tree Search operating entirely in latent space, allowing the model to "imagine" and evaluate thousands of future physical outcomes. -* **Action Conditioning**: Future states predicted conditioned on specific physical actions/interventions. +### 6) Latent Planning & Decision Support +- **Latent MCTS** module for action-conditioned future evaluation. +- Value estimation head for ranking latent future states. -### 6. Advanced Training Engine -* **VICReg Objective**: Variance-Covariance regularization to prevent representation collapse. -* **3D Block Masking**: Spatio-temporal masking that forces the model to infer large missing segments of the world. +### 7) Multi-Modal and Robustness Extensions +- Hooks for **audio** and **tactile/proprioceptive** grounding. +- **Uncertainty estimation**, **information bottleneck**, **topology-aware**, and **spectral** auxiliary modules. -## Getting Started +### 8) Training Stack +- **VICReg** objective (invariance + variance/covariance regularization). +- Spatio-temporal masking regime. +- Optimizer backends: **AdamW**, **Muon**, or **Hybrid Muon+AdamW**. +- EMA target encoder for stable JEPA-style targets. -### Configuration -Adjust the 10B parameter specs in `config/vjepa_10b.yaml`. +--- + +## Repository Workflow -### Training +### Configurations +- **Micro / local iteration**: `config/vjepa_micro.yaml` +- **Large-scale profile**: `config/vjepa_10b.yaml` + +### Training Entrypoint ```bash -python vjepa_train.py +python vjepa_train.py --config config/vjepa_micro.yaml +# or +python vjepa_train.py --config config/vjepa_10b.yaml ``` -## Future Multimodal Grounding -The architecture is designed to be modality-agnostic, with hooks ready for future integration of **Audio** and **Tactile (Proprioceptive)** data. +`vjepa_train.py` accepts `--config` and loads runtime behavior from YAML. +`training.epochs` can be set in YAML (defaults to `100` if omitted). + +### Practical Notes +- Place video files in `data/` for training. +- If `data/` is absent, the script attempts to create it and generate a small synthetic test video via `ffmpeg`. +- For phase-1 rigorous world-model checks, run: + `python evaluate_world_model.py --config config/vjepa_micro.yaml --seed 42` + (saves JSON manifests in `eval_runs/`). --- -*This project is dedicated to pushing the boundaries of artificial general intelligence through the lens of hierarchical physical reasoning.* + +## Roadmap Direction +- Stronger experiment tracking and benchmark reports. +- Expanded multimodal pretraining/evaluation. +- Systematic ablations on dynamics engines (ODE vs. flow matching vs. symplectic rollout). +- Better reproducibility packaging for large-scale distributed runs. + +--- + +This project is focused on pushing **hierarchical physical reasoning** toward robust, scalable world models with clear technical structure and research extensibility. diff --git a/docs/FRONTIER_GAP_ANALYSIS.md b/docs/FRONTIER_GAP_ANALYSIS.md new file mode 100644 index 00000000..c3626aeb --- /dev/null +++ b/docs/FRONTIER_GAP_ANALYSIS.md @@ -0,0 +1,35 @@ +# Frontier Capability Gap Analysis and Implementation Plan + +## Scope +This document compares: +1. Existing repository capabilities. +2. State-of-the-art (frontier labs + top academic trends) capability expectations. +3. Immediate implementation decisions. + +## Comparison Matrix + +| Area | Existing in repo | Frontier expectation | Gap | Action | +|---|---|---|---|---| +| Latent world modeling | V-JEPA-style masked latent prediction and EMA target path | Long-horizon stable latent rollouts with robust eval | Partial | Add dedicated world-model evaluation harness (next step) | +| Continuous-time dynamics | Hamiltonian/ODE/symplectic modules present | Quantitative invariance + long-horizon stability metrics | Missing benchmarks | Add metrics + ablations (next step) | +| Latent planning (MCTS) | MCTS scaffold existed with placeholder action priors | Policy-informed planning priors and uncertainty-aware scoring | Prior quality gap | Implemented policy-query action priors in MCTS | +| Uncertainty | Uncertainty module present | Planning/calibration integration | Partial | Integrate uncertainty into planner scoring (planned) | +| Multimodal grounding | Audio/tactile hooks + cross-modal attention | Curriculum and modality-drop robustness metrics | Partial | Add modality-drop ablations (planned) | +| Reproducible workflow | Configurable training entrypoint | Evaluation protocol + run manifests + acceptance gates | Partial | Add benchmark specs and run manifests (planned) | + +## Implemented in this change + +### 1) MCTS action prior upgrade +Previously, planning used a placeholder prior logits tensor. We replaced that with a learned action-prior mechanism: +- A new `policy_query_head` in `VJEPA` maps latent state to an action-space query vector. +- MCTS computes action priors from dot-product similarity between candidate actions and the learned query. +- This upgrades search from uniform/placeholder priors to model-informed priors. + +## Why this is prioritized first +Planning quality depends heavily on action priors. Replacing placeholder priors is a high-leverage improvement that directly improves practical controllable rollout search. + +## Next technical steps (ordered) +1. Add `evaluate_world_model.py` with rollout drift, action-consistency, and calibration metrics. +2. Wire uncertainty estimates into PUCT scoring (risk-aware planning). +3. Add ablation configs for dynamics engines and multimodal drop robustness. +4. Define acceptance thresholds for promotion of each advanced module. diff --git a/docs/RIGOROUS_DEVELOPMENT_PROTOCOL.md b/docs/RIGOROUS_DEVELOPMENT_PROTOCOL.md new file mode 100644 index 00000000..1c7c8fbe --- /dev/null +++ b/docs/RIGOROUS_DEVELOPMENT_PROTOCOL.md @@ -0,0 +1,33 @@ +# Rigorous Development Protocol (Phase-1) + +This protocol defines minimum rigor gates for model and planner changes. + +## Gate A — Sanity / Determinism +1. Python compile sanity for repository modules. +2. Seeded execution for evaluation scripts. +3. Basic tensor-shape and NaN safety in smoke runs. + +## Gate B — World-model Metrics +Run: + +```bash +python evaluate_world_model.py --config config/vjepa_micro.yaml --seed 42 +``` + +Required outputs: +- `rollout_drift_l2` +- `trajectory_divergence_l2` +- `max_action_prior` +- `action_prior_entropy` + +All metrics are persisted as a JSON run manifest in `eval_runs/`. + +## Gate C — Change Promotion +Any feature PR must include: +1. Before/after metric table (same config + seed). +2. Ablation switch (enable/disable path). +3. Short rationale for metric movement. + +## Notes +- This is a phase-1 lightweight protocol and will be extended with calibration + and long-horizon benchmark suites in subsequent iterations. diff --git a/evaluate_world_model.py b/evaluate_world_model.py new file mode 100644 index 00000000..ac981b1d --- /dev/null +++ b/evaluate_world_model.py @@ -0,0 +1,166 @@ +""" +World-model evaluation harness for V-JEPA/HRM. + +Phase-1 rigorous evaluation focuses on: + 1) rollout drift (latent-state drift over repeated imagination) + 2) action consistency (whether action-conditioned futures are distinct) + 3) calibration-oriented proxy metrics (uncertainty magnitude + confidence proxy) + +This script is intentionally lightweight and self-contained so it can be +used early in development before full benchmark infrastructure is added. +""" + +import argparse +import json +import os +import random +from dataclasses import dataclass, asdict +from datetime import datetime, timezone +from typing import Dict, List + +import torch +import yaml + +from models.vjepa.vjepa_model import VJEPA + + +def set_seed(seed: int) -> None: + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +@dataclass +class EvalManifest: + timestamp_utc: str + commit: str + config_path: str + seed: int + device: str + batch_size: int + rollout_steps: int + num_actions: int + metrics: Dict[str, float] + + +def get_commit_hash(default: str = "unknown") -> str: + head = os.path.join(".git", "HEAD") + if not os.path.exists(head): + return default + try: + with open(head, "r", encoding="utf-8") as f: + ref = f.read().strip() + if ref.startswith("ref: "): + ref_path = os.path.join(".git", ref.split(" ", 1)[1]) + with open(ref_path, "r", encoding="utf-8") as f: + return f.read().strip()[:12] + return ref[:12] + except Exception: + return default + + +def latent_rollout( + model: VJEPA, + state: torch.Tensor, + actions: List[torch.Tensor], +) -> List[torch.Tensor]: + states = [state] + cur = state + for action in actions: + dt = torch.ones(cur.shape[0], device=cur.device) + cur = model.predictor.physics_engine(cur, dt, action=action) + states.append(cur) + return states + + +def evaluate_metrics(model: VJEPA, device: torch.device, rollout_steps: int, num_actions: int) -> Dict[str, float]: + model.eval() + with torch.no_grad(): + dim = model.value_head[0].in_features + bs = 1 + seq_len = 16 + z0 = torch.randn(bs, seq_len, dim, device=device) + + actions_a = [torch.randn(bs, seq_len, 128, device=device) for _ in range(rollout_steps)] + actions_b = [torch.randn(bs, seq_len, 128, device=device) for _ in range(rollout_steps)] + + traj_a = latent_rollout(model, z0, actions_a) + traj_b = latent_rollout(model, z0, actions_b) + + # 1) rollout drift: average step-to-step displacement in a rollout + step_drifts = [] + for t in range(1, len(traj_a)): + step_drifts.append((traj_a[t] - traj_a[t - 1]).pow(2).mean().sqrt().item()) + rollout_drift = float(sum(step_drifts) / max(len(step_drifts), 1)) + + # 2) action consistency proxy: trajectories from distinct actions should diverge + trajectory_divergence = float((traj_a[-1] - traj_b[-1]).pow(2).mean().sqrt().item()) + + # 3) planner prior concentration (confidence proxy) + available_actions = torch.randn(num_actions, 128, device=device) + pooled = traj_a[-1].mean(dim=1) + query = model.policy_query_head(pooled).squeeze(0) + logits = torch.matmul(available_actions, query) + probs = torch.softmax(logits, dim=0) + max_prior = float(probs.max().item()) + prior_entropy = float(-(probs * (probs + 1e-9).log()).sum().item()) + + return { + "rollout_drift_l2": rollout_drift, + "trajectory_divergence_l2": trajectory_divergence, + "max_action_prior": max_prior, + "action_prior_entropy": prior_entropy, + } + + +def main() -> None: + parser = argparse.ArgumentParser(description="Evaluate V-JEPA/HRM world-model metrics") + parser.add_argument("--config", default="config/vjepa_micro.yaml") + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--rollout-steps", type=int, default=8) + parser.add_argument("--num-actions", type=int, default=32) + parser.add_argument("--save-dir", default="eval_runs") + args = parser.parse_args() + + set_seed(args.seed) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + with open(args.config, "r", encoding="utf-8") as f: + config = yaml.safe_load(f) + + model = VJEPA( + config["encoder"], + config["predictor"], + config["training"]["ema_momentum"], + action_dim=128, + ).to(device) + + metrics = evaluate_metrics( + model=model, + device=device, + rollout_steps=args.rollout_steps, + num_actions=args.num_actions, + ) + + os.makedirs(args.save_dir, exist_ok=True) + manifest = EvalManifest( + timestamp_utc=datetime.now(timezone.utc).isoformat(), + commit=get_commit_hash(), + config_path=args.config, + seed=args.seed, + device=str(device), + batch_size=1, + rollout_steps=args.rollout_steps, + num_actions=args.num_actions, + metrics=metrics, + ) + + out_path = os.path.join(args.save_dir, f"world_model_eval_{manifest.commit}_{args.seed}.json") + with open(out_path, "w", encoding="utf-8") as f: + json.dump(asdict(manifest), f, indent=2) + + print("Evaluation complete.") + print(json.dumps(asdict(manifest), indent=2)) + + +if __name__ == "__main__": + main() diff --git a/models/vjepa/planning.py b/models/vjepa/planning.py index 50ab55f0..9b2f811c 100644 --- a/models/vjepa/planning.py +++ b/models/vjepa/planning.py @@ -172,7 +172,7 @@ def _imagine_future( Returns: next_state: (1, D) predicted next state. value: scalar value estimate of the next state. - policy_logits: (1, num_actions) action prior logits. + policy_query: (1, action_dim) action-prior query embedding. """ # Use the physics engine for dynamics prediction delta_t = torch.ones(state.shape[0], device=state.device) @@ -183,11 +183,11 @@ def _imagine_future( next_state.mean(dim=1) if next_state.ndim > 2 else next_state ).item() - # Estimate action priors (simple: use cosine similarity with available actions) - # In a full implementation, this would come from a policy network - policy_logits = torch.zeros(1, device=state.device) # placeholder + # Estimate action priors from a learned policy-query head. + pooled_next_state = next_state.mean(dim=1) if next_state.ndim > 2 else next_state + policy_query = self.model.policy_query_head(pooled_next_state) - return next_state, value, policy_logits + return next_state, value, policy_query def _select(self, node: MCTSNode) -> MCTSNode: """ @@ -223,7 +223,7 @@ def _expand( self, node: MCTSNode, available_actions: torch.Tensor, - policy_logits: Optional[torch.Tensor] = None, + policy_query: Optional[torch.Tensor] = None, ) -> float: """ Expand a node by creating children for available actions. @@ -234,7 +234,7 @@ def _expand( Args: node: the node to expand. available_actions: (num_actions, action_dim) action set. - policy_logits: (num_actions,) optional prior logits. + policy_query: (1, action_dim) optional action-prior query vector. Returns: Value estimate of the expanded node. @@ -248,9 +248,18 @@ def _expand( max_children = max(1, int(self.pw_c * (node.visits + 1) ** self.pw_alpha)) max_children = min(max_children, num_actions) - # Compute priors from policy logits - if policy_logits is not None and policy_logits.numel() > 0: - priors = F.softmax(policy_logits[:num_actions] / self.temperature, dim=0) + # Compute priors from policy query vector. + if policy_query is None and hasattr(self.model, "policy_query_head"): + with torch.no_grad(): + pooled_state = node.state.mean(dim=1) if node.state.ndim > 2 else node.state + policy_query = self.model.policy_query_head(pooled_state) + + if policy_query is not None and policy_query.numel() > 0: + # Similarity(action_i, query) -> prior logit + # available_actions: (num_actions, action_dim) + # policy_query: (1, action_dim) + logits = torch.matmul(available_actions, policy_query.squeeze(0)) + priors = F.softmax(logits / self.temperature, dim=0) else: # Uniform prior if no policy network priors = torch.ones(num_actions, device=available_actions.device) / num_actions diff --git a/models/vjepa/vjepa_model.py b/models/vjepa/vjepa_model.py index f11a6174..ad46d346 100644 --- a/models/vjepa/vjepa_model.py +++ b/models/vjepa/vjepa_model.py @@ -61,6 +61,15 @@ def __init__(self, nn.Linear(predictor_config["hidden_size"], 1) ) + # 5b. Policy query head for action-prior scoring in latent MCTS. + # Produces an action-space query vector that can be matched against + # candidate action vectors via dot-product similarity. + self.policy_query_head = nn.Sequential( + nn.Linear(predictor_config["hidden_size"], predictor_config["hidden_size"]), + nn.SiLU(), + nn.Linear(predictor_config["hidden_size"], action_dim) + ) + # 6. Adaptive depth controller for test-time compute scaling self.depth_controller = AdaptiveDepthController( max_depth=predictor_config.get("halt_max_steps", 8), diff --git a/vjepa_train.py b/vjepa_train.py index e90cd0c7..7fb9d7bc 100644 --- a/vjepa_train.py +++ b/vjepa_train.py @@ -1,4 +1,5 @@ import os +import argparse import yaml import torch from torch import nn @@ -169,7 +170,8 @@ def train(config_path="config/vjepa_micro.yaml"): # 5. Training Loop model.train() - for epoch in range(100): + epochs = int(config.get("training", {}).get("epochs", 100)) + for epoch in range(epochs): for i, batch in enumerate(dataloader): # Move batch to device batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} @@ -204,4 +206,11 @@ def train(config_path="config/vjepa_micro.yaml"): # wandb.log({"loss": loss.item()}) if __name__ == "__main__": - train() + parser = argparse.ArgumentParser(description="Train V-JEPA/HRM model") + parser.add_argument( + "--config", + default="config/vjepa_micro.yaml", + help="Path to YAML config file (e.g., config/vjepa_micro.yaml or config/vjepa_10b.yaml)", + ) + args = parser.parse_args() + train(args.config) From 07b8481c0ef681125539cfe147d73c74662daa55 Mon Sep 17 00:00:00 2001 From: jeevesh415 Date: Sat, 16 May 2026 21:39:12 +0530 Subject: [PATCH 2/7] Add perception evaluation track and harden training bootstrap fallback --- README.md | 9 +++ docs/HUMAN_VISION_EXECUTION_EVAL_SPEC.md | 48 +++++++++++++ evaluate_perception.py | 90 ++++++++++++++++++++++++ vjepa_train.py | 13 ++-- 4 files changed, 156 insertions(+), 4 deletions(-) create mode 100644 docs/HUMAN_VISION_EXECUTION_EVAL_SPEC.md create mode 100644 evaluate_perception.py diff --git a/README.md b/README.md index 96243230..d4676035 100644 --- a/README.md +++ b/README.md @@ -90,6 +90,15 @@ python vjepa_train.py --config config/vjepa_10b.yaml - For phase-1 rigorous world-model checks, run: `python evaluate_world_model.py --config config/vjepa_micro.yaml --seed 42` (saves JSON manifests in `eval_runs/`). +- For perception robustness checks (color/shadow/noise/shift), run: + `python evaluate_perception.py --config config/vjepa_micro.yaml --seed 42` + (saves JSON manifests in `eval_runs/`). + +### Final Execution Checklist (Do This) +1. `python -m compileall -q .` +2. `python evaluate_world_model.py --config config/vjepa_micro.yaml --seed 42` +3. `python evaluate_perception.py --config config/vjepa_micro.yaml --seed 42` +4. `python vjepa_train.py --config config/vjepa_micro.yaml` (with real videos in `data/`, or with `ffmpeg` installed) --- diff --git a/docs/HUMAN_VISION_EXECUTION_EVAL_SPEC.md b/docs/HUMAN_VISION_EXECUTION_EVAL_SPEC.md new file mode 100644 index 00000000..21141f88 --- /dev/null +++ b/docs/HUMAN_VISION_EXECUTION_EVAL_SPEC.md @@ -0,0 +1,48 @@ +# Human-Vision + Execution Evaluation Spec (Initial) + +This document translates project purpose into executable evaluation tracks. + +## Purpose Alignment +Target capabilities: +- color and illumination robustness +- depth/geometry continuity +- shadow/reflectance stability +- action-conditioned future consistency +- long-horizon cognitive execution + +## Track A: Perception Robustness (implemented baseline) +Command: +```bash +python evaluate_perception.py --config config/vjepa_micro.yaml --seed 42 +``` + +Outputs: +- `color_jitter_latent_l2` +- `brightness_shadow_latent_l2` +- `gaussian_noise_latent_l2` +- `spatial_shift_latent_l2` + +Lower is better (more invariant latent representations). + +## Track B: World-Model Dynamics (implemented baseline) +Command: +```bash +python evaluate_world_model.py --config config/vjepa_micro.yaml --seed 42 +``` + +Outputs: +- `rollout_drift_l2` +- `trajectory_divergence_l2` +- `max_action_prior` +- `action_prior_entropy` + +## Track C: Execution/Cognition (next) +- goal-conditioned planning success@k +- action counterfactual consistency +- uncertainty-aware risk-return tradeoff + +## Promotion Rule (Phase-1/2) +A change is promoted only if: +1. no regression in compile/smoke execution, +2. no major degradation in Track A/B metrics for same seed/config, +3. rationale + ablation switch is documented. diff --git a/evaluate_perception.py b/evaluate_perception.py new file mode 100644 index 00000000..01c382a5 --- /dev/null +++ b/evaluate_perception.py @@ -0,0 +1,90 @@ +""" +Phase-2 perception robustness evaluation for HRM + V-JEPA. + +Focus: +- color robustness +- brightness/shadow robustness +- noise robustness +- geometric perturbation robustness + +This uses latent consistency between original and perturbed clips as an +early proxy for perceptual invariance before task-specific benchmarks. +""" + +import argparse +import json +import os +from datetime import datetime, timezone +from typing import Dict + +import torch +import yaml + +from models.vjepa.vjepa_model import VJEPA + + +def apply_perturbation(video: torch.Tensor, mode: str) -> torch.Tensor: + if mode == "color_jitter": + scale = torch.tensor([1.1, 0.9, 1.05], device=video.device).view(1, 1, 3, 1, 1) + return (video * scale).clamp(-3.0, 3.0) + if mode == "brightness_shadow": + return (video * 0.7).clamp(-3.0, 3.0) + if mode == "gaussian_noise": + return video + 0.05 * torch.randn_like(video) + if mode == "spatial_shift": + return torch.roll(video, shifts=2, dims=-1) + raise ValueError(f"Unknown perturbation mode: {mode}") + + +def latent_consistency(model: VJEPA, video: torch.Tensor, perturbed: torch.Tensor) -> float: + with torch.no_grad(): + z_ref = model.context_encoder(video) + z_alt = model.context_encoder(perturbed) + return float((z_ref - z_alt).pow(2).mean().sqrt().item()) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--config", default="config/vjepa_micro.yaml") + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--save-dir", default="eval_runs") + args = parser.parse_args() + + torch.manual_seed(args.seed) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + with open(args.config, "r", encoding="utf-8") as f: + cfg = yaml.safe_load(f) + + model = VJEPA( + cfg["encoder"], + cfg["predictor"], + cfg["training"]["ema_momentum"], + action_dim=128, + ).to(device).eval() + + # synthetic clip for deterministic smoke-evaluation + bs, t, c, h, w = 1, cfg["encoder"]["max_t"] * cfg["encoder"]["patch_size"][0], 3, cfg["encoder"]["img_size"], cfg["encoder"]["img_size"] + video = torch.randn(bs, t, c, h, w, device=device) + + metrics: Dict[str, float] = {} + for mode in ["color_jitter", "brightness_shadow", "gaussian_noise", "spatial_shift"]: + pert = apply_perturbation(video, mode) + metrics[f"{mode}_latent_l2"] = latent_consistency(model, video, pert) + + os.makedirs(args.save_dir, exist_ok=True) + out = { + "timestamp_utc": datetime.now(timezone.utc).isoformat(), + "config_path": args.config, + "seed": args.seed, + "device": str(device), + "metrics": metrics, + } + path = os.path.join(args.save_dir, f"perception_eval_seed_{args.seed}.json") + with open(path, "w", encoding="utf-8") as f: + json.dump(out, f, indent=2) + print(json.dumps(out, indent=2)) + + +if __name__ == "__main__": + main() + diff --git a/vjepa_train.py b/vjepa_train.py index 7fb9d7bc..927b87b5 100644 --- a/vjepa_train.py +++ b/vjepa_train.py @@ -1,5 +1,6 @@ import os import argparse +import shutil import yaml import torch from torch import nn @@ -149,10 +150,14 @@ def train(config_path="config/vjepa_micro.yaml"): os.makedirs(video_dir) print(f"Created directory {video_dir}. Please add videos here.") import subprocess - subprocess.run([ - 'ffmpeg', '-f', 'lavfi', '-i', 'testsrc=duration=5:size=224x224:rate=15', - os.path.join(video_dir, 'test_video.mp4'), '-y' - ], capture_output=True) + ffmpeg_bin = shutil.which("ffmpeg") + if ffmpeg_bin is None: + print("ffmpeg not found; skipping synthetic video generation. Add videos manually to data/.") + else: + subprocess.run([ + ffmpeg_bin, '-f', 'lavfi', '-i', 'testsrc=duration=5:size=224x224:rate=15', + os.path.join(video_dir, 'test_video.mp4'), '-y' + ], capture_output=True, check=False) video_paths = [os.path.join(video_dir, f) for f in os.listdir(video_dir) if f.endswith(('.mp4', '.avi', '.mov'))] if not video_paths: From 52e879b0ad9829b6958e247c80298738ee03e3bd Mon Sep 17 00:00:00 2001 From: jeevesh415 Date: Sun, 17 May 2026 16:02:13 +0530 Subject: [PATCH 3/7] Unify naming under Visual Execution Model framework --- README.md | 19 +++++++++---------- models/vjepa/vjepa_model.py | 7 ++++++- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index d4676035..1ed8f2a5 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,8 @@ -# Hierarchical Reasoning Model (HRM) + V-JEPA +# Visual Execution Model (VEM) -An advanced research codebase for **continuous-time world modeling** from video, combining: -- **HRM** (hierarchical latent reasoning), -- **V-JEPA** (self-supervised predictive representation learning), -- and mathematically grounded modules for dynamics, geometry, planning, and uncertainty. +A single integrated framework for **continuous-time world modeling** from video. + +Visual Execution Model (VEM) unifies hierarchical reasoning, predictive representation learning, dynamics, geometry, planning, uncertainty, and multimodal grounding inside one model stack (not separate models). --- @@ -13,7 +12,7 @@ Build a practical foundation for models that can: 2. Reason over future latent trajectories, 3. Support intervention-aware planning in latent space. -This repository transitions from discrete puzzle-style reasoning to **continuous latent dynamics** with explicit architectural support for long-horizon prediction. +This repository is organized as one unified model pipeline with scalable sizes and optional modules, so every capability is part of the same framework and execution graph. ## Vision Our vision is a model that develops robust **intuitive physics** (e.g., continuity, object permanence, motion consistency, and causal effects of actions) by combining representation learning, geometric priors, and dynamics-aware objectives. @@ -68,11 +67,11 @@ Deliver a scalable and analyzable training stack that can evolve from micro-scal --- -## Repository Workflow +## Repository Workflow (Single Framework) ### Configurations -- **Micro / local iteration**: `config/vjepa_micro.yaml` -- **Large-scale profile**: `config/vjepa_10b.yaml` +- **Micro scale profile (same model, small size)**: `config/vjepa_micro.yaml` +- **Large scale profile (same model, 10B-class size target)**: `config/vjepa_10b.yaml` ### Training Entrypoint ```bash @@ -81,7 +80,7 @@ python vjepa_train.py --config config/vjepa_micro.yaml python vjepa_train.py --config config/vjepa_10b.yaml ``` -`vjepa_train.py` accepts `--config` and loads runtime behavior from YAML. +`vjepa_train.py` accepts `--config` and loads runtime behavior from YAML. Both configs run the same Visual Execution Model framework at different scales. `training.epochs` can be set in YAML (defaults to `100` if omitted). ### Practical Notes diff --git a/models/vjepa/vjepa_model.py b/models/vjepa/vjepa_model.py index ad46d346..1ca0e2fd 100644 --- a/models/vjepa/vjepa_model.py +++ b/models/vjepa/vjepa_model.py @@ -11,7 +11,7 @@ class VJEPA(nn.Module): """ - Unified V-JEPA Model with HRM-ODE Predictor and Holographic Memory. + Visual Execution Model (VEM): unified single-framework model. Designed for 10B parameter physical world modeling. Enhancements over base: @@ -134,3 +134,8 @@ def forward(self, batch: Dict[str, torch.Tensor]): "all_context": all_latents, "value": value } + + +class VisualExecutionModel(VJEPA): + """Backward-compatible alias for the unified Visual Execution Model name.""" + pass From 9429fdfd3a64db81b6b02b2ff5f4e42fc60551c1 Mon Sep 17 00:00:00 2001 From: jeevesh415 Date: Sun, 17 May 2026 16:20:17 +0530 Subject: [PATCH 4/7] Remove puzzle-era components and harden VEM integrations --- CODE_ADDRESS_INDEX.md | 979 ++-------------------------------- arc_eval.ipynb | 252 --------- check_integrations.py | 56 ++ config/arch/hrm_v1.yaml | 21 - config/cfg_pretrain.yaml | 31 -- dataset/build_arc_dataset.py | 291 ---------- dataset/build_maze_dataset.py | 142 ----- evaluate.py | 68 --- models/adaptive_depth.py | 6 +- models/hrm/hrm_act_v1.py | 283 ---------- models/sparse_embedding.py | 132 ----- models/topological.py | 8 +- pretrain.py | 453 ---------------- puzzle_dataset.py | 199 ------- puzzle_visualizer.html | 426 --------------- 15 files changed, 112 insertions(+), 3235 deletions(-) delete mode 100644 arc_eval.ipynb create mode 100644 check_integrations.py delete mode 100644 config/arch/hrm_v1.yaml delete mode 100644 config/cfg_pretrain.yaml delete mode 100644 dataset/build_arc_dataset.py delete mode 100644 dataset/build_maze_dataset.py delete mode 100644 evaluate.py delete mode 100644 models/hrm/hrm_act_v1.py delete mode 100644 models/sparse_embedding.py delete mode 100644 pretrain.py delete mode 100644 puzzle_dataset.py delete mode 100644 puzzle_visualizer.html diff --git a/CODE_ADDRESS_INDEX.md b/CODE_ADDRESS_INDEX.md index 98d5b96d..69d97f5f 100644 --- a/CODE_ADDRESS_INDEX.md +++ b/CODE_ADDRESS_INDEX.md @@ -1,928 +1,51 @@ -# Code Address Index (Absolute Repository Map) - -This file is an exhaustive navigational index of the repository at generation time. - -- Root: `/workspace/HRM` -- Coverage: all project files except `.git` internals and cache folders. -- For each file: path, line count, byte size, and symbol/key anchors. - -## File Inventory - -### `.github/workflows/sync-from-upstream.yml` - -- Type: text -- Size: 664 bytes -- Lines: 1-19 - -- Address anchors: - - L1: `name` - - L2: `on` - - L3: `schedule` - - L5: `workflow_dispatch` - - L6: `jobs` - - L7: `sync` - - L8: `runs-on` - - L9: `steps` - -### `.gitignore` - -- Type: text -- Size: 3152 bytes -- Lines: 1-169 - -- First non-empty content (anchor): L1: `# WandB` - -### `.gitmodules` - -- Type: text -- Size: 364 bytes -- Lines: 1-9 - -- First non-empty content (anchor): L1: `[submodule "dataset/raw-data/ConceptARC"]` - -### `.vscode/launch.json` - -- Type: text -- Size: 778 bytes -- Lines: 1-26 - -- First non-empty content (anchor): L1: `{` - -### `.vscode/settings.json` - -- Type: text -- Size: 54 bytes -- Lines: 1-3 - -- First non-empty content (anchor): L1: `{` - -### `CODE_ADDRESS_INDEX.md` - -- Type: text -- Size: 20346 bytes -- Lines: 1-927 - -- Address anchors: - - L1: `# Code Address Index (Absolute Repository Map)` - - L9: `## File Inventory` - - L11: `### `.github/workflows/sync-from-upstream.yml`` - - L27: `### `.gitignore`` - - L35: `### `.gitmodules`` - - L43: `### `.vscode/launch.json`` - - L51: `### `.vscode/settings.json`` - - L59: `### `CODE_ADDRESS_INDEX.md`` - - L122: `### `LICENSE`` - - L130: `### `README.md`` - - L152: `### `arc_eval.ipynb`` - - L160: `### `assets/hrm.png`` - - L165: `### `assets/npyjs.js`` - - L173: `### `config/arch/hrm_v1.yaml`` - - L196: `### `config/cfg_pretrain.yaml`` - - L220: `### `config/vjepa_10b.yaml`` - - L289: `### `config/vjepa_micro.yaml`` - - L328: `### `dataset/build_arc_dataset.py`` - - L347: `### `dataset/build_maze_dataset.py`` - - L359: `### `dataset/common.py`` - - L370: `### `dataset/generate_dummy_data.py`` - - L379: `### `dataset/video_dataset.py`` - - L393: `### `evaluate.py`` - - L403: `### `models/adaptive_depth.py`` - - L417: `### `models/common.py`` - - L426: `### `models/hrm/hrm_act_v1.py`` - - L454: `### `models/hybrid_ssm.py`` - - L468: `### `models/information_bottleneck.py`` - - L483: `### `models/layers.py`` - - L511: `### `models/losses.py`` - - L527: `### `models/multimodal_grounding.py`` - - L544: `### `models/muon_optimizer.py`` - - L557: `### `models/proper_equivariance.py`` - - L575: `### `models/sparse_embedding.py`` - - L590: `### `models/spectral_conv.py`` - - L605: `### `models/topological.py`` - - L621: `### `models/ttt_layer.py`` - - L635: `### `models/uncertainty.py`` - - L650: `### `models/vjepa/flow_matching.py`` - - L669: `### `models/vjepa/gaussian_splatting.py`` - - L682: `### `models/vjepa/layers.py`` - - L703: `### `models/vjepa/losses.py`` - - L713: `### `models/vjepa/memory.py`` - - L742: `### `models/vjepa/physics_engine.py`` - - L756: `### `models/vjepa/planning.py`` - - L782: `### `models/vjepa/predictor.py`` - - L793: `### `models/vjepa/symplectic_integrator.py`` - - L807: `### `models/vjepa/utils.py`` - - L817: `### `models/vjepa/vit.py`` - - L835: `### `models/vjepa/vjepa_model.py`` - - L847: `### `pretrain.py`` - - L870: `### `puzzle_dataset.py`` - - L888: `### `puzzle_visualizer.html`` - - L896: `### `requirements.txt`` - - L904: `### `utils/functions.py`` - - L914: `### `vjepa_train.py`` - -### `LICENSE` - -- Type: text -- Size: 11357 bytes -- Lines: 1-202 - -- First non-empty content (anchor): L2: `Apache License` - -### `README.md` - -- Type: text -- Size: 3507 bytes -- Lines: 1-51 - -- Address anchors: - - L1: `# Hierarchical Reasoning Model - V-JEPA Integration (AGI Scale)` - - L5: `## Core Vision` - - L8: `## Key Architectural Pillars` - - L10: `### 1. Vision Encoder (The Eyes)` - - L15: `### 2. Physical Relativity (Lie Group Equivariance)` - - L18: `### 3. Continuous-Time Brain (Hamiltonian Neural ODEs & Predictive Coding)` - - L24: `### 4. Light & Shadow Intuition (Neural Radiance Latents)` - - L27: `### 5. Latent Planning (The Imagination)` - - L31: `### 6. Advanced Training Engine` - - L35: `## Getting Started` - - L37: `### Configuration` - - L40: `### Training` - - L43: `# or` - - L47: `## Future Multimodal Grounding` - -### `arc_eval.ipynb` - -- Type: text -- Size: 9317 bytes -- Lines: 1-252 - -- First non-empty content (anchor): L1: `{` - -### `assets/hrm.png` - -- Type: binary/non-UTF8 -- Size: 99852 bytes - -### `assets/npyjs.js` - -- Type: text -- Size: 5216 bytes -- Lines: 1-176 - -- First non-empty content (anchor): L1: `class npyjs {` - -### `config/arch/hrm_v1.yaml` - -- Type: text -- Size: 349 bytes -- Lines: 1-21 - -- Address anchors: - - L1: `name` - - L2: `loss` - - L3: `name` - - L4: `loss_type` - - L6: `halt_exploration_prob` - - L7: `halt_max_steps` - - L9: `H_cycles` - - L10: `L_cycles` - - L12: `H_layers` - - L13: `L_layers` - - L15: `hidden_size` - - L16: `num_heads` - - L17: `expansion` - - L19: `puzzle_emb_ndim` - - L21: `pos_encodings` - -### `config/cfg_pretrain.yaml` - -- Type: text -- Size: 492 bytes -- Lines: 1-31 - -- Address anchors: - - L3: `defaults` - - L7: `hydra` - - L8: `output_subdir` - - L11: `data_path` - - L14: `global_batch_size` - - L16: `epochs` - - L17: `eval_interval` - - L18: `checkpoint_every_eval` - - L20: `lr` - - L21: `lr_min_ratio` - - L22: `lr_warmup_steps` - - L25: `beta1` - - L26: `beta2` - - L27: `weight_decay` - - L28: `puzzle_emb_weight_decay` - - L31: `puzzle_emb_lr` - -### `config/vjepa_10b.yaml` - -- Type: text -- Size: 1709 bytes -- Lines: 1-85 - -- Address anchors: - - L2: `encoder` - - L3: `img_size` - - L4: `patch_size` - - L5: `in_chans` - - L6: `embed_dim` - - L7: `depth` - - L8: `num_heads` - - L9: `expansion` - - L10: `max_t` - - L11: `max_h` - - L12: `max_w` - - L14: `predictor` - - L16: `H_cycles` - - L17: `L_cycles` - - L18: `H_layers` - - L19: `L_layers` - - L20: `hidden_size` - - L21: `expansion` - - L22: `num_heads` - - L23: `pos_encodings` - - L24: `halt_max_steps` - - L25: `halt_exploration_prob` - - L26: `forward_dtype` - - L29: `use_gaussian_splatting` - - L30: `num_gaussians` - - L33: `use_flow_matching` - - L36: `use_symplectic` - - L39: `use_ttt` - - L40: `ttt_inner_lr` - - L41: `ttt_inner_steps` - - L44: `adaptive_depth` - - L45: `confidence_threshold` - - L46: `uncertainty_threshold` - - L49: `use_multimodal` - - L50: `audio_input_dim` - - L51: `tactile_input_dim` - - L54: `use_information_bottleneck` - - L55: `ib_beta` - - L57: `use_spectral_conv` - - L58: `spectral_num_filters` - - L59: `spectral_polynomial_order` - - L61: `use_topology` - - L62: `topology_filtration_steps` - - L64: `use_proper_se3` - - L65: `se3_num_frequencies` - - L66: `se3_l_max` - - L68: `use_uncertainty` - - L69: `uncertainty_mc_samples` - - L70: `uncertainty_dropout` - - L72: `training` - - L73: `batch_size` - - L74: `global_batch_size` - - L75: `lr` - - L76: `ema_momentum` - - L77: `sim_coeff` - - L78: `std_coeff` - - L79: `cov_coeff` - - L80: `weight_decay` - - L83: `optimizer` - - L84: `muon_momentum` - - L85: `muon_ns_steps` - -### `config/vjepa_micro.yaml` - -- Type: text -- Size: 555 bytes -- Lines: 1-34 - -- Address anchors: - - L2: `encoder` - - L3: `img_size` - - L4: `patch_size` - - L5: `in_chans` - - L6: `embed_dim` - - L7: `depth` - - L8: `num_heads` - - L9: `expansion` - - L10: `max_t` - - L11: `max_h` - - L12: `max_w` - - L14: `predictor` - - L15: `H_cycles` - - L16: `L_cycles` - - L17: `H_layers` - - L18: `L_layers` - - L19: `hidden_size` - - L20: `expansion` - - L21: `num_heads` - - L22: `pos_encodings` - - L23: `halt_max_steps` - - L24: `halt_exploration_prob` - - L25: `forward_dtype` - - L27: `training` - - L28: `batch_size` - - L29: `global_batch_size` - - L30: `lr` - - L31: `ema_momentum` - - L32: `sim_coeff` - - L33: `std_coeff` - - L34: `cov_coeff` - -### `dataset/build_arc_dataset.py` - -- Type: text -- Size: 10084 bytes -- Lines: 1-291 - -- Address anchors: - - L19: `class DataProcessConfig` - - L37: `class ARCPuzzle` - - L43: `def arc_grid_to_np` - - L54: `def np_grid_to_seq_translational_augment` - - L81: `def puzzle_hash` - - L83: `def _grid_hash` - - L98: `def convert_single_arc_puzzle` - - L122: `def _map_grid` - - L148: `def load_puzzles_arcagi` - - L184: `def convert_dataset` - - L286: `def main` - -### `dataset/build_maze_dataset.py` - -- Type: text -- Size: 4461 bytes -- Lines: 1-142 - -- Address anchors: - - L22: `class DataProcessConfig` - - L30: `def convert_subset` - - L89: `def _seq_to_numpy` - - L136: `def preprocess_data` - -### `dataset/common.py` - -- Type: text -- Size: 1381 bytes -- Lines: 1-51 - -- Address anchors: - - L12: `class PuzzleDatasetMetadata` - - L27: `def dihedral_transform` - - L50: `def inverse_dihedral_transform` - -### `dataset/generate_dummy_data.py` - -- Type: text -- Size: 860 bytes -- Lines: 1-29 - -- Address anchors: - - L5: `def generate_dummy_video` - -### `dataset/video_dataset.py` - -- Type: text -- Size: 4426 bytes -- Lines: 1-108 - -- Address anchors: - - L8: `class AdvancedVideoDataset` - - L13: `def __init__` - - L33: `def _get_video_stream` - - L42: `def _generate_3d_block_mask` - - L66: `def __iter__` - - L106: `def get_dataloader` - -### `evaluate.py` - -- Type: text -- Size: 2490 bytes -- Lines: 1-68 - -- Address anchors: - - L13: `class EvalConfig` - - L19: `def launch` - -### `models/adaptive_depth.py` - -- Type: text -- Size: 6649 bytes -- Lines: 1-199 - -- Address anchors: - - L24: `class AdaptiveDepthController` - - L41: `def __init__` - - L54: `def should_continue` - - L116: `class AdaptiveDepthWrapper` - - L131: `def __init__` - - L146: `def forward` - -### `models/common.py` - -- Type: text -- Size: 1216 bytes -- Lines: 1-32 - -- Address anchors: - - L7: `def trunc_normal_init_` - -### `models/hrm/hrm_act_v1.py` - -- Type: text -- Size: 12161 bytes -- Lines: 1-283 - -- Address anchors: - - L16: `class HierarchicalReasoningModel_ACTV1InnerCarry` - - L22: `class HierarchicalReasoningModel_ACTV1Carry` - - L31: `class HierarchicalReasoningModel_ACTV1Config` - - L60: `class HierarchicalReasoningModel_ACTV1Block` - - L61: `def __init__` - - L77: `def forward` - - L86: `class HierarchicalReasoningModel_ACTV1ReasoningModule` - - L87: `def __init__` - - L92: `def forward` - - L102: `class HierarchicalReasoningModel_ACTV1_Inner` - - L103: `def __init__` - - L146: `def _input_embeddings` - - L168: `def empty_carry` - - L174: `def reset_carry` - - L180: `def forward` - - L216: `class HierarchicalReasoningModel_ACTV1` - - L219: `def __init__` - - L225: `def puzzle_emb` - - L228: `def initial_carry` - - L240: `def forward` - -### `models/hybrid_ssm.py` - -- Type: text -- Size: 7530 bytes -- Lines: 1-228 - -- Address anchors: - - L26: `class SelectiveSSM` - - L41: `def __init__` - - L85: `def forward` - - L141: `class HybridSSMAttentionBlock` - - L161: `def __init__` - - L199: `def forward` - -### `models/information_bottleneck.py` - -- Type: text -- Size: 7603 bytes -- Lines: 1-227 - -- Address anchors: - - L28: `class VariationalInformationBottleneck` - - L42: `def __init__` - - L71: `def _reparameterize` - - L90: `def forward` - - L135: `class InformationBottleneckAttention` - - L150: `def __init__` - - L179: `def forward` - -### `models/layers.py` - -- Type: text -- Size: 6156 bytes -- Lines: 1-167 - -- Address anchors: - - L13: `def flash_attn_func` - - L29: `def _find_multiple` - - L33: `def rotate_half` - - L40: `def apply_rotary_pos_emb` - - L53: `class CastedLinear` - - L54: `def __init__` - - L68: `def forward` - - L72: `class CastedEmbedding` - - L73: `def __init__` - - L86: `def forward` - - L90: `class RotaryEmbedding` - - L91: `def __init__` - - L104: `def forward` - - L108: `class Attention` - - L109: `def __init__` - - L122: `def forward` - - L148: `class SwiGLU` - - L149: `def __init__` - - L156: `def forward` - - L161: `def rms_norm` - -### `models/losses.py` - -- Type: text -- Size: 3804 bytes -- Lines: 1-101 - -- Address anchors: - - L11: `def s` - - L19: `def log_stablemax` - - L24: `def stablemax_cross_entropy` - - L34: `def softmax_cross_entropy` - - L40: `class ACTLossHead` - - L41: `def __init__` - - L46: `def initial_carry` - - L49: `def forward` - -### `models/multimodal_grounding.py` - -- Type: text -- Size: 8774 bytes -- Lines: 1-258 - -- Address anchors: - - L24: `class ModalityEncoder` - - L37: `def __init__` - - L60: `def forward` - - L88: `class CrossModalAttention` - - L101: `def __init__` - - L122: `def forward` - - L158: `class MultiModalGrounding` - - L180: `def __init__` - - L213: `def forward` - -### `models/muon_optimizer.py` - -- Type: text -- Size: 6664 bytes -- Lines: 1-191 - -- Address anchors: - - L26: `class Muon` - - L44: `def __init__` - - L67: `def step` - - L137: `def _newton_schulz_orthogonalize` - - L170: `def _distributed_allreduce_grads` - -### `models/proper_equivariance.py` - -- Type: text -- Size: 11960 bytes -- Lines: 1-335 - -- Address anchors: - - L27: `class SO3Rotation` - - L37: `def axis_angle_to_matrix` - - L74: `def matrix_to_quaternion` - - L93: `class WignerDMatrices` - - L107: `def wigner_d_small` - - L170: `def rotation_to_euler` - - L187: `class ProperSE3EquivariantLayer` - - L204: `def __init__` - - L243: `def _positional_encoding` - - L259: `def forward` - -### `models/sparse_embedding.py` - -- Type: text -- Size: 4366 bytes -- Lines: 1-132 - -- Address anchors: - - L11: `class CastedSparseEmbedding` - - L12: `def __init__` - - L28: `def forward` - - L41: `class CastedSparseEmbeddingSignSGD_Distributed` - - L42: `def __init__` - - L63: `def step` - - L98: `def _sparse_emb_signsgd_dist` - -### `models/spectral_conv.py` - -- Type: text -- Size: 6117 bytes -- Lines: 1-193 - -- Address anchors: - - L25: `class GraphLaplacian` - - L34: `def __init__` - - L38: `def forward` - - L74: `class SpectralGraphConv` - - L88: `def __init__` - - L120: `def _chebyshev_polynomials` - - L159: `def forward` - -### `models/topological.py` - -- Type: text -- Size: 7232 bytes -- Lines: 1-215 - -- Address anchors: - - L25: `class DifferentiableBettiNumbers` - - L41: `def __init__` - - L68: `def _compute_distance_matrix` - - L74: `def _soft_threshold` - - L87: `def forward` - - L156: `class TopologicalAwareness` - - L171: `def __init__` - - L191: `def forward` - -### `models/ttt_layer.py` - -- Type: text -- Size: 7388 bytes -- Lines: 1-206 - -- Address anchors: - - L25: `class TTTLinear` - - L45: `def __init__` - - L74: `def forward` - - L148: `class TTTLinearWithAttention` - - L163: `def __init__` - - L191: `def forward` - -### `models/uncertainty.py` - -- Type: text -- Size: 7009 bytes -- Lines: 1-206 - -- Address anchors: - - L24: `class VariationalLinear` - - L41: `def __init__` - - L61: `def forward` - - L78: `def kl_divergence` - - L106: `class UncertaintyQuantification` - - L121: `def __init__` - - L151: `def forward` - -### `models/vjepa/flow_matching.py` - -- Type: text -- Size: 9015 bytes -- Lines: 1-281 - -- Address anchors: - - L28: `class SinusoidalTimeEmbedding` - - L31: `def __init__` - - L35: `def forward` - - L53: `class VelocityField` - - L61: `def __init__` - - L89: `def forward` - - L125: `class ConditionalFlowMatching` - - L150: `def __init__` - - L165: `def forward` - - L211: `def sample` - - L257: `def sample_rectified` - -### `models/vjepa/gaussian_splatting.py` - -- Type: text -- Size: 7036 bytes -- Lines: 1-203 - -- Address anchors: - - L28: `class LatentGaussianSplatting` - - L40: `def __init__` - - L75: `def _parse_gaussians` - - L118: `def _quaternion_to_rotation_matrix` - - L146: `def forward` - -### `models/vjepa/layers.py` - -- Type: text -- Size: 6641 bytes -- Lines: 1-157 - -- Address anchors: - - L8: `class LieGroupEquivariantLayer` - - L14: `def __init__` - - L31: `def forward` - - L50: `class LatentRayMarcher` - - L56: `def __init__` - - L71: `def forward` - - L113: `def apply_rotary_pos_emb_3d` - - L114: `def rotate_half` - - L128: `class RotaryEmbedding3D` - - L129: `def __init__` - - L141: `def _get_freqs` - - L147: `def _build_cache` - - L152: `def forward` - -### `models/vjepa/losses.py` - -- Type: text -- Size: 977 bytes -- Lines: 1-31 - -- Address anchors: - - L4: `def vicreg_loss` - - L22: `def covariance_loss` - -### `models/vjepa/memory.py` - -- Type: text -- Size: 13196 bytes -- Lines: 1-392 - -- Address anchors: - - L28: `class ResonatorNetwork` - - L49: `def __init__` - - L66: `def set_cleanup_memory` - - L70: `def cleanup` - - L96: `def resonator_step` - - L119: `def _unbind` - - L125: `def forward` - - L178: `class HolographicMemory` - - L198: `def __init__` - - L219: `def _bind_hrr` - - L225: `def _unbind_hrr` - - L231: `def _bind_fhrr` - - L243: `def _unbind_fhrr` - - L249: `def bind` - - L255: `def unbind` - - L261: `def superpose` - - L284: `def forward` - - L312: `def retrieve` - - L327: `def retrieve_with_cleanup` - - L348: `def multi_retrieve` - - L385: `def set_cleanup_memory` - -### `models/vjepa/physics_engine.py` - -- Type: text -- Size: 4774 bytes -- Lines: 1-107 - -- Address anchors: - - L6: `class HRMPhysicsODE` - - L12: `def __init__` - - L30: `def forward` - - L61: `class ContinuousTimeHRM` - - L67: `def __init__` - - L71: `def forward` - -### `models/vjepa/planning.py` - -- Type: text -- Size: 15338 bytes -- Lines: 1-458 - -- Address anchors: - - L28: `class MCTSNode` - - L49: `def __init__` - - L67: `def mean_value` - - L74: `def is_expanded` - - L79: `def effective_visits` - - L83: `def puct_score` - - L120: `class MCTS` - - L139: `def __init__` - - L160: `def _imagine_future` - - L192: `def _select` - - L222: `def _expand` - - L280: `def _backpropagate` - - L306: `def _get_action_probabilities` - - L340: `def plan` - - L393: `class LatentPlannerMCTS` - - L408: `def __init__` - - L424: `def plan` - - L442: `def plan_with_uncertainty` - -### `models/vjepa/predictor.py` - -- Type: text -- Size: 10722 bytes -- Lines: 1-260 - -- Address anchors: - - L22: `class VJEPAPredictorInner` - - L28: `def __init__` - - L156: `def forward` - -### `models/vjepa/symplectic_integrator.py` - -- Type: text -- Size: 4606 bytes -- Lines: 1-137 - -- Address anchors: - - L31: `class SymplecticEulerIntegrator` - - L48: `def __init__` - - L67: `def set_action` - - L71: `def hamiltonian` - - L87: `def forward` - - L126: `def compute_energy` - -### `models/vjepa/utils.py` - -- Type: text -- Size: 1478 bytes -- Lines: 1-44 - -- Address anchors: - - L3: `def get_block_mask` - - L21: `def apply_mask` - -### `models/vjepa/vit.py` - -- Type: text -- Size: 2960 bytes -- Lines: 1-94 - -- Address anchors: - - L9: `class PatchEmbed3D` - - L15: `def __init__` - - L20: `def forward` - - L30: `class VisionTransformerBlock` - - L31: `def __init__` - - L43: `def _forward_inner` - - L48: `def forward` - - L54: `class VisionEncoder` - - L55: `def __init__` - - L86: `def forward` - -### `models/vjepa/vjepa_model.py` - -- Type: text -- Size: 5221 bytes -- Lines: 1-127 - -- Address anchors: - - L12: `class VJEPA` - - L23: `def __init__` - - L72: `def update_target_encoder` - - L77: `def forward` - -### `pretrain.py` - -- Type: text -- Size: 15607 bytes -- Lines: 1-453 - -- Address anchors: - - L26: `class LossConfig` - - L32: `class ArchConfig` - - L39: `class PretrainConfig` - - L74: `class TrainState` - - L84: `def create_dataloader` - - L108: `def create_model` - - L162: `def cosine_schedule_with_warmup_lr_lambda` - - L172: `def init_train_state` - - L190: `def save_train_state` - - L199: `def compute_lr` - - L209: `def train_batch` - - L266: `def evaluate` - - L333: `def save_code_and_config` - - L359: `def load_synced_config` - - L381: `def launch` - -### `puzzle_dataset.py` - -- Type: text -- Size: 7980 bytes -- Lines: 1-199 - -- Address anchors: - - L14: `def _sample_batch` - - L41: `class PuzzleDatasetConfig` - - L53: `class PuzzleDataset` - - L54: `def __init__` - - L68: `def _load_metadata` - - L72: `def _lazy_load_dataset` - - L95: `def _collate_batch` - - L118: `def _iter_test` - - L151: `def _iter_train` - - L189: `def __iter__` - -### `puzzle_visualizer.html` - -- Type: text -- Size: 14119 bytes -- Lines: 1-426 - -- First non-empty content (anchor): L1: `` - -### `requirements.txt` - -- Type: text -- Size: 270 bytes -- Lines: 1-16 - -- First non-empty content (anchor): L1: `torch` - -### `utils/functions.py` - -- Type: text -- Size: 516 bytes -- Lines: 1-19 - -- Address anchors: - - L5: `def load_model_class` - - L15: `def get_model_source_path` - -### `vjepa_train.py` - -- Type: text -- Size: 7591 bytes -- Lines: 1-216 - -- Address anchors: - - L18: `def build_optimizer` - - L92: `class CombinedOptimizer` - - L97: `def __init__` - - L100: `def zero_grad` - - L104: `def step` - - L109: `def param_groups` - - L116: `def train` +# CODE ADDRESS INDEX + +Auto-generated repository file map. + +- `.github/workflows/sync-from-upstream.yml` | text | lines: 19 +- `.gitignore` | text | lines: 169 +- `.gitmodules` | text | lines: 9 +- `.vscode/launch.json` | text | lines: 26 +- `.vscode/settings.json` | text | lines: 3 +- `LICENSE` | text | lines: 202 +- `README.md` | text | lines: 112 +- `assets/hrm.png` | binary | lines: 0 +- `assets/npyjs.js` | text | lines: 176 +- `config/vjepa_10b.yaml` | text | lines: 85 +- `config/vjepa_micro.yaml` | text | lines: 34 +- `dataset/common.py` | text | lines: 51 +- `dataset/generate_dummy_data.py` | text | lines: 29 +- `dataset/video_dataset.py` | text | lines: 108 +- `docs/FRONTIER_GAP_ANALYSIS.md` | text | lines: 35 +- `docs/HUMAN_VISION_EXECUTION_EVAL_SPEC.md` | text | lines: 48 +- `docs/RIGOROUS_DEVELOPMENT_PROTOCOL.md` | text | lines: 33 +- `evaluate_perception.py` | text | lines: 90 +- `evaluate_world_model.py` | text | lines: 166 +- `models/adaptive_depth.py` | text | lines: 199 +- `models/common.py` | text | lines: 32 +- `models/hybrid_ssm.py` | text | lines: 228 +- `models/information_bottleneck.py` | text | lines: 227 +- `models/layers.py` | text | lines: 167 +- `models/losses.py` | text | lines: 101 +- `models/multimodal_grounding.py` | text | lines: 258 +- `models/muon_optimizer.py` | text | lines: 191 +- `models/proper_equivariance.py` | text | lines: 335 +- `models/spectral_conv.py` | text | lines: 193 +- `models/topological.py` | text | lines: 215 +- `models/ttt_layer.py` | text | lines: 206 +- `models/uncertainty.py` | text | lines: 206 +- `models/vjepa/flow_matching.py` | text | lines: 281 +- `models/vjepa/gaussian_splatting.py` | text | lines: 203 +- `models/vjepa/layers.py` | text | lines: 157 +- `models/vjepa/losses.py` | text | lines: 31 +- `models/vjepa/memory.py` | text | lines: 392 +- `models/vjepa/physics_engine.py` | text | lines: 107 +- `models/vjepa/planning.py` | text | lines: 467 +- `models/vjepa/predictor.py` | text | lines: 260 +- `models/vjepa/symplectic_integrator.py` | text | lines: 137 +- `models/vjepa/utils.py` | text | lines: 44 +- `models/vjepa/vit.py` | text | lines: 94 +- `models/vjepa/vjepa_model.py` | text | lines: 141 +- `requirements.txt` | text | lines: 16 +- `utils/functions.py` | text | lines: 19 +- `vjepa_train.py` | text | lines: 221 diff --git a/arc_eval.ipynb b/arc_eval.ipynb deleted file mode 100644 index b2786b8a..00000000 --- a/arc_eval.ipynb +++ /dev/null @@ -1,252 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "import json\n", - "from glob import glob\n", - "import hashlib\n", - "import matplotlib.pyplot as plt\n", - "import matplotlib.colors as mcolors\n", - "\n", - "import torch\n", - "import torch.nn.functional as F\n", - "import numpy as np\n", - "from numba import njit\n", - "\n", - "from dataset.common import inverse_dihedral_transform\n", - "\n", - "\n", - "DATASET_PATH = \"data/arc-aug-1000\" # ARC-1\n", - "# DATASET_PATH = \"data/arc-2-aug-1000\" # ARC-2\n", - "\n", - "CHECKPOINT_PATH = \"checkpoints/Arc-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV1 amphibian-turaco/step_414456\"\n", - "\n", - "\n", - "PAD_PUZZLE_IDENTIFIER = 0\n", - "\n", - "# Visualization\n", - "ARC_COLOR_MAP = mcolors.ListedColormap([\n", - " \"#000000\", # symbol_0: black\n", - " \"#0074D9\", # symbol_1: blue\n", - " \"#FF4136\", # symbol_2: red\n", - " \"#2ECC40\", # symbol_3: green\n", - " \"#FFDC00\", # symbol_4: yellow\n", - " \"#AAAAAA\", # symbol_5: grey\n", - " \"#F012BE\", # symbol_6: fuschia\n", - " \"#FF851B\", # symbol_7: orange\n", - " \"#7FDBFF\", # symbol_8: teal\n", - " \"#870C25\" # symbol_9: brown\n", - "])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def load_identifiers_and_preds(dataset_path: str, checkpoint_path: str):\n", - " # Load puzzle identifiers\n", - " with open(os.path.join(dataset_path, \"identifiers.json\"), \"r\") as f:\n", - " identifier_map = json.load(f)\n", - " \n", - " # Load preds\n", - " all_preds = {}\n", - " for filename in glob(f\"{checkpoint_path}_all_preds.*\"):\n", - " preds = torch.load(filename)\n", - " for k, v in preds.items():\n", - " all_preds.setdefault(k, [])\n", - " all_preds[k].append(v)\n", - " \n", - " del preds\n", - "\n", - " all_preds = {k: torch.cat(v, dim=0) for k, v in all_preds.items()}\n", - " \n", - " # Remove paddings\n", - " mask = all_preds[\"puzzle_identifiers\"] != PAD_PUZZLE_IDENTIFIER\n", - " all_preds = {k: v[mask] for k, v in all_preds.items()}\n", - "\n", - " return identifier_map, all_preds\n", - "\n", - "\n", - "def inverse_aug(name: str, grid: np.ndarray):\n", - " if \"_\" not in name:\n", - " return grid\n", - "\n", - " trans_id, perm = name.split(\"_\")[-2:]\n", - " trans_id = int(trans_id[1:]) # Remove \"t\" letter\n", - " inv_perm = np.argsort(list(perm))\n", - " \n", - " return inv_perm[inverse_dihedral_transform(grid, trans_id)]\n", - "\n", - "\n", - "def grid_hash(grid: np.ndarray):\n", - " return hash((grid.tobytes(), grid.shape))\n", - "\n", - "\n", - "@njit\n", - "def crop(grid: np.ndarray):\n", - " # Find maximum-sized rectangle without any EOS token inside.\n", - " grid = grid.reshape(30, 30)\n", - "\n", - " max_area = 0\n", - " max_size = (0, 0)\n", - " nr, nc = grid.shape\n", - " \n", - " num_c = nc\n", - " for num_r in range(1, nr + 1):\n", - " # Scan for maximum c\n", - " for c in range(1, num_c + 1):\n", - " x = grid[num_r - 1, c - 1]\n", - " if (x < 2) | (x > 11):\n", - " num_c = c - 1\n", - " break\n", - " \n", - " area = num_r * num_c\n", - " if area > max_area:\n", - " max_area = area\n", - " max_size = (num_r, num_c)\n", - "\n", - " return grid[:max_size[0], :max_size[1]] - 2\n", - "\n", - "\n", - "def test(visualize, Ks=[1, 2, 10, 100, 1000]):\n", - " identifier_map, all_preds = load_identifiers_and_preds(DATASET_PATH, CHECKPOINT_PATH)\n", - " \n", - " global_hmap = {}\n", - " \n", - " # Get puzzles and corresponding answers\n", - " puzzle_labels = {}\n", - " for identifier, input, label in zip(all_preds[\"puzzle_identifiers\"], all_preds[\"inputs\"], all_preds[\"labels\"]):\n", - " name = identifier_map[identifier]\n", - " if \"_\" not in name: # Not-augmented\n", - " puzzle_labels.setdefault(name, {})\n", - " \n", - " input = crop(input.numpy())\n", - " label = crop(label.numpy())\n", - "\n", - " input_hash = grid_hash(input)\n", - " label_hash = grid_hash(label)\n", - "\n", - " global_hmap[input_hash] = input\n", - " global_hmap[label_hash] = label\n", - "\n", - " assert input_hash not in puzzle_labels[name]\n", - " puzzle_labels[name][input_hash] = label_hash\n", - " \n", - " print (\"Number of puzzles\", len(puzzle_labels))\n", - " \n", - " # Argmax prediction\n", - " preds = all_preds[\"logits\"].argmax(-1)\n", - "\n", - " # Collate\n", - " pred_answers = {}\n", - " for identifier, input, pred, q in zip(all_preds[\"puzzle_identifiers\"], all_preds[\"inputs\"], preds, all_preds[\"q_halt_logits\"].sigmoid()):\n", - " name = identifier_map[identifier]\n", - " orig_name = name.split(\"_\")[0]\n", - " \n", - " input = input.numpy()\n", - " input_hash = grid_hash(inverse_aug(name, crop(input)))\n", - " assert input_hash in puzzle_labels[orig_name]\n", - " \n", - " pred = inverse_aug(name, crop(pred.numpy()))\n", - " pred_hash = grid_hash(pred)\n", - " global_hmap[pred_hash] = pred\n", - " \n", - " pred_answers.setdefault(orig_name, {})\n", - " pred_answers[orig_name].setdefault(input_hash, [])\n", - " pred_answers[orig_name][input_hash].append((pred_hash, q.item()))\n", - "\n", - " # test-1\n", - " if visualize:\n", - " num_figs = sum(len(tests) for name, tests in puzzle_labels.items())\n", - " fig, axes = plt.subplots(num_figs, 4, figsize=(8, num_figs * 4))\n", - " \n", - " fig_id = 0\n", - " \n", - " correct = [0 for _ in range(len(Ks))]\n", - " for name, tests in puzzle_labels.items():\n", - " num_test_correct = [0 for _ in range(len(Ks))]\n", - " for input_hash, label_hash in tests.items():\n", - " p = pred_answers[name][input_hash]\n", - " p_map = {}\n", - " \n", - " for h, q in p:\n", - " p_map.setdefault(h, [0, 0])\n", - " p_map[h][0] += 1\n", - " p_map[h][1] += q\n", - " \n", - " for h, stats in p_map.items():\n", - " stats[1] /= stats[0]\n", - " \n", - " p_map = sorted(p_map.items(), key=lambda kv: kv[1], reverse=True)\n", - "\n", - " # 2-vote\n", - " for i, k in enumerate(Ks):\n", - " ok = False\n", - " for h, stats in p_map[:k]:\n", - " ok |= h == label_hash\n", - " \n", - " num_test_correct[i] += ok\n", - "\n", - " if visualize:\n", - " # Show input and ground truth\n", - " axes[fig_id, 0].imshow(global_hmap[input_hash], cmap=ARC_COLOR_MAP)\n", - " axes[fig_id, 0].set_title(f\"{name}\\nInput\")\n", - " axes[fig_id, 0].axis('off')\n", - " \n", - " axes[fig_id, 1].imshow(global_hmap[label_hash], cmap=ARC_COLOR_MAP)\n", - " axes[fig_id, 1].set_title(f\"{name}\\nAnswer\")\n", - " axes[fig_id, 1].axis('off')\n", - " \n", - " trial_id = 2\n", - " for h, stats in p_map[:2]:\n", - " ans = global_hmap[h]\n", - " \n", - " axes[fig_id, trial_id].imshow(ans, cmap=ARC_COLOR_MAP)\n", - " axes[fig_id, trial_id].set_title(f\"{name}\\nTrial {trial_id}\")\n", - " axes[fig_id, trial_id].axis('off')\n", - " \n", - " trial_id += 1\n", - " \n", - " fig_id += 1\n", - " \n", - " # Total correctness\n", - " for i in range(len(Ks)):\n", - " correct[i] += num_test_correct[i] == len(tests)\n", - "\n", - " for i, k in enumerate(Ks):\n", - " print (f\"{k}-shot: {correct[i] / len(puzzle_labels) * 100:.2f}%\")\n", - "\n", - "\n", - "test(visualize=False)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.10" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/check_integrations.py b/check_integrations.py new file mode 100644 index 00000000..9d8363f7 --- /dev/null +++ b/check_integrations.py @@ -0,0 +1,56 @@ +import torch +import yaml + +from models.vjepa.vjepa_model import VisualExecutionModel +from models.vjepa.planning import MCTS + + +def load_config(path: str): + with open(path, "r", encoding="utf-8") as f: + return yaml.safe_load(f) + + +def main(): + cfg = load_config("config/vjepa_micro.yaml") + model = VisualExecutionModel( + encoder_config=cfg["encoder"], + predictor_config=cfg["predictor"], + ema_momentum=cfg["training"].get("ema_momentum", 0.996), + action_dim=cfg.get("action_dim", 128), + ) + model.eval() + + bsz = 1 + t = cfg["encoder"].get("max_t", 8) + h = cfg["encoder"].get("img_size", 64) + w = cfg["encoder"].get("img_size", 64) + video = torch.randn(bsz, t, 3, h, w) + + pt, ph, pw = cfg["encoder"]["patch_size"] + seq_len = (t // pt) * (h // ph) * (w // pw) + num_mask = max(1, seq_len // 4) + mask = torch.randperm(seq_len)[:num_mask] + + batch = { + "video": video, + "mask": mask, + "delta_t": torch.ones(bsz, 1), + "action": torch.randn(bsz, cfg.get("action_dim", 128)), + } + + out = model(batch) + assert "predicted" in out and "target" in out and "value" in out + + mcts = MCTS(model=model, n_simulations=4) + root_state = out["all_context"].mean(dim=1) + actions = torch.randn(8, cfg.get("action_dim", 128)) + chosen = mcts.plan(root_state, actions) + if isinstance(chosen, tuple): + chosen = chosen[0] + assert chosen.shape[-1] == cfg.get("action_dim", 128) + + print("Integration check passed: model forward + MCTS planning wired correctly.") + + +if __name__ == "__main__": + main() diff --git a/config/arch/hrm_v1.yaml b/config/arch/hrm_v1.yaml deleted file mode 100644 index a5646b89..00000000 --- a/config/arch/hrm_v1.yaml +++ /dev/null @@ -1,21 +0,0 @@ -name: hrm.hrm_act_v1@HierarchicalReasoningModel_ACTV1 -loss: - name: losses@ACTLossHead - loss_type: stablemax_cross_entropy - -halt_exploration_prob: 0.1 -halt_max_steps: 16 - -H_cycles: 2 -L_cycles: 2 - -H_layers: 4 -L_layers: 4 - -hidden_size: 512 -num_heads: 8 # min(2, hidden_size // 64) -expansion: 4 - -puzzle_emb_ndim: ${.hidden_size} - -pos_encodings: rope diff --git a/config/cfg_pretrain.yaml b/config/cfg_pretrain.yaml deleted file mode 100644 index 51c55a07..00000000 --- a/config/cfg_pretrain.yaml +++ /dev/null @@ -1,31 +0,0 @@ -# ARC training config - -defaults: - - arch: hrm_v1 - - _self_ - -hydra: - output_subdir: null - -# Data path -data_path: data/arc-aug-1000 - -# Hyperparams - Training -global_batch_size: 768 - -epochs: 100000 -eval_interval: 10000 -checkpoint_every_eval: True - -lr: 1e-4 -lr_min_ratio: 1.0 -lr_warmup_steps: 2000 - -# Standard hyperparameter settings for LM, as used in Llama -beta1: 0.9 -beta2: 0.95 -weight_decay: 0.1 -puzzle_emb_weight_decay: 0.1 - -# Hyperparams - Puzzle embeddings training -puzzle_emb_lr: 1e-2 diff --git a/dataset/build_arc_dataset.py b/dataset/build_arc_dataset.py deleted file mode 100644 index 2da5703e..00000000 --- a/dataset/build_arc_dataset.py +++ /dev/null @@ -1,291 +0,0 @@ -from typing import List, Optional, Tuple, Dict -from dataclasses import dataclass -from pathlib import Path -import os -import json -import hashlib -import numpy as np -from glob import glob - -from argdantic import ArgParser -from pydantic import BaseModel - -from common import PuzzleDatasetMetadata, dihedral_transform - - -cli = ArgParser() - - -class DataProcessConfig(BaseModel): - # ARC-1 - dataset_dirs: List[str] = ["dataset/raw-data/ARC-AGI/data", "dataset/raw-data/ConceptARC/corpus"] - output_dir: str = "data/arc-aug-1000" - - # ARC-2 - # dataset_dirs: List[str] = ["dataset/raw-data/ARC-AGI-2/data"] - # output_dir: str = "data/arc-2-aug-1000" - - seed: int = 42 - num_aug: int = 1000 - - -ARCMaxGridSize = 30 -ARCAugmentRetriesFactor = 5 - - -@dataclass -class ARCPuzzle: - id: str - - examples: List[Tuple[np.ndarray, np.ndarray]] - - -def arc_grid_to_np(grid: List[List[int]]): - arr = np.array(grid) - - # Shape check - assert arr.ndim == 2 - assert arr.shape[0] <= ARCMaxGridSize and arr.shape[1] <= ARCMaxGridSize - # Element check - assert np.all((arr >= 0) & (arr <= 9)) - return arr.astype(np.uint8) - - -def np_grid_to_seq_translational_augment(inp: np.ndarray, out: np.ndarray, do_translation: bool): - # PAD: 0, : 1, digits: 2 ... 11 - # Compute random top-left pad - if do_translation: - pad_r = np.random.randint(0, ARCMaxGridSize - max(inp.shape[0], out.shape[0]) + 1) - pad_c = np.random.randint(0, ARCMaxGridSize - max(inp.shape[1], out.shape[1]) + 1) - else: - pad_r = pad_c = 0 - - # Pad grid - result = [] - for grid in [inp, out]: - nrow, ncol = grid.shape - grid = np.pad(grid + 2, ((pad_r, ARCMaxGridSize - pad_r - nrow), (pad_c, ARCMaxGridSize - pad_c - ncol)), constant_values=0) - - # Add - eos_row, eos_col = pad_r + nrow, pad_c + ncol - if eos_row < ARCMaxGridSize: - grid[eos_row, pad_c:eos_col] = 1 - if eos_col < ARCMaxGridSize: - grid[pad_r:eos_row, eos_col] = 1 - - result.append(grid.flatten()) - - return result - - -def puzzle_hash(puzzle: dict): - # Hash the puzzle for checking equivalence - def _grid_hash(grid: np.ndarray): - buffer = [x.to_bytes(1) for x in grid.shape] - buffer.append(grid.tobytes()) - - return hashlib.sha256(b"".join(buffer)).hexdigest() - - hashes = [] - for example_type, example in puzzle.items(): - for input, label in example.examples: - hashes.append(f"{_grid_hash(input)}|{_grid_hash(label)}") - - hashes.sort() - return hashlib.sha256("|".join(hashes).encode()).hexdigest() - - -def convert_single_arc_puzzle(results: dict, default_name: str, puzzle: dict, aug_count: int, dest_mapping: Dict[str, Tuple[str, str]]): - # Remove "name" - name = puzzle.pop("name", default_name) - - # Convert - dests = set(dest_mapping.values()) - converted = {dest: ARCPuzzle(name, []) for dest in dests} - for example_type, examples in puzzle.items(): - dest = dest_mapping[example_type] - converted[dest].examples.extend([(arc_grid_to_np(example["input"]), arc_grid_to_np(example["output"])) for example in examples]) - - group = [converted] - - # Augment - if aug_count > 0: - hashes = {puzzle_hash(converted)} - - for _trial in range(ARCAugmentRetriesFactor * aug_count): - # Augment plan - trans_id = np.random.randint(0, 8) - mapping = np.concatenate([np.arange(0, 1, dtype=np.uint8), np.random.permutation(np.arange(1, 10, dtype=np.uint8))]) # Permute colors, Excluding "0" (black) - - aug_repr = f"t{trans_id}_{''.join(str(x) for x in mapping)}" - - def _map_grid(grid: np.ndarray): - return dihedral_transform(mapping[grid], trans_id) - - # Check duplicate - augmented = {dest: ARCPuzzle(f"{puzzle.id}_{aug_repr}", [(_map_grid(input), _map_grid(label)) for (input, label) in puzzle.examples]) for dest, puzzle in converted.items()} - h = puzzle_hash(augmented) - if h not in hashes: - hashes.add(h) - group.append(augmented) - - if len(group) >= aug_count + 1: - break - - if len(group) < aug_count + 1: - print (f"[Puzzle {name}] augmentation not full, only {len(group)}") - - # Append - for dest in dests: - # Convert the examples - dest_split, dest_set = dest - - results.setdefault(dest_split, {}) - results[dest_split].setdefault(dest_set, []) - results[dest_split][dest_set].append([converted[dest] for converted in group]) - - -def load_puzzles_arcagi(results: dict, dataset_path: str, config: DataProcessConfig): - train_examples_dest = ("train", "all") - test_examples_map = { - "evaluation": [(1.0, ("test", "all"))], - "_default": [(1.0, ("train", "all"))] - } - - total_puzzles = 0 - for subdir in os.scandir(dataset_path): - if subdir.is_dir(): - # Load all puzzles in this directory - puzzles = [] - for filename in glob(os.path.join(subdir.path, "*.json")): - with open(filename, "r") as f: - puzzles.append((Path(filename).stem, json.load(f))) - - # Shuffle puzzles - np.random.shuffle(puzzles) - - # Assign by fraction - for idx, (default_name, puzzle) in enumerate(puzzles): - fraction = idx / len(puzzles) - test_examples_dest = None - for f, dest in test_examples_map.get(subdir.name, test_examples_map["_default"]): - if fraction < f: - test_examples_dest = dest - break - - assert test_examples_dest is not None - - convert_single_arc_puzzle(results, default_name, puzzle, config.num_aug, {"train": train_examples_dest, "test": test_examples_dest}) - total_puzzles += 1 - - print (f"[{dataset_path}] total puzzles: {total_puzzles}") - - -def convert_dataset(config: DataProcessConfig): - np.random.seed(config.seed) - - # Read dataset - data = {} - for dataset_dir in config.dataset_dirs: - load_puzzles_arcagi(data, dataset_dir, config) - - # Map global puzzle identifiers - num_identifiers = 1 # 0 is blank - identifier_map = {} - for split_name, split in data.items(): - for subset_name, subset in split.items(): - for group in subset: - for puzzle in group: - if puzzle.id not in identifier_map: - identifier_map[puzzle.id] = num_identifiers - num_identifiers += 1 - - print (f"Total puzzle IDs (including ): {num_identifiers}") - - # Save - for split_name, split in data.items(): - os.makedirs(os.path.join(config.output_dir, split_name), exist_ok=True) - - # Translational augmentations - enable_translational_augment = split_name == "train" - - # Statistics - total_examples = 0 - total_puzzles = 0 - total_groups = 0 - - for subset_name, subset in split.items(): - # Construct subset - results = {k: [] for k in ["inputs", "labels", "puzzle_identifiers", "puzzle_indices", "group_indices"]} - results["puzzle_indices"].append(0) - results["group_indices"].append(0) - - example_id = 0 - puzzle_id = 0 - - for group in subset: - for puzzle in group: - # Push puzzle - no_aug_id = np.random.randint(0, len(puzzle.examples)) - for _idx_ex, (inp, out) in enumerate(puzzle.examples): - inp, out = np_grid_to_seq_translational_augment(inp, out, do_translation=enable_translational_augment and _idx_ex != no_aug_id) - - results["inputs"].append(inp) - results["labels"].append(out) - example_id += 1 - - total_examples += 1 - - results["puzzle_indices"].append(example_id) - results["puzzle_identifiers"].append(identifier_map[puzzle.id]) - - puzzle_id += 1 - - total_puzzles += 1 - - # Push group - results["group_indices"].append(puzzle_id) - total_groups += 1 - - for k, v in results.items(): - if k in {"inputs", "labels"}: - v = np.stack(v, 0) - else: - v = np.array(v, dtype=np.int32) - - np.save(os.path.join(config.output_dir, split_name, f"{subset_name}__{k}.npy"), v) - - # Metadata - metadata = PuzzleDatasetMetadata( - seq_len=ARCMaxGridSize * ARCMaxGridSize, - vocab_size=10 + 2, # PAD + EOS + "0" ... "9" - - pad_id=0, - ignore_label_id=0, - - blank_identifier_id=0, - num_puzzle_identifiers=num_identifiers, - - total_groups=total_groups, - mean_puzzle_examples=total_examples / total_puzzles, - sets=list(split.keys()) - ) - - # Save metadata as JSON. - with open(os.path.join(config.output_dir, split_name, "dataset.json"), "w") as f: - json.dump(metadata.model_dump(), f) - - # Save IDs mapping - with open(os.path.join(config.output_dir, "identifiers.json"), "w") as f: - ids_mapping = {v: k for k, v in identifier_map.items()} - - json.dump([ids_mapping.get(i, "") for i in range(num_identifiers)], f) - - -@cli.command(singleton=True) -def main(config: DataProcessConfig): - convert_dataset(config) - - -if __name__ == "__main__": - cli() diff --git a/dataset/build_maze_dataset.py b/dataset/build_maze_dataset.py deleted file mode 100644 index a9367f38..00000000 --- a/dataset/build_maze_dataset.py +++ /dev/null @@ -1,142 +0,0 @@ -from typing import Optional -import math -import os -import csv -import json -import numpy as np - -from argdantic import ArgParser -from pydantic import BaseModel -from tqdm import tqdm -from huggingface_hub import hf_hub_download - -from common import PuzzleDatasetMetadata, dihedral_transform - - -CHARSET = "# SGo" - - -cli = ArgParser() - - -class DataProcessConfig(BaseModel): - source_repo: str = "sapientinc/maze-30x30-hard-1k" - output_dir: str = "data/maze-30x30-hard-1k" - - subsample_size: Optional[int] = None - aug: bool = False - - -def convert_subset(set_name: str, config: DataProcessConfig): - # Read CSV - all_chars = set() - grid_size = None - inputs = [] - labels = [] - - with open(hf_hub_download(config.source_repo, f"{set_name}.csv", repo_type="dataset"), newline="") as csvfile: # type: ignore - reader = csv.reader(csvfile) - next(reader) # Skip header - for source, q, a, rating in reader: - all_chars.update(q) - all_chars.update(a) - - if grid_size is None: - n = int(len(q) ** 0.5) - grid_size = (n, n) - - inputs.append(np.frombuffer(q.encode(), dtype=np.uint8).reshape(grid_size)) - labels.append(np.frombuffer(a.encode(), dtype=np.uint8).reshape(grid_size)) - - # If subsample_size is specified for the training set, - # randomly sample the desired number of examples. - if set_name == "train" and config.subsample_size is not None: - total_samples = len(inputs) - if config.subsample_size < total_samples: - indices = np.random.choice(total_samples, size=config.subsample_size, replace=False) - inputs = [inputs[i] for i in indices] - labels = [labels[i] for i in indices] - - # Generate dataset - results = {k: [] for k in ["inputs", "labels", "puzzle_identifiers", "puzzle_indices", "group_indices"]} - puzzle_id = 0 - example_id = 0 - - results["puzzle_indices"].append(0) - results["group_indices"].append(0) - - for inp, out in zip(tqdm(inputs), labels): - # Dihedral transformations for augmentation - for aug_idx in range(8 if (set_name == "train" and config.aug) else 1): - results["inputs"].append(dihedral_transform(inp, aug_idx)) - results["labels"].append(dihedral_transform(out, aug_idx)) - example_id += 1 - puzzle_id += 1 - - results["puzzle_indices"].append(example_id) - results["puzzle_identifiers"].append(0) - - # Push group - results["group_indices"].append(puzzle_id) - - # Char mappings - assert len(all_chars - set(CHARSET)) == 0 - - char2id = np.zeros(256, np.uint8) - char2id[np.array(list(map(ord, CHARSET)))] = np.arange(len(CHARSET)) + 1 - - # To Numpy - def _seq_to_numpy(seq): - arr = np.vstack([char2id[s.reshape(-1)] for s in seq]) - - return arr - - results = { - "inputs": _seq_to_numpy(results["inputs"]), - "labels": _seq_to_numpy(results["labels"]), - - "group_indices": np.array(results["group_indices"], dtype=np.int32), - "puzzle_indices": np.array(results["puzzle_indices"], dtype=np.int32), - "puzzle_identifiers": np.array(results["puzzle_identifiers"], dtype=np.int32), - } - - # Metadata - metadata = PuzzleDatasetMetadata( - seq_len=int(math.prod(grid_size)), # type: ignore - vocab_size=len(CHARSET) + 1, # PAD + Charset - - pad_id=0, - ignore_label_id=0, - - blank_identifier_id=0, - num_puzzle_identifiers=1, - - total_groups=len(results["group_indices"]) - 1, - mean_puzzle_examples=1, - sets=["all"] - ) - - # Save metadata as JSON. - save_dir = os.path.join(config.output_dir, set_name) - os.makedirs(save_dir, exist_ok=True) - - with open(os.path.join(save_dir, "dataset.json"), "w") as f: - json.dump(metadata.model_dump(), f) - - # Save data - for k, v in results.items(): - np.save(os.path.join(save_dir, f"all__{k}.npy"), v) - - # Save IDs mapping (for visualization only) - with open(os.path.join(config.output_dir, "identifiers.json"), "w") as f: - json.dump([""], f) - - -@cli.command(singleton=True) -def preprocess_data(config: DataProcessConfig): - convert_subset("train", config) - convert_subset("test", config) - - -if __name__ == "__main__": - cli() diff --git a/evaluate.py b/evaluate.py deleted file mode 100644 index 71ee7530..00000000 --- a/evaluate.py +++ /dev/null @@ -1,68 +0,0 @@ -from typing import List -import yaml -import os - -import torch -import torch.distributed as dist - -import pydantic -from omegaconf import OmegaConf -from pretrain import PretrainConfig, init_train_state, evaluate, create_dataloader - - -class EvalConfig(pydantic.BaseModel): - checkpoint: str - - save_outputs: List[str] = ["inputs", "labels", "puzzle_identifiers", "logits", "q_halt_logits", "q_continue_logits"] - - -def launch(): - eval_cfg = EvalConfig(**OmegaConf.to_container(OmegaConf.from_cli())) # type: ignore - - RANK = 0 - WORLD_SIZE = 1 - # Initialize distributed training if in distributed environment (e.g. torchrun) - if "LOCAL_RANK" in os.environ: - # Initialize distributed, default device and dtype - dist.init_process_group(backend="nccl") - - RANK = dist.get_rank() - WORLD_SIZE = dist.get_world_size() - - torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) - - with open(os.path.join(os.path.dirname(eval_cfg.checkpoint), "all_config.yaml"), "r") as f: - config = PretrainConfig(**yaml.safe_load(f)) - - config.eval_save_outputs = eval_cfg.save_outputs - config.checkpoint_path = os.path.dirname(eval_cfg.checkpoint) - - # Dataloader - train_loader, train_metadata = create_dataloader(config, "train", test_set_mode=False, epochs_per_iter=1, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE) - eval_loader, eval_metadata = create_dataloader(config, "test", test_set_mode=True, epochs_per_iter=1, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE) - - # Models - train_state = init_train_state(config, train_metadata, world_size=WORLD_SIZE) - # Try unwrap torch.compile - try: - train_state.model.load_state_dict(torch.load(eval_cfg.checkpoint, map_location="cuda"), assign=True) - except: - train_state.model.load_state_dict({k.removeprefix("_orig_mod."): v for k, v in torch.load(eval_cfg.checkpoint, map_location="cuda").items()}, assign=True) - - train_state.step = 0 - ckpt_filename = os.path.basename(eval_cfg.checkpoint) - if ckpt_filename.startswith("step_"): - train_state.step = int(ckpt_filename.removeprefix("step_")) - - # Evaluate - print ("Starting evaluation") - - train_state.model.eval() - metrics = evaluate(config, train_state, eval_loader, eval_metadata, rank=RANK, world_size=WORLD_SIZE) - - if metrics is not None: - print (metrics) - - -if __name__ == "__main__": - launch() diff --git a/models/adaptive_depth.py b/models/adaptive_depth.py index caddaf7c..600d43c6 100644 --- a/models/adaptive_depth.py +++ b/models/adaptive_depth.py @@ -154,8 +154,6 @@ def forward( Runs the model iteratively, checking confidence at each step. Halts early for confident samples, continues for uncertain ones. """ - from models.hrm.hrm_act_v1 import HierarchicalReasoningModel_ACTV1Carry - # Initialize new_inner_carry = self.model.inner.reset_carry(carry.halted, carry.inner_carry) new_steps = torch.where(carry.halted, 0, carry.steps) @@ -192,8 +190,6 @@ def forward( outputs["depth_info"] = depth_info return ( - HierarchicalReasoningModel_ACTV1Carry( - new_inner_carry, new_steps, halted, new_current_data - ), + type(carry)(new_inner_carry, new_steps, halted, new_current_data), outputs, ) diff --git a/models/hrm/hrm_act_v1.py b/models/hrm/hrm_act_v1.py deleted file mode 100644 index e91c7d1a..00000000 --- a/models/hrm/hrm_act_v1.py +++ /dev/null @@ -1,283 +0,0 @@ -from typing import Tuple, List, Dict, Optional -from dataclasses import dataclass -import math - -import torch -import torch.nn.functional as F -from torch import nn -from pydantic import BaseModel - -from models.common import trunc_normal_init_ -from models.layers import rms_norm, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear -from models.sparse_embedding import CastedSparseEmbedding - - -@dataclass -class HierarchicalReasoningModel_ACTV1InnerCarry: - z_H: torch.Tensor - z_L: torch.Tensor - - -@dataclass -class HierarchicalReasoningModel_ACTV1Carry: - inner_carry: HierarchicalReasoningModel_ACTV1InnerCarry - - steps: torch.Tensor - halted: torch.Tensor - - current_data: Dict[str, torch.Tensor] - - -class HierarchicalReasoningModel_ACTV1Config(BaseModel): - batch_size: int - seq_len: int - puzzle_emb_ndim: int = 0 - num_puzzle_identifiers: int - vocab_size: int - - H_cycles: int - L_cycles: int - - H_layers: int - L_layers: int - - # Transformer config - hidden_size: int - expansion: float - num_heads: int - pos_encodings: str - - rms_norm_eps: float = 1e-5 - rope_theta: float = 10000.0 - - # Halting Q-learning config - halt_max_steps: int - halt_exploration_prob: float - - forward_dtype: str = "bfloat16" - - -class HierarchicalReasoningModel_ACTV1Block(nn.Module): - def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> None: - super().__init__() - - self.self_attn = Attention( - hidden_size=config.hidden_size, - head_dim=config.hidden_size // config.num_heads, - num_heads=config.num_heads, - num_key_value_heads=config.num_heads, - causal=False - ) - self.mlp = SwiGLU( - hidden_size=config.hidden_size, - expansion=config.expansion, - ) - self.norm_eps = config.rms_norm_eps - - def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor: - # Post Norm - # Self Attention - hidden_states = rms_norm(hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states), variance_epsilon=self.norm_eps) - # Fully Connected - hidden_states = rms_norm(hidden_states + self.mlp(hidden_states), variance_epsilon=self.norm_eps) - return hidden_states - - -class HierarchicalReasoningModel_ACTV1ReasoningModule(nn.Module): - def __init__(self, layers: List[HierarchicalReasoningModel_ACTV1Block]): - super().__init__() - - self.layers = torch.nn.ModuleList(layers) - - def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, **kwargs) -> torch.Tensor: - # Input injection (add) - hidden_states = hidden_states + input_injection - # Layers - for layer in self.layers: - hidden_states = layer(hidden_states=hidden_states, **kwargs) - - return hidden_states - - -class HierarchicalReasoningModel_ACTV1_Inner(nn.Module): - def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> None: - super().__init__() - self.config = config - self.forward_dtype = getattr(torch, self.config.forward_dtype) - - # I/O - self.embed_scale = math.sqrt(self.config.hidden_size) - embed_init_std = 1.0 / self.embed_scale - - self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype) - self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False) - self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True) - - self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) # ceil div - if self.config.puzzle_emb_ndim > 0: - # Zero init puzzle embeddings - self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim, - batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype) - - # LM Blocks - if self.config.pos_encodings == "rope": - self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads, - max_position_embeddings=self.config.seq_len + self.puzzle_emb_len, - base=self.config.rope_theta) - elif self.config.pos_encodings == "learned": - self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype) - else: - raise NotImplementedError() - - # Reasoning Layers - self.H_level = HierarchicalReasoningModel_ACTV1ReasoningModule(layers=[HierarchicalReasoningModel_ACTV1Block(self.config) for _i in range(self.config.H_layers)]) - self.L_level = HierarchicalReasoningModel_ACTV1ReasoningModule(layers=[HierarchicalReasoningModel_ACTV1Block(self.config) for _i in range(self.config.L_layers)]) - - # Initial states - self.H_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True) - self.L_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True) - - # Q head special init - # Init Q to (almost) zero for faster learning during bootstrapping - with torch.no_grad(): - self.q_head.weight.zero_() - self.q_head.bias.fill_(-5) # type: ignore - - def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor): - # Token embedding - embedding = self.embed_tokens(input.to(torch.int32)) - - # Puzzle embeddings - if self.config.puzzle_emb_ndim > 0: - puzzle_embedding = self.puzzle_emb(puzzle_identifiers) - - pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1] - if pad_count > 0: - puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count)) - - embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2) - - # Position embeddings - if self.config.pos_encodings == "learned": - # scale by 1/sqrt(2) to maintain forward variance - embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype)) - - # Scale - return self.embed_scale * embedding - - def empty_carry(self, batch_size: int): - return HierarchicalReasoningModel_ACTV1InnerCarry( - z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype), - z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype), - ) - - def reset_carry(self, reset_flag: torch.Tensor, carry: HierarchicalReasoningModel_ACTV1InnerCarry): - return HierarchicalReasoningModel_ACTV1InnerCarry( - z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H), - z_L=torch.where(reset_flag.view(-1, 1, 1), self.L_init, carry.z_L), - ) - - def forward(self, carry: HierarchicalReasoningModel_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]) -> Tuple[HierarchicalReasoningModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - seq_info = dict( - cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None, - ) - - # Input encoding - input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"]) - - # Forward iterations - with torch.no_grad(): - z_H, z_L = carry.z_H, carry.z_L - - for _H_step in range(self.config.H_cycles): - for _L_step in range(self.config.L_cycles): - if not ((_H_step == self.config.H_cycles - 1) and (_L_step == self.config.L_cycles - 1)): - z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info) - - if not (_H_step == self.config.H_cycles - 1): - z_H = self.H_level(z_H, z_L, **seq_info) - - assert not z_H.requires_grad and not z_L.requires_grad - - # 1-step grad - z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info) - z_H = self.H_level(z_H, z_L, **seq_info) - - # LM Outputs - new_carry = HierarchicalReasoningModel_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach()) # New carry no grad - output = self.lm_head(z_H)[:, self.puzzle_emb_len:] - - # Q head - q_logits = self.q_head(z_H[:, 0]).to(torch.float32) - - return new_carry, output, (q_logits[..., 0], q_logits[..., 1]) - - -class HierarchicalReasoningModel_ACTV1(nn.Module): - """ACT wrapper.""" - - def __init__(self, config_dict: dict): - super().__init__() - self.config = HierarchicalReasoningModel_ACTV1Config(**config_dict) - self.inner = HierarchicalReasoningModel_ACTV1_Inner(self.config) - - @property - def puzzle_emb(self): - return self.inner.puzzle_emb - - def initial_carry(self, batch: Dict[str, torch.Tensor]): - batch_size = batch["inputs"].shape[0] - - return HierarchicalReasoningModel_ACTV1Carry( - inner_carry=self.inner.empty_carry(batch_size), # Empty is expected, it will be reseted in first pass as all sequences are halted. - - steps=torch.zeros((batch_size, ), dtype=torch.int32), - halted=torch.ones((batch_size, ), dtype=torch.bool), # Default to halted - - current_data={k: torch.empty_like(v) for k, v in batch.items()} - ) - - def forward(self, carry: HierarchicalReasoningModel_ACTV1Carry, batch: Dict[str, torch.Tensor]) -> Tuple[HierarchicalReasoningModel_ACTV1Carry, Dict[str, torch.Tensor]]: - # Update data, carry (removing halted sequences) - new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry) - - new_steps = torch.where(carry.halted, 0, carry.steps) - - new_current_data = {k: torch.where(carry.halted.view((-1, ) + (1, ) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()} - - # Forward inner model - new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(new_inner_carry, new_current_data) - - outputs = { - "logits": logits, - "q_halt_logits": q_halt_logits, - "q_continue_logits": q_continue_logits - } - - with torch.no_grad(): - # Step - new_steps = new_steps + 1 - is_last_step = new_steps >= self.config.halt_max_steps - - halted = is_last_step - - # if training, and ACT is enabled - if self.training and (self.config.halt_max_steps > 1): - # Halt signal - # NOTE: During evaluation, always use max steps, this is to guarantee the same halting steps inside a batch for batching purposes - halted = halted | (q_halt_logits > q_continue_logits) - - # Exploration - min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1) - - halted = halted & (new_steps >= min_halt_steps) - - # Compute target Q - # NOTE: No replay buffer and target networks for computing target Q-value. - # As batch_size is large, there're many parallel envs. - # Similar concept as PQN https://arxiv.org/abs/2407.04811 - next_q_halt_logits, next_q_continue_logits = self.inner(new_inner_carry, new_current_data)[-1] - - outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, next_q_halt_logits, torch.maximum(next_q_halt_logits, next_q_continue_logits))) - - return HierarchicalReasoningModel_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs diff --git a/models/sparse_embedding.py b/models/sparse_embedding.py deleted file mode 100644 index c701524b..00000000 --- a/models/sparse_embedding.py +++ /dev/null @@ -1,132 +0,0 @@ -from typing import Union - -import torch -from torch import nn -import torch.distributed as dist -from torch.optim.optimizer import Optimizer, ParamsT - -from models.common import trunc_normal_init_ - - -class CastedSparseEmbedding(nn.Module): - def __init__(self, num_embeddings: int, embedding_dim: int, batch_size: int, init_std: float, cast_to: torch.dtype): - super().__init__() - self.cast_to = cast_to - - # Real Weights - # Truncated LeCun normal init - self.weights = nn.Buffer( - trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std), persistent=True - ) - - # Local weights and IDs - # Local embeddings, with gradient, not persistent - self.local_weights = nn.Buffer(torch.zeros(batch_size, embedding_dim, requires_grad=True), persistent=False) - # Local embedding IDs, not persistent - self.local_ids = nn.Buffer(torch.zeros(batch_size, dtype=torch.int32), persistent=False) - - def forward(self, inputs: torch.Tensor) -> torch.Tensor: - if not self.training: - # Test mode, no gradient - return self.weights[inputs].to(self.cast_to) - - # Training mode, fill puzzle embedding from weights - with torch.no_grad(): - self.local_weights.copy_(self.weights[inputs]) - self.local_ids.copy_(inputs) - - return self.local_weights.to(self.cast_to) - - -class CastedSparseEmbeddingSignSGD_Distributed(Optimizer): - def __init__( - self, - params: ParamsT, - - world_size: int, - lr: Union[float, torch.Tensor] = 1e-3, - weight_decay: float = 1e-2, - ): - if not 0.0 <= lr: - raise ValueError(f"Invalid learning rate: {lr}") - if not 0.0 <= weight_decay: - raise ValueError(f"Invalid weight_decay value: {weight_decay}") - - defaults = dict( - lr=lr, - weight_decay=weight_decay, - world_size=world_size - ) - super().__init__(params, defaults) - - @torch.no_grad - def step(self, closure=None): # type: ignore - for group in self.param_groups: - # Find the sparse embedding weights - local_weights_grad = None - local_ids = None - weights = None - - assert len(group["params"]) == 3 - for p in group["params"]: - if p.requires_grad: - local_weights_grad = p.grad - elif p.ndim == 1: - local_ids = p - elif p.ndim == 2: - weights = p - else: - assert False - - assert local_weights_grad is not None - assert local_ids is not None - assert weights is not None - - # Apply SignSGD - # Adam ≈ SignSGD if gradient is very sparse - _sparse_emb_signsgd_dist( - local_weights_grad, - local_ids, - weights, - - lr=group["lr"], - weight_decay=group["weight_decay"], - world_size=group["world_size"] - ) - - -def _sparse_emb_signsgd_dist( - local_weights_grad: torch.Tensor, - local_ids: torch.Tensor, - weights: torch.Tensor, - - lr: float, - weight_decay: float, - world_size: int -) -> None: - N, D = local_weights_grad.shape - - # All-gather - all_weights_grad = local_weights_grad - all_ids = local_ids - - if world_size > 1: - all_weights_grad = torch.empty((world_size * N, D), dtype=local_weights_grad.dtype, device=local_weights_grad.device) - all_ids = torch.empty(world_size * N, dtype=local_ids.dtype, device=local_ids.device) - - dist.all_gather_into_tensor(all_weights_grad, local_weights_grad) - dist.all_gather_into_tensor(all_ids, local_ids) - - # Unique - grad_ids, inv = all_ids.unique(return_inverse=True) - - grad = torch.zeros((grad_ids.shape[0], D), dtype=all_weights_grad.dtype, device=all_weights_grad.device) - grad.scatter_add_(0, inv.unsqueeze(-1).expand(-1, D), all_weights_grad) - - # SignSGD with decoupled weight decay - p = weights[grad_ids] - - p.mul_(1.0 - lr * weight_decay).add_(torch.sign(grad), alpha=-lr) - - # Write updated slices back - weights[grad_ids] = p diff --git a/models/topological.py b/models/topological.py index 22b28b00..eea473bb 100644 --- a/models/topological.py +++ b/models/topological.py @@ -120,17 +120,17 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Te components = (degree / (degree.sum(dim=-1, keepdim=True) + 1e-6)).sum(dim=-1) component_counts.append(components) - component_counts = torch.stack(component_counts, dim=-1) # (bs, n, num_steps) + component_counts = torch.stack(component_counts, dim=-1) # (bs, num_steps) # Betti-0: max components across filtration (persistent) - betti_0 = component_counts.max(dim=-1).values.mean(dim=-1, keepdim=True) + betti_0 = component_counts.max(dim=-1, keepdim=True).values # Betti-1: loops appear when components merge but don't fill in # Approximate: count "births" of 1-cycles # A 1-cycle is born when two previously separate components connect # but the enclosed region is not yet filled - diffs = component_counts[:, :, 1:] - component_counts[:, :, :-1] - loop_evidence = (diffs < 0).float().sum(dim=-1).mean(dim=-1, keepdim=True) + diffs = component_counts[:, 1:] - component_counts[:, :-1] + loop_evidence = (diffs < 0).float().sum(dim=-1, keepdim=True) betti_1 = torch.sigmoid(loop_evidence) # Topological feature vector: combines local and global topology diff --git a/pretrain.py b/pretrain.py deleted file mode 100644 index 245cb5c7..00000000 --- a/pretrain.py +++ /dev/null @@ -1,453 +0,0 @@ -from typing import Optional, Any, Sequence, List -from dataclasses import dataclass -import os -import math -import yaml -import shutil - -import torch -import torch.distributed as dist -from torch import nn -from torch.utils.data import DataLoader - -import tqdm -import wandb -import coolname -import hydra -import pydantic -from omegaconf import DictConfig -from adam_atan2 import AdamATan2 - -from puzzle_dataset import PuzzleDataset, PuzzleDatasetConfig, PuzzleDatasetMetadata -from utils.functions import load_model_class, get_model_source_path -from models.sparse_embedding import CastedSparseEmbeddingSignSGD_Distributed - - -class LossConfig(pydantic.BaseModel): - model_config = pydantic.ConfigDict(extra='allow') - - name: str - - -class ArchConfig(pydantic.BaseModel): - model_config = pydantic.ConfigDict(extra='allow') - - name: str - loss: LossConfig - - -class PretrainConfig(pydantic.BaseModel): - # Config - arch: ArchConfig - # Data - data_path: str - - # Hyperparams - global_batch_size: int - epochs: int - - lr: float - lr_min_ratio: float - lr_warmup_steps: int - - weight_decay: float - beta1: float - beta2: float - - # Puzzle embedding - puzzle_emb_lr: float - puzzle_emb_weight_decay: float - - # Names - project_name: Optional[str] = None - run_name: Optional[str] = None - checkpoint_path: Optional[str] = None - - # Extras - seed: int = 0 - checkpoint_every_eval: bool = False - eval_interval: Optional[int] = None - eval_save_outputs: List[str] = [] - - -@dataclass -class TrainState: - model: nn.Module - optimizers: Sequence[torch.optim.Optimizer] - optimizer_lrs: Sequence[float] - carry: Any - - step: int - total_steps: int - - -def create_dataloader(config: PretrainConfig, split: str, rank: int, world_size: int, **kwargs): - dataset = PuzzleDataset(PuzzleDatasetConfig( - seed=config.seed, - - dataset_path=config.data_path, - - rank=rank, - num_replicas=world_size, - - **kwargs - ), split=split) - dataloader = DataLoader( - dataset, - batch_size=None, - - num_workers=1, - prefetch_factor=8, - - pin_memory=True, - persistent_workers=True - ) - return dataloader, dataset.metadata - - -def create_model(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, world_size: int): - model_cfg = dict( - **config.arch.__pydantic_extra__, # type: ignore - - batch_size=config.global_batch_size // world_size, - - vocab_size=train_metadata.vocab_size, - seq_len=train_metadata.seq_len, - num_puzzle_identifiers=train_metadata.num_puzzle_identifiers, - causal=False # Non-autoregressive - ) - - # Instantiate model with loss head - model_cls = load_model_class(config.arch.name) - loss_head_cls = load_model_class(config.arch.loss.name) - - with torch.device("cuda"): - model: nn.Module = model_cls(model_cfg) - model = loss_head_cls(model, **config.arch.loss.__pydantic_extra__) # type: ignore - if "DISABLE_COMPILE" not in os.environ: - model = torch.compile(model, dynamic=False) # type: ignore - - # Broadcast parameters from rank 0 - if world_size > 1: - with torch.no_grad(): - for param in list(model.parameters()) + list(model.buffers()): - dist.broadcast(param, src=0) - - # Optimizers and lr - optimizers = [ - CastedSparseEmbeddingSignSGD_Distributed( - model.model.puzzle_emb.buffers(), # type: ignore - - lr=0, # Needs to be set by scheduler - weight_decay=config.puzzle_emb_weight_decay, - - world_size=world_size - ), - AdamATan2( - model.parameters(), - - lr=0, # Needs to be set by scheduler - weight_decay=config.weight_decay, - betas=(config.beta1, config.beta2) - ) - ] - optimizer_lrs = [ - config.puzzle_emb_lr, - config.lr - ] - - return model, optimizers, optimizer_lrs - - -def cosine_schedule_with_warmup_lr_lambda( - current_step: int, *, base_lr: float, num_warmup_steps: int, num_training_steps: int, min_ratio: float = 0.0, num_cycles: float = 0.5 -): - if current_step < num_warmup_steps: - return base_lr * float(current_step) / float(max(1, num_warmup_steps)) - - progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) - return base_lr * (min_ratio + max(0.0, (1 - min_ratio) * 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))) - - -def init_train_state(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, world_size: int): - # Estimated total training steps - total_steps = int(config.epochs * train_metadata.total_groups * train_metadata.mean_puzzle_examples / config.global_batch_size) - - # Model - model, optimizers, optimizer_lrs = create_model(config, train_metadata, world_size=world_size) - - return TrainState( - step=0, - total_steps=total_steps, - - model=model, - optimizers=optimizers, - optimizer_lrs=optimizer_lrs, - carry=None - ) - - -def save_train_state(config: PretrainConfig, train_state: TrainState): - # FIXME: Only saved model. - if config.checkpoint_path is None: - return - - os.makedirs(config.checkpoint_path, exist_ok=True) - torch.save(train_state.model.state_dict(), os.path.join(config.checkpoint_path, f"step_{train_state.step}")) - - -def compute_lr(base_lr: float, config: PretrainConfig, train_state: TrainState): - return cosine_schedule_with_warmup_lr_lambda( - current_step=train_state.step, - base_lr=base_lr, - num_warmup_steps=round(config.lr_warmup_steps), - num_training_steps=train_state.total_steps, - min_ratio=config.lr_min_ratio - ) - - -def train_batch(config: PretrainConfig, train_state: TrainState, batch: Any, global_batch_size: int, rank: int, world_size: int): - train_state.step += 1 - if train_state.step > train_state.total_steps: # At most train_total_steps - return - - # To device - batch = {k: v.cuda() for k, v in batch.items()} - - # Init carry if it is None - if train_state.carry is None: - with torch.device("cuda"): - train_state.carry = train_state.model.initial_carry(batch) # type: ignore - - # Forward - train_state.carry, loss, metrics, _, _ = train_state.model(carry=train_state.carry, batch=batch, return_keys=[]) - - ((1 / global_batch_size) * loss).backward() - - # Allreduce - if world_size > 1: - for param in train_state.model.parameters(): - if param.grad is not None: - dist.all_reduce(param.grad) - - # Apply optimizer - lr_this_step = None - for optim, base_lr in zip(train_state.optimizers, train_state.optimizer_lrs): - lr_this_step = compute_lr(base_lr, config, train_state) - - for param_group in optim.param_groups: - param_group['lr'] = lr_this_step - - optim.step() - optim.zero_grad() - - # Reduce metrics - if len(metrics): - assert not any(v.requires_grad for v in metrics.values()) - - metric_keys = list(sorted(metrics.keys())) # Sort keys to guarantee all processes use the same order. - # Reduce and reconstruct - metric_values = torch.stack([metrics[k] for k in metric_keys]) - if world_size > 1: - dist.reduce(metric_values, dst=0) - - if rank == 0: - metric_values = metric_values.cpu().numpy() - reduced_metrics = {k: metric_values[i] for i, k in enumerate(metric_keys)} - - # Postprocess - count = max(reduced_metrics["count"], 1) # Avoid NaNs - reduced_metrics = {f"train/{k}": v / (global_batch_size if k.endswith("loss") else count) for k, v in reduced_metrics.items()} - - reduced_metrics["train/lr"] = lr_this_step - return reduced_metrics - - -def evaluate(config: PretrainConfig, train_state: TrainState, eval_loader: torch.utils.data.DataLoader, eval_metadata: PuzzleDatasetMetadata, rank: int, world_size: int): - with torch.inference_mode(): - set_ids = {k: idx for idx, k in enumerate(eval_metadata.sets)} - - all_preds = {} - - metric_keys = [] - metric_values = None - metric_global_batch_size = [0 for _ in range(len(set_ids))] - - carry = None - for set_name, batch, global_batch_size in eval_loader: - # To device - batch = {k: v.cuda() for k, v in batch.items()} - with torch.device("cuda"): - carry = train_state.model.initial_carry(batch) # type: ignore - - # Forward - while True: - carry, _, metrics, preds, all_finish = train_state.model(carry=carry, batch=batch, return_keys=config.eval_save_outputs) - - if all_finish: - break - - for collection in (batch, preds): - for k, v in collection.items(): - if k in config.eval_save_outputs: - all_preds.setdefault(k, []) - all_preds[k].append(v.cpu()) # Move to CPU for saving GPU memory - - del carry, preds, batch, all_finish - - # Aggregate - set_id = set_ids[set_name] - - if metric_values is None: - metric_keys = list(sorted(metrics.keys())) # Sort keys to guarantee all processes use the same order. - metric_values = torch.zeros((len(set_ids), len(metrics.values())), dtype=torch.float32, device="cuda") - - metric_values[set_id] += torch.stack([metrics[k] for k in metric_keys]) - metric_global_batch_size[set_id] += global_batch_size - - if len(all_preds) and config.checkpoint_path is not None: - all_preds = {k: torch.cat(v, dim=0) for k, v in all_preds.items()} - - os.makedirs(config.checkpoint_path, exist_ok=True) - torch.save(all_preds, os.path.join(config.checkpoint_path, f"step_{train_state.step}_all_preds.{rank}")) - - # Logging - # Reduce to rank 0 - if metric_values is not None: - if world_size > 1: - dist.reduce(metric_values, dst=0) - - if rank == 0: - reduced_metrics = metric_values.cpu().numpy() - reduced_metrics = {set_name: {metric_name: reduced_metrics[set_id, metric_id] for metric_id, metric_name in enumerate(metric_keys)} - for set_id, set_name in enumerate(set_ids)} - - # Postprocess - for set_name, metrics in reduced_metrics.items(): - count = metrics.pop("count") - reduced_metrics[set_name] = {k: v / count for k, v in metrics.items()} - - return reduced_metrics - - -def save_code_and_config(config: PretrainConfig): - if config.checkpoint_path is None or wandb.run is None: - return - - os.makedirs(config.checkpoint_path, exist_ok=True) - - # Copy code - code_list = [ - get_model_source_path(config.arch.name), - get_model_source_path(config.arch.loss.name) - ] - for code_file in code_list: - if code_file is not None: - code_name = os.path.basename(code_file) - - shutil.copy(code_file, os.path.join(config.checkpoint_path, code_name)) - - # Dump config as yaml - config_file = os.path.join(config.checkpoint_path, "all_config.yaml") - with open(config_file, "wt") as f: - yaml.dump(config.model_dump(), f) - - # Log code - wandb.run.log_code(config.checkpoint_path) - - -def load_synced_config(hydra_config: DictConfig, rank: int, world_size: int) -> PretrainConfig: - objects = [None] - if rank == 0: - config = PretrainConfig(**hydra_config) # type: ignore - - # Naming - if config.project_name is None: - config.project_name = f"{os.path.basename(config.data_path).capitalize()} ACT-torch" - if config.run_name is None: - config.run_name = f"{config.arch.name.split('@')[-1]} {coolname.generate_slug(2)}" - if config.checkpoint_path is None: - config.checkpoint_path = os.path.join("checkpoints", config.project_name, config.run_name) - - objects = [config] - - if world_size > 1: - dist.broadcast_object_list(objects, src=0) - - return objects[0] # type: ignore - - -@hydra.main(config_path="config", config_name="cfg_pretrain", version_base=None) -def launch(hydra_config: DictConfig): - RANK = 0 - WORLD_SIZE = 1 - - # Initialize distributed training if in distributed environment (e.g. torchrun) - if "LOCAL_RANK" in os.environ: - # Initialize distributed, default device and dtype - dist.init_process_group(backend="nccl") - - RANK = dist.get_rank() - WORLD_SIZE = dist.get_world_size() - - torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) - - # Load sync'ed config - config = load_synced_config(hydra_config, rank=RANK, world_size=WORLD_SIZE) - - # Seed RNGs to ensure consistency - torch.random.manual_seed(config.seed + RANK) - - # Dataset - train_epochs_per_iter = config.eval_interval if config.eval_interval is not None else config.epochs - total_iters = config.epochs // train_epochs_per_iter - - assert config.epochs % train_epochs_per_iter == 0, "Eval interval must be a divisor of total epochs." - - train_loader, train_metadata = create_dataloader(config, "train", test_set_mode=False, epochs_per_iter=train_epochs_per_iter, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE) - eval_loader, eval_metadata = create_dataloader(config, "test", test_set_mode=True, epochs_per_iter=1, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE) - - # Train state - train_state = init_train_state(config, train_metadata, world_size=WORLD_SIZE) - - # Progress bar and logger - progress_bar = None - if RANK == 0: - progress_bar = tqdm.tqdm(total=train_state.total_steps) - - wandb.init(project=config.project_name, name=config.run_name, config=config.model_dump(), settings=wandb.Settings(_disable_stats=True)) # type: ignore - wandb.log({"num_params": sum(x.numel() for x in train_state.model.parameters())}, step=0) - save_code_and_config(config) - - # Training Loop - for _iter_id in range(total_iters): - print (f"[Rank {RANK}, World Size {WORLD_SIZE}]: Epoch {_iter_id * train_epochs_per_iter}") - - ############ Train Iter - train_state.model.train() - for set_name, batch, global_batch_size in train_loader: - metrics = train_batch(config, train_state, batch, global_batch_size, rank=RANK, world_size=WORLD_SIZE) - - if RANK == 0 and metrics is not None: - wandb.log(metrics, step=train_state.step) - progress_bar.update(train_state.step - progress_bar.n) # type: ignore - - ############ Evaluation - train_state.model.eval() - metrics = evaluate(config, train_state, eval_loader, eval_metadata, rank=RANK, world_size=WORLD_SIZE) - - if RANK == 0 and metrics is not None: - wandb.log(metrics, step=train_state.step) - - ############ Checkpointing - if RANK == 0 and (config.checkpoint_every_eval or (_iter_id == total_iters - 1)): - save_train_state(config, train_state) - - # finalize - if dist.is_initialized(): - dist.destroy_process_group() - wandb.finish() - - -if __name__ == "__main__": - launch() diff --git a/puzzle_dataset.py b/puzzle_dataset.py deleted file mode 100644 index 2782403c..00000000 --- a/puzzle_dataset.py +++ /dev/null @@ -1,199 +0,0 @@ -import os -import json - -import numpy as np -import pydantic - -import torch -from torch.utils.data import IterableDataset, get_worker_info - -from models.losses import IGNORE_LABEL_ID -from dataset.common import PuzzleDatasetMetadata - - -def _sample_batch(rng: np.random.Generator, group_order: np.ndarray, puzzle_indices: np.ndarray, group_indices: np.ndarray, start_index: int, global_batch_size: int): - # Pack examples into a full batch - batch = [] - batch_puzzle_indices = [] - current_size = 0 - - while (start_index < group_order.size) and (current_size < global_batch_size): - # Pick a group and a puzzle from that group - group_id = group_order[start_index] - puzzle_id = rng.integers(group_indices[group_id], group_indices[group_id + 1]) - start_index += 1 - - # Get range of the puzzle - puzzle_start = puzzle_indices[puzzle_id] - puzzle_size = int(puzzle_indices[puzzle_id + 1] - puzzle_start) - - append_size = min(puzzle_size, global_batch_size - current_size) - - # Put into batch - batch_puzzle_indices.append(np.full(append_size, puzzle_id, dtype=np.int32)) - batch.append(puzzle_start + np.random.choice(puzzle_size, append_size, replace=False)) - - current_size += append_size - - return start_index, np.concatenate(batch), np.concatenate(batch_puzzle_indices) - - -class PuzzleDatasetConfig(pydantic.BaseModel): - seed: int - dataset_path: str - global_batch_size: int - test_set_mode: bool - - epochs_per_iter: int # Batch X epochs in an iteration to reduce overhead. - - rank: int - num_replicas: int - - -class PuzzleDataset(IterableDataset): - def __init__(self, config: PuzzleDatasetConfig, split: str = "train"): - super().__init__() - self.config = config - self.split = split - self.metadata = self._load_metadata() - - # Checks - assert self.config.global_batch_size % self.config.num_replicas == 0, f"Global batch size {self.config.global_batch_size} must be multiples of nodes {self.config.num_replicas}." - self.local_batch_size = self.config.global_batch_size // self.config.num_replicas - - # State - self._data = None - self._iters = 0 - - def _load_metadata(self) -> PuzzleDatasetMetadata: - with open(os.path.join(self.config.dataset_path, self.split, "dataset.json"), "r") as f: - return PuzzleDatasetMetadata(**json.load(f)) - - def _lazy_load_dataset(self): - if self._data is not None: - return - - field_mmap_modes = { - "inputs": "r", - "labels": "r", - - # Keep indices in memory - "puzzle_identifiers": None, - "puzzle_indices": None, - "group_indices": None - } - - # Load data - self._data = {} - for set_name in self.metadata.sets: - # Load subset - self._data[set_name] = { - field_name: np.load(os.path.join(self.config.dataset_path, self.split, f"{set_name}__{field_name}.npy"), mmap_mode=mmap_mode) - for field_name, mmap_mode in field_mmap_modes.items() - } - - def _collate_batch(self, batch): - # Convert dtype - batch = {k: v.astype(np.int32) for k, v in batch.items()} - - # Convert ignore label IDs - if self.metadata.ignore_label_id is not None: - batch["labels"][batch["labels"] == self.metadata.ignore_label_id] = IGNORE_LABEL_ID - - # Pad - if batch["puzzle_identifiers"].size < self.local_batch_size: - pad_size = self.local_batch_size - batch["puzzle_identifiers"].size - - pad_values = { - "inputs": self.metadata.pad_id, - "labels": IGNORE_LABEL_ID, - - "puzzle_identifiers": self.metadata.blank_identifier_id - } - batch = {k: np.pad(v, ((0, pad_size), ) + ((0, 0), ) * (v.ndim - 1), constant_values=pad_values[k]) for k, v in batch.items()} - - # To tensor - return {k: torch.from_numpy(v) for k, v in batch.items()} - - def _iter_test(self): - for set_name, dataset in self._data.items(): # type: ignore - total_examples = len(dataset["inputs"]) - - # Load examples one by one - start_index = 0 - while start_index < total_examples: - # Compute indices - end_index = min(total_examples, start_index + self.config.global_batch_size) - - local_start = start_index + self.config.rank * self.local_batch_size - local_end = min(start_index + (self.config.rank + 1) * self.local_batch_size, end_index) - - # Get batch of examples, and also puzzle IDs - puzzle_indices = [] - puzzle_index = np.searchsorted(dataset["puzzle_indices"], local_start, side="right") - 1 - for i in range(local_start, local_end): - while puzzle_index + 1 < len(dataset["puzzle_indices"]) and i >= dataset["puzzle_indices"][puzzle_index + 1]: - puzzle_index += 1 - - puzzle_indices.append(puzzle_index) - - batch = self._collate_batch({ - "inputs": dataset["inputs"][local_start: local_end], - "labels": dataset["labels"][local_start: local_end], - "puzzle_identifiers": dataset["puzzle_identifiers"][puzzle_indices] - }) - - yield set_name, batch, end_index - start_index - - # Advance to next batch - start_index += self.config.global_batch_size - - def _iter_train(self): - for set_name, dataset in self._data.items(): # type: ignore - # Increase epoch count - self._iters += 1 - - # Randomly shuffle groups - rng = np.random.Generator(np.random.Philox(seed=self.config.seed + self._iters)) - - group_order = np.concatenate([rng.permutation(dataset["group_indices"].size - 1) for _i in range(self.config.epochs_per_iter)]) - start_index = 0 - - while start_index < group_order.size: - start_index, batch_indices, batch_puzzle_indices = _sample_batch( - rng, - group_order=group_order, - puzzle_indices=dataset["puzzle_indices"], - group_indices=dataset["group_indices"], - start_index=start_index, - global_batch_size=self.config.global_batch_size, - ) - - # Select current rank and collate - global_effective_batch_size = batch_puzzle_indices.size # Global effective batch size, excluding pads - - # Drop last batch - if global_effective_batch_size < self.config.global_batch_size: - break - - batch_indices = batch_indices [self.config.rank * self.local_batch_size: (self.config.rank + 1) * self.local_batch_size] - batch_puzzle_indices = batch_puzzle_indices[self.config.rank * self.local_batch_size: (self.config.rank + 1) * self.local_batch_size] - batch = self._collate_batch({ - "inputs": dataset["inputs"][batch_indices], - "labels": dataset["labels"][batch_indices], - "puzzle_identifiers": dataset["puzzle_identifiers"][batch_puzzle_indices] - }) - - yield set_name, batch, global_effective_batch_size - - def __iter__(self): - worker_info = get_worker_info() - assert worker_info is None or worker_info.num_workers == 1, "Multithreaded data loading is not currently supported." - - self._lazy_load_dataset() - - # Iterate using specified mode - if self.config.test_set_mode: - yield from self._iter_test() - else: - yield from self._iter_train() diff --git a/puzzle_visualizer.html b/puzzle_visualizer.html deleted file mode 100644 index bcefdf1c..00000000 --- a/puzzle_visualizer.html +++ /dev/null @@ -1,426 +0,0 @@ - - - - - ARC‐Converted Dataset Visualizer (Upload Local Folder) - - - -

ARC‐Converted Dataset Visualizer (Local Directory)

- -
- - - -

- - - - - - - - - -
- -
-
-
-
-
- - - - - - - From a6414ad7f837791e72dc8612f6a56033922dd4e0 Mon Sep 17 00:00:00 2001 From: jeevesh415 Date: Sun, 17 May 2026 16:26:20 +0530 Subject: [PATCH 5/7] Regenerate full line-level CODE_ADDRESS_INDEX after repo changes --- CODE_ADDRESS_INDEX.md | 572 ++++++++++++++++++++++++++++++++++++++---- 1 file changed, 523 insertions(+), 49 deletions(-) diff --git a/CODE_ADDRESS_INDEX.md b/CODE_ADDRESS_INDEX.md index 69d97f5f..9e35997f 100644 --- a/CODE_ADDRESS_INDEX.md +++ b/CODE_ADDRESS_INDEX.md @@ -1,51 +1,525 @@ # CODE ADDRESS INDEX -Auto-generated repository file map. - -- `.github/workflows/sync-from-upstream.yml` | text | lines: 19 -- `.gitignore` | text | lines: 169 -- `.gitmodules` | text | lines: 9 -- `.vscode/launch.json` | text | lines: 26 -- `.vscode/settings.json` | text | lines: 3 -- `LICENSE` | text | lines: 202 -- `README.md` | text | lines: 112 -- `assets/hrm.png` | binary | lines: 0 -- `assets/npyjs.js` | text | lines: 176 -- `config/vjepa_10b.yaml` | text | lines: 85 -- `config/vjepa_micro.yaml` | text | lines: 34 -- `dataset/common.py` | text | lines: 51 -- `dataset/generate_dummy_data.py` | text | lines: 29 -- `dataset/video_dataset.py` | text | lines: 108 -- `docs/FRONTIER_GAP_ANALYSIS.md` | text | lines: 35 -- `docs/HUMAN_VISION_EXECUTION_EVAL_SPEC.md` | text | lines: 48 -- `docs/RIGOROUS_DEVELOPMENT_PROTOCOL.md` | text | lines: 33 -- `evaluate_perception.py` | text | lines: 90 -- `evaluate_world_model.py` | text | lines: 166 -- `models/adaptive_depth.py` | text | lines: 199 -- `models/common.py` | text | lines: 32 -- `models/hybrid_ssm.py` | text | lines: 228 -- `models/information_bottleneck.py` | text | lines: 227 -- `models/layers.py` | text | lines: 167 -- `models/losses.py` | text | lines: 101 -- `models/multimodal_grounding.py` | text | lines: 258 -- `models/muon_optimizer.py` | text | lines: 191 -- `models/proper_equivariance.py` | text | lines: 335 -- `models/spectral_conv.py` | text | lines: 193 -- `models/topological.py` | text | lines: 215 -- `models/ttt_layer.py` | text | lines: 206 -- `models/uncertainty.py` | text | lines: 206 -- `models/vjepa/flow_matching.py` | text | lines: 281 -- `models/vjepa/gaussian_splatting.py` | text | lines: 203 -- `models/vjepa/layers.py` | text | lines: 157 -- `models/vjepa/losses.py` | text | lines: 31 -- `models/vjepa/memory.py` | text | lines: 392 -- `models/vjepa/physics_engine.py` | text | lines: 107 -- `models/vjepa/planning.py` | text | lines: 467 -- `models/vjepa/predictor.py` | text | lines: 260 -- `models/vjepa/symplectic_integrator.py` | text | lines: 137 -- `models/vjepa/utils.py` | text | lines: 44 -- `models/vjepa/vit.py` | text | lines: 94 -- `models/vjepa/vjepa_model.py` | text | lines: 141 -- `requirements.txt` | text | lines: 16 -- `utils/functions.py` | text | lines: 19 -- `vjepa_train.py` | text | lines: 221 +Comprehensive repository address map. Updated to current line-level state. + +Total files indexed: **48** + +## `.github/workflows/sync-from-upstream.yml` +- Type: text +- Total lines: 19 +- Address anchors: + - L1: `name: Auto Sync from Upstream` + - L2: `on:` + - L6: `jobs:` + +## `.gitignore` +- Type: text +- Total lines: 169 +- Address anchors: none detected + +## `.gitmodules` +- Type: text +- Total lines: 9 +- Address anchors: none detected + +## `.vscode/launch.json` +- Type: text +- Total lines: 26 +- Address anchors: none detected + +## `.vscode/settings.json` +- Type: text +- Total lines: 3 +- Address anchors: none detected + +## `LICENSE` +- Type: text +- Total lines: 202 +- Address anchors: none detected + +## `README.md` +- Type: text +- Total lines: 112 +- Address anchors: + - L1: `# Visual Execution Model (VEM)` + - L9: `## Purpose` + - L17: `## Vision` + - L20: `## Goal` + - L28: `## Technical Architecture (Concept Map)` + - L30: `### 1) Spatio-Temporal Representation (Vision Encoder)` + - L35: `### 2) Geometric Inductive Biases` + - L40: `### 3) Continuous-Time Latent Dynamics` + - L45: `### 4) Hierarchical Predictive Reasoning` + - L50: `### 5) World Rendering and Latent Scene Composition` + - L54: `### 6) Latent Planning & Decision Support` + - L58: `### 7) Multi-Modal and Robustness Extensions` + - L62: `### 8) Training Stack` + - L70: `## Repository Workflow (Single Framework)` + - L72: `### Configurations` + - L76: `### Training Entrypoint` + - L79: `# or` + - L86: `### Practical Notes` + - L96: `### Final Execution Checklist (Do This)` + - L104: `## Roadmap Direction` + +## `assets/hrm.png` +- Type: binary/non-utf8 +- Total lines: 0 +- Address anchors: n/a + +## `assets/npyjs.js` +- Type: text +- Total lines: 176 +- Address anchors: none detected + +## `check_integrations.py` +- Type: text +- Total lines: 56 +- Address anchors: + - L8: `def load_config(path: str):` + - L13: `def main():` + +## `config/vjepa_10b.yaml` +- Type: text +- Total lines: 85 +- Address anchors: + - L2: `encoder:` + - L14: `predictor:` + - L72: `training:` + +## `config/vjepa_micro.yaml` +- Type: text +- Total lines: 34 +- Address anchors: + - L2: `encoder:` + - L14: `predictor:` + - L27: `training:` + +## `dataset/common.py` +- Type: text +- Total lines: 51 +- Address anchors: + - L12: `class PuzzleDatasetMetadata(pydantic.BaseModel):` + - L27: `def dihedral_transform(arr: np.ndarray, tid: int) -> np.ndarray:` + - L50: `def inverse_dihedral_transform(arr: np.ndarray, tid: int) -> np.ndarray:` + +## `dataset/generate_dummy_data.py` +- Type: text +- Total lines: 29 +- Address anchors: + - L5: `def generate_dummy_video(path, frames=32, res=(224, 224)):` + +## `dataset/video_dataset.py` +- Type: text +- Total lines: 108 +- Address anchors: + - L8: `class AdvancedVideoDataset(IterableDataset):` + - L13: `def __init__(self,` + - L33: `def _get_video_stream(self, path):` + - L42: `def _generate_3d_block_mask(self):` + - L66: `def __iter__(self):` + - L106: `def get_dataloader(video_paths, batch_size=1, **kwargs):` + +## `docs/FRONTIER_GAP_ANALYSIS.md` +- Type: text +- Total lines: 35 +- Address anchors: + - L1: `# Frontier Capability Gap Analysis and Implementation Plan` + - L3: `## Scope` + - L9: `## Comparison Matrix` + - L20: `## Implemented in this change` + - L22: `### 1) MCTS action prior upgrade` + - L28: `## Why this is prioritized first` + - L31: `## Next technical steps (ordered)` + +## `docs/HUMAN_VISION_EXECUTION_EVAL_SPEC.md` +- Type: text +- Total lines: 48 +- Address anchors: + - L1: `# Human-Vision + Execution Evaluation Spec (Initial)` + - L5: `## Purpose Alignment` + - L13: `## Track A: Perception Robustness (implemented baseline)` + - L27: `## Track B: World-Model Dynamics (implemented baseline)` + - L39: `## Track C: Execution/Cognition (next)` + - L44: `## Promotion Rule (Phase-1/2)` + +## `docs/RIGOROUS_DEVELOPMENT_PROTOCOL.md` +- Type: text +- Total lines: 33 +- Address anchors: + - L1: `# Rigorous Development Protocol (Phase-1)` + - L5: `## Gate A — Sanity / Determinism` + - L10: `## Gate B — World-model Metrics` + - L25: `## Gate C — Change Promotion` + - L31: `## Notes` + +## `evaluate_perception.py` +- Type: text +- Total lines: 90 +- Address anchors: + - L26: `def apply_perturbation(video: torch.Tensor, mode: str) -> torch.Tensor:` + - L39: `def latent_consistency(model: VJEPA, video: torch.Tensor, perturbed: torch.Tensor) -> float:` + - L46: `def main() -> None:` + +## `evaluate_world_model.py` +- Type: text +- Total lines: 166 +- Address anchors: + - L27: `def set_seed(seed: int) -> None:` + - L34: `class EvalManifest:` + - L46: `def get_commit_hash(default: str = "unknown") -> str:` + - L62: `def latent_rollout(` + - L76: `def evaluate_metrics(model: VJEPA, device: torch.device, rollout_steps: int, num_actions: int) -> Dict[str, float]:` + - L116: `def main() -> None:` + +## `models/adaptive_depth.py` +- Type: text +- Total lines: 195 +- Address anchors: + - L24: `class AdaptiveDepthController(nn.Module):` + - L41: `def __init__(` + - L54: `def should_continue(` + - L116: `class AdaptiveDepthWrapper(nn.Module):` + - L131: `def __init__(` + - L146: `def forward(` + +## `models/common.py` +- Type: text +- Total lines: 32 +- Address anchors: + - L7: `def trunc_normal_init_(tensor: torch.Tensor, std: float = 1.0, lower: float = -2.0, upper: float = 2.0):` + +## `models/hybrid_ssm.py` +- Type: text +- Total lines: 228 +- Address anchors: + - L26: `class SelectiveSSM(nn.Module):` + - L41: `def __init__(` + - L85: `def forward(self, x: torch.Tensor) -> torch.Tensor:` + - L141: `class HybridSSMAttentionBlock(nn.Module):` + - L161: `def __init__(` + - L199: `def forward(` + +## `models/information_bottleneck.py` +- Type: text +- Total lines: 227 +- Address anchors: + - L28: `class VariationalInformationBottleneck(nn.Module):` + - L42: `def __init__(` + - L71: `def _reparameterize(` + - L90: `def forward(` + - L135: `class InformationBottleneckAttention(nn.Module):` + - L150: `def __init__(` + - L179: `def forward(` + +## `models/layers.py` +- Type: text +- Total lines: 167 +- Address anchors: + - L13: `def flash_attn_func(q, k, v, causal=False):` + - L29: `def _find_multiple(a, b):` + - L33: `def rotate_half(x: torch.Tensor):` + - L40: `def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):` + - L53: `class CastedLinear(nn.Module):` + - L54: `def __init__(self,` + - L68: `def forward(self, input: torch.Tensor) -> torch.Tensor:` + - L72: `class CastedEmbedding(nn.Module):` + - L73: `def __init__(self,` + - L86: `def forward(self, input: torch.Tensor) -> torch.Tensor:` + - L90: `class RotaryEmbedding(nn.Module):` + - L91: `def __init__(self, dim, max_position_embeddings, base, device=None):` + - L104: `def forward(self):` + - L108: `class Attention(nn.Module):` + - L109: `def __init__(self, hidden_size, head_dim, num_heads, num_key_value_heads, causal=False):` + - L122: `def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:` + - L148: `class SwiGLU(nn.Module):` + - L149: `def __init__(self, hidden_size: int, expansion: float):` + - L156: `def forward(self, x):` + - L161: `def rms_norm(hidden_states: torch.Tensor, variance_epsilon: float) -> torch.Tensor:` + +## `models/losses.py` +- Type: text +- Total lines: 101 +- Address anchors: + - L11: `def s(x, epsilon=1e-30):` + - L19: `def log_stablemax(x, dim=-1):` + - L24: `def stablemax_cross_entropy(logits, labels, ignore_index: int = -100):` + - L34: `def softmax_cross_entropy(logits, labels, ignore_index: int = -100):` + - L40: `class ACTLossHead(nn.Module):` + - L41: `def __init__(self, model: nn.Module, loss_type: str):` + - L46: `def initial_carry(self, *args, **kwargs):` + - L49: `def forward(` + +## `models/multimodal_grounding.py` +- Type: text +- Total lines: 258 +- Address anchors: + - L24: `class ModalityEncoder(nn.Module):` + - L37: `def __init__(` + - L60: `def forward(` + - L88: `class CrossModalAttention(nn.Module):` + - L101: `def __init__(` + - L122: `def forward(` + - L158: `class MultiModalGrounding(nn.Module):` + - L180: `def __init__(` + - L213: `def forward(` + +## `models/muon_optimizer.py` +- Type: text +- Total lines: 191 +- Address anchors: + - L26: `class Muon(Optimizer):` + - L44: `def __init__(` + - L67: `def step(self, closure=None):` + - L137: `def _newton_schulz_orthogonalize(G: torch.Tensor, steps: int = 5) -> torch.Tensor:` + - L170: `def _distributed_allreduce_grads(` + +## `models/proper_equivariance.py` +- Type: text +- Total lines: 335 +- Address anchors: + - L27: `class SO3Rotation(nn.Module):` + - L37: `def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor:` + - L74: `def matrix_to_quaternion(R: torch.Tensor) -> torch.Tensor:` + - L93: `class WignerDMatrices(nn.Module):` + - L107: `def wigner_d_small(beta: torch.Tensor, l: int) -> torch.Tensor:` + - L170: `def rotation_to_euler(R: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:` + - L187: `class ProperSE3EquivariantLayer(nn.Module):` + - L204: `def __init__(` + - L243: `def _positional_encoding(self, positions: torch.Tensor) -> torch.Tensor:` + - L259: `def forward(` + +## `models/spectral_conv.py` +- Type: text +- Total lines: 193 +- Address anchors: + - L25: `class GraphLaplacian(nn.Module):` + - L34: `def __init__(self, k_neighbors: int = 8):` + - L38: `def forward(self, x: torch.Tensor) -> torch.Tensor:` + - L74: `class SpectralGraphConv(nn.Module):` + - L88: `def __init__(` + - L120: `def _chebyshev_polynomials(` + - L159: `def forward(self, x: torch.Tensor) -> torch.Tensor:` + +## `models/topological.py` +- Type: text +- Total lines: 215 +- Address anchors: + - L25: `class DifferentiableBettiNumbers(nn.Module):` + - L41: `def __init__(` + - L68: `def _compute_distance_matrix(self, x: torch.Tensor) -> torch.Tensor:` + - L74: `def _soft_threshold(` + - L87: `def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:` + - L156: `class TopologicalAwareness(nn.Module):` + - L171: `def __init__(` + - L191: `def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:` + +## `models/ttt_layer.py` +- Type: text +- Total lines: 206 +- Address anchors: + - L25: `class TTTLinear(nn.Module):` + - L45: `def __init__(` + - L74: `def forward(` + - L148: `class TTTLinearWithAttention(nn.Module):` + - L163: `def __init__(` + - L191: `def forward(` + +## `models/uncertainty.py` +- Type: text +- Total lines: 206 +- Address anchors: + - L24: `class VariationalLinear(nn.Module):` + - L41: `def __init__(` + - L61: `def forward(self, x: torch.Tensor) -> torch.Tensor:` + - L78: `def kl_divergence(self) -> torch.Tensor:` + - L106: `class UncertaintyQuantification(nn.Module):` + - L121: `def __init__(` + - L151: `def forward(` + +## `models/vjepa/flow_matching.py` +- Type: text +- Total lines: 281 +- Address anchors: + - L28: `class SinusoidalTimeEmbedding(nn.Module):` + - L31: `def __init__(self, dim: int):` + - L35: `def forward(self, t: torch.Tensor) -> torch.Tensor:` + - L53: `class VelocityField(nn.Module):` + - L61: `def __init__(self, dim: int, hidden_dim: int, condition_dim: int):` + - L89: `def forward(` + - L125: `class ConditionalFlowMatching(nn.Module):` + - L150: `def __init__(` + - L165: `def forward(` + - L211: `def sample(` + - L257: `def sample_rectified(` + +## `models/vjepa/gaussian_splatting.py` +- Type: text +- Total lines: 203 +- Address anchors: + - L28: `class LatentGaussianSplatting(nn.Module):` + - L40: `def __init__(self, dim: int, num_gaussians: int = 256):` + - L75: `def _parse_gaussians(self, params: torch.Tensor) -> dict:` + - L118: `def _quaternion_to_rotation_matrix(q: torch.Tensor) -> torch.Tensor:` + - L146: `def forward(` + +## `models/vjepa/layers.py` +- Type: text +- Total lines: 157 +- Address anchors: + - L8: `class LieGroupEquivariantLayer(nn.Module):` + - L14: `def __init__(self, dim: int, rank: int = 8):` + - L31: `def forward(self, x: torch.Tensor, group_element: torch.Tensor) -> torch.Tensor:` + - L50: `class LatentRayMarcher(nn.Module):` + - L56: `def __init__(self, dim: int, num_samples: int = 16):` + - L71: `def forward(self, latents: torch.Tensor, ray_dirs: torch.Tensor) -> torch.Tensor:` + - L113: `def apply_rotary_pos_emb_3d(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):` + - L114: `def rotate_half(x):` + - L128: `class RotaryEmbedding3D(nn.Module):` + - L129: `def __init__(self, dim: int, max_t: int, max_h: int, max_w: int, base: float = 10000.0, device=None):` + - L141: `def _get_freqs(self, length: int, dim: int, device):` + - L147: `def _build_cache(self, device):` + - L152: `def forward(self, t: int, h: int, w: int) -> Tuple[torch.Tensor, torch.Tensor]:` + +## `models/vjepa/losses.py` +- Type: text +- Total lines: 31 +- Address anchors: + - L4: `def vicreg_loss(x, y, sim_coeff=25.0, std_coeff=25.0, cov_coeff=1.0):` + - L22: `def covariance_loss(z):` + +## `models/vjepa/memory.py` +- Type: text +- Total lines: 392 +- Address anchors: + - L28: `class ResonatorNetwork(nn.Module):` + - L49: `def __init__(` + - L66: `def set_cleanup_memory(self, memory: torch.Tensor) -> None:` + - L70: `def cleanup(self, x: torch.Tensor) -> torch.Tensor:` + - L96: `def resonator_step(` + - L119: `def _unbind(self, composite: torch.Tensor, key: torch.Tensor) -> torch.Tensor:` + - L125: `def forward(` + - L178: `class HolographicMemory(nn.Module):` + - L198: `def __init__(` + - L219: `def _bind_hrr(self, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:` + - L225: `def _unbind_hrr(self, composite: torch.Tensor, key: torch.Tensor) -> torch.Tensor:` + - L231: `def _bind_fhrr(self, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:` + - L243: `def _unbind_fhrr(self, composite: torch.Tensor, key: torch.Tensor) -> torch.Tensor:` + - L249: `def bind(self, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:` + - L255: `def unbind(self, composite: torch.Tensor, key: torch.Tensor) -> torch.Tensor:` + - L261: `def superpose(self, vectors: torch.Tensor, dim: int = 1) -> torch.Tensor:` + - L284: `def forward(` + - L312: `def retrieve(self, memory: torch.Tensor, key: torch.Tensor) -> torch.Tensor:` + - L327: `def retrieve_with_cleanup(` + - L348: `def multi_retrieve(` + - L385: `def set_cleanup_memory(self, memory: torch.Tensor) -> None:` + +## `models/vjepa/physics_engine.py` +- Type: text +- Total lines: 107 +- Address anchors: + - L6: `class HRMPhysicsODE(nn.Module):` + - L12: `def __init__(self, dim: int, action_dim: Optional[int] = 128):` + - L30: `def forward(self, t: float, z: torch.Tensor) -> torch.Tensor:` + - L61: `class ContinuousTimeHRM(nn.Module):` + - L67: `def __init__(self, dim: int, action_dim: int = 128):` + - L71: `def forward(self, z: torch.Tensor, delta_t: torch.Tensor | float = 1.0, action: Optional[torch.Tensor] = None):` + +## `models/vjepa/planning.py` +- Type: text +- Total lines: 467 +- Address anchors: + - L28: `class MCTSNode:` + - L49: `def __init__(` + - L67: `def mean_value(self) -> float:` + - L74: `def is_expanded(self) -> bool:` + - L79: `def effective_visits(self) -> int:` + - L83: `def puct_score(self, parent_visits: int, c_puct: float = 1.41) -> float:` + - L120: `class MCTS:` + - L139: `def __init__(` + - L160: `def _imagine_future(` + - L192: `def _select(self, node: MCTSNode) -> MCTSNode:` + - L222: `def _expand(` + - L289: `def _backpropagate(self, node: MCTSNode, value: float) -> None:` + - L315: `def _get_action_probabilities(self, root: MCTSNode, num_actions: int) -> torch.Tensor:` + - L349: `def plan(` + - L402: `class LatentPlannerMCTS:` + - L417: `def __init__(` + - L433: `def plan(` + - L451: `def plan_with_uncertainty(` + +## `models/vjepa/predictor.py` +- Type: text +- Total lines: 260 +- Address anchors: + - L22: `class VJEPAPredictorInner(nn.Module):` + - L28: `def __init__(self,` + - L156: `def forward(self,` + +## `models/vjepa/symplectic_integrator.py` +- Type: text +- Total lines: 137 +- Address anchors: + - L31: `class SymplecticEulerIntegrator(nn.Module):` + - L48: `def __init__(self, dim: int, action_dim: Optional[int] = None):` + - L67: `def set_action(self, action: Optional[torch.Tensor]) -> None:` + - L71: `def hamiltonian(self, q: torch.Tensor, p: torch.Tensor) -> torch.Tensor:` + - L87: `def forward(` + - L126: `def compute_energy(self, z: torch.Tensor) -> torch.Tensor:` + +## `models/vjepa/utils.py` +- Type: text +- Total lines: 44 +- Address anchors: + - L3: `def get_block_mask(t, h, w, mask_ratio=0.6):` + - L21: `def apply_mask(x, mask):` + +## `models/vjepa/vit.py` +- Type: text +- Total lines: 94 +- Address anchors: + - L9: `class PatchEmbed3D(nn.Module):` + - L15: `def __init__(self, patch_size=(2, 16, 16), in_chans=3, embed_dim=768):` + - L20: `def forward(self, x):` + - L30: `class VisionTransformerBlock(nn.Module):` + - L31: `def __init__(self, dim, num_heads, expansion, norm_eps=1e-5):` + - L43: `def _forward_inner(self, x, cos_sin):` + - L48: `def forward(self, x, cos_sin):` + - L54: `class VisionEncoder(nn.Module):` + - L55: `def __init__(self,` + - L86: `def forward(self, x):` + +## `models/vjepa/vjepa_model.py` +- Type: text +- Total lines: 141 +- Address anchors: + - L12: `class VJEPA(nn.Module):` + - L23: `def __init__(self,` + - L81: `def update_target_encoder(self):` + - L86: `def forward(self, batch: Dict[str, torch.Tensor]):` + - L139: `class VisualExecutionModel(VJEPA):` + +## `requirements.txt` +- Type: text +- Total lines: 16 +- Address anchors: none detected + +## `utils/functions.py` +- Type: text +- Total lines: 19 +- Address anchors: + - L5: `def load_model_class(identifier: str, prefix: str = "models."):` + - L15: `def get_model_source_path(identifier: str, prefix: str = "models."):` + +## `vjepa_train.py` +- Type: text +- Total lines: 221 +- Address anchors: + - L19: `def build_optimizer(model, config):` + - L93: `class CombinedOptimizer:` + - L98: `def __init__(self, optimizers):` + - L101: `def zero_grad(self, set_to_none=False):` + - L105: `def step(self, closure=None):` + - L110: `def param_groups(self):` + - L117: `def train(config_path="config/vjepa_micro.yaml"):` + From a3e1075ebd4baedea69e5dd3aafa1da7eb2176b2 Mon Sep 17 00:00:00 2001 From: jeevesh415 Date: Sun, 17 May 2026 16:32:42 +0530 Subject: [PATCH 6/7] Restore previously removed legacy modules and resync CODE_ADDRESS_INDEX --- CODE_ADDRESS_INDEX.md | 154 +++++++++++- arc_eval.ipynb | 252 +++++++++++++++++++ config/arch/hrm_v1.yaml | 21 ++ config/cfg_pretrain.yaml | 31 +++ dataset/build_arc_dataset.py | 291 ++++++++++++++++++++++ dataset/build_maze_dataset.py | 142 +++++++++++ evaluate.py | 68 +++++ models/hrm/hrm_act_v1.py | 283 +++++++++++++++++++++ models/sparse_embedding.py | 132 ++++++++++ pretrain.py | 453 ++++++++++++++++++++++++++++++++++ puzzle_dataset.py | 199 +++++++++++++++ puzzle_visualizer.html | 426 ++++++++++++++++++++++++++++++++ 12 files changed, 2451 insertions(+), 1 deletion(-) create mode 100644 arc_eval.ipynb create mode 100644 config/arch/hrm_v1.yaml create mode 100644 config/cfg_pretrain.yaml create mode 100644 dataset/build_arc_dataset.py create mode 100644 dataset/build_maze_dataset.py create mode 100644 evaluate.py create mode 100644 models/hrm/hrm_act_v1.py create mode 100644 models/sparse_embedding.py create mode 100644 pretrain.py create mode 100644 puzzle_dataset.py create mode 100644 puzzle_visualizer.html diff --git a/CODE_ADDRESS_INDEX.md b/CODE_ADDRESS_INDEX.md index 9e35997f..7568c669 100644 --- a/CODE_ADDRESS_INDEX.md +++ b/CODE_ADDRESS_INDEX.md @@ -2,7 +2,7 @@ Comprehensive repository address map. Updated to current line-level state. -Total files indexed: **48** +Total files indexed: **59** ## `.github/workflows/sync-from-upstream.yml` - Type: text @@ -62,6 +62,11 @@ Total files indexed: **48** - L96: `### Final Execution Checklist (Do This)` - L104: `## Roadmap Direction` +## `arc_eval.ipynb` +- Type: text +- Total lines: 252 +- Address anchors: none detected + ## `assets/hrm.png` - Type: binary/non-utf8 - Total lines: 0 @@ -79,6 +84,44 @@ Total files indexed: **48** - L8: `def load_config(path: str):` - L13: `def main():` +## `config/arch/hrm_v1.yaml` +- Type: text +- Total lines: 21 +- Address anchors: + - L1: `name: hrm.hrm_act_v1@HierarchicalReasoningModel_ACTV1` + - L2: `loss:` + - L6: `halt_exploration_prob: 0.1` + - L7: `halt_max_steps: 16` + - L9: `H_cycles: 2` + - L10: `L_cycles: 2` + - L12: `H_layers: 4` + - L13: `L_layers: 4` + - L15: `hidden_size: 512` + - L16: `num_heads: 8 # min(2, hidden_size // 64)` + - L17: `expansion: 4` + - L19: `puzzle_emb_ndim: ${.hidden_size}` + - L21: `pos_encodings: rope` + +## `config/cfg_pretrain.yaml` +- Type: text +- Total lines: 31 +- Address anchors: + - L3: `defaults:` + - L7: `hydra:` + - L11: `data_path: data/arc-aug-1000` + - L14: `global_batch_size: 768` + - L16: `epochs: 100000` + - L17: `eval_interval: 10000` + - L18: `checkpoint_every_eval: True` + - L20: `lr: 1e-4` + - L21: `lr_min_ratio: 1.0` + - L22: `lr_warmup_steps: 2000` + - L25: `beta1: 0.9` + - L26: `beta2: 0.95` + - L27: `weight_decay: 0.1` + - L28: `puzzle_emb_weight_decay: 0.1` + - L31: `puzzle_emb_lr: 1e-2` + ## `config/vjepa_10b.yaml` - Type: text - Total lines: 85 @@ -95,6 +138,31 @@ Total files indexed: **48** - L14: `predictor:` - L27: `training:` +## `dataset/build_arc_dataset.py` +- Type: text +- Total lines: 291 +- Address anchors: + - L19: `class DataProcessConfig(BaseModel):` + - L37: `class ARCPuzzle:` + - L43: `def arc_grid_to_np(grid: List[List[int]]):` + - L54: `def np_grid_to_seq_translational_augment(inp: np.ndarray, out: np.ndarray, do_translation: bool):` + - L81: `def puzzle_hash(puzzle: dict):` + - L83: `def _grid_hash(grid: np.ndarray):` + - L98: `def convert_single_arc_puzzle(results: dict, default_name: str, puzzle: dict, aug_count: int, dest_mapping: Dict[str, Tuple[str, str]]):` + - L122: `def _map_grid(grid: np.ndarray):` + - L148: `def load_puzzles_arcagi(results: dict, dataset_path: str, config: DataProcessConfig):` + - L184: `def convert_dataset(config: DataProcessConfig):` + - L286: `def main(config: DataProcessConfig):` + +## `dataset/build_maze_dataset.py` +- Type: text +- Total lines: 142 +- Address anchors: + - L22: `class DataProcessConfig(BaseModel):` + - L30: `def convert_subset(set_name: str, config: DataProcessConfig):` + - L89: `def _seq_to_numpy(seq):` + - L136: `def preprocess_data(config: DataProcessConfig):` + ## `dataset/common.py` - Type: text - Total lines: 51 @@ -153,6 +221,13 @@ Total files indexed: **48** - L25: `## Gate C — Change Promotion` - L31: `## Notes` +## `evaluate.py` +- Type: text +- Total lines: 68 +- Address anchors: + - L13: `class EvalConfig(pydantic.BaseModel):` + - L19: `def launch():` + ## `evaluate_perception.py` - Type: text - Total lines: 90 @@ -189,6 +264,31 @@ Total files indexed: **48** - Address anchors: - L7: `def trunc_normal_init_(tensor: torch.Tensor, std: float = 1.0, lower: float = -2.0, upper: float = 2.0):` +## `models/hrm/hrm_act_v1.py` +- Type: text +- Total lines: 283 +- Address anchors: + - L16: `class HierarchicalReasoningModel_ACTV1InnerCarry:` + - L22: `class HierarchicalReasoningModel_ACTV1Carry:` + - L31: `class HierarchicalReasoningModel_ACTV1Config(BaseModel):` + - L60: `class HierarchicalReasoningModel_ACTV1Block(nn.Module):` + - L61: `def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> None:` + - L77: `def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor:` + - L86: `class HierarchicalReasoningModel_ACTV1ReasoningModule(nn.Module):` + - L87: `def __init__(self, layers: List[HierarchicalReasoningModel_ACTV1Block]):` + - L92: `def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, **kwargs) -> torch.Tensor:` + - L102: `class HierarchicalReasoningModel_ACTV1_Inner(nn.Module):` + - L103: `def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> None:` + - L146: `def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):` + - L168: `def empty_carry(self, batch_size: int):` + - L174: `def reset_carry(self, reset_flag: torch.Tensor, carry: HierarchicalReasoningModel_ACTV1InnerCarry):` + - L180: `def forward(self, carry: HierarchicalReasoningModel_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]) -> Tuple[HierarchicalReasoningModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:` + - L216: `class HierarchicalReasoningModel_ACTV1(nn.Module):` + - L219: `def __init__(self, config_dict: dict):` + - L225: `def puzzle_emb(self):` + - L228: `def initial_carry(self, batch: Dict[str, torch.Tensor]):` + - L240: `def forward(self, carry: HierarchicalReasoningModel_ACTV1Carry, batch: Dict[str, torch.Tensor]) -> Tuple[HierarchicalReasoningModel_ACTV1Carry, Dict[str, torch.Tensor]]:` + ## `models/hybrid_ssm.py` - Type: text - Total lines: 228 @@ -289,6 +389,18 @@ Total files indexed: **48** - L243: `def _positional_encoding(self, positions: torch.Tensor) -> torch.Tensor:` - L259: `def forward(` +## `models/sparse_embedding.py` +- Type: text +- Total lines: 132 +- Address anchors: + - L11: `class CastedSparseEmbedding(nn.Module):` + - L12: `def __init__(self, num_embeddings: int, embedding_dim: int, batch_size: int, init_std: float, cast_to: torch.dtype):` + - L28: `def forward(self, inputs: torch.Tensor) -> torch.Tensor:` + - L41: `class CastedSparseEmbeddingSignSGD_Distributed(Optimizer):` + - L42: `def __init__(` + - L63: `def step(self, closure=None): # type: ignore` + - L98: `def _sparse_emb_signsgd_dist(` + ## `models/spectral_conv.py` - Type: text - Total lines: 193 @@ -499,6 +611,46 @@ Total files indexed: **48** - L86: `def forward(self, batch: Dict[str, torch.Tensor]):` - L139: `class VisualExecutionModel(VJEPA):` +## `pretrain.py` +- Type: text +- Total lines: 453 +- Address anchors: + - L26: `class LossConfig(pydantic.BaseModel):` + - L32: `class ArchConfig(pydantic.BaseModel):` + - L39: `class PretrainConfig(pydantic.BaseModel):` + - L74: `class TrainState:` + - L84: `def create_dataloader(config: PretrainConfig, split: str, rank: int, world_size: int, **kwargs):` + - L108: `def create_model(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, world_size: int):` + - L162: `def cosine_schedule_with_warmup_lr_lambda(` + - L172: `def init_train_state(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, world_size: int):` + - L190: `def save_train_state(config: PretrainConfig, train_state: TrainState):` + - L199: `def compute_lr(base_lr: float, config: PretrainConfig, train_state: TrainState):` + - L209: `def train_batch(config: PretrainConfig, train_state: TrainState, batch: Any, global_batch_size: int, rank: int, world_size: int):` + - L266: `def evaluate(config: PretrainConfig, train_state: TrainState, eval_loader: torch.utils.data.DataLoader, eval_metadata: PuzzleDatasetMetadata, rank: int, world_size: int):` + - L333: `def save_code_and_config(config: PretrainConfig):` + - L359: `def load_synced_config(hydra_config: DictConfig, rank: int, world_size: int) -> PretrainConfig:` + - L381: `def launch(hydra_config: DictConfig):` + +## `puzzle_dataset.py` +- Type: text +- Total lines: 199 +- Address anchors: + - L14: `def _sample_batch(rng: np.random.Generator, group_order: np.ndarray, puzzle_indices: np.ndarray, group_indices: np.ndarray, start_index: int, global_batch_size: int):` + - L41: `class PuzzleDatasetConfig(pydantic.BaseModel):` + - L53: `class PuzzleDataset(IterableDataset):` + - L54: `def __init__(self, config: PuzzleDatasetConfig, split: str = "train"):` + - L68: `def _load_metadata(self) -> PuzzleDatasetMetadata:` + - L72: `def _lazy_load_dataset(self):` + - L95: `def _collate_batch(self, batch):` + - L118: `def _iter_test(self):` + - L151: `def _iter_train(self):` + - L189: `def __iter__(self):` + +## `puzzle_visualizer.html` +- Type: text +- Total lines: 426 +- Address anchors: none detected + ## `requirements.txt` - Type: text - Total lines: 16 diff --git a/arc_eval.ipynb b/arc_eval.ipynb new file mode 100644 index 00000000..b2786b8a --- /dev/null +++ b/arc_eval.ipynb @@ -0,0 +1,252 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import json\n", + "from glob import glob\n", + "import hashlib\n", + "import matplotlib.pyplot as plt\n", + "import matplotlib.colors as mcolors\n", + "\n", + "import torch\n", + "import torch.nn.functional as F\n", + "import numpy as np\n", + "from numba import njit\n", + "\n", + "from dataset.common import inverse_dihedral_transform\n", + "\n", + "\n", + "DATASET_PATH = \"data/arc-aug-1000\" # ARC-1\n", + "# DATASET_PATH = \"data/arc-2-aug-1000\" # ARC-2\n", + "\n", + "CHECKPOINT_PATH = \"checkpoints/Arc-aug-1000 ACT-torch/HierarchicalReasoningModel_ACTV1 amphibian-turaco/step_414456\"\n", + "\n", + "\n", + "PAD_PUZZLE_IDENTIFIER = 0\n", + "\n", + "# Visualization\n", + "ARC_COLOR_MAP = mcolors.ListedColormap([\n", + " \"#000000\", # symbol_0: black\n", + " \"#0074D9\", # symbol_1: blue\n", + " \"#FF4136\", # symbol_2: red\n", + " \"#2ECC40\", # symbol_3: green\n", + " \"#FFDC00\", # symbol_4: yellow\n", + " \"#AAAAAA\", # symbol_5: grey\n", + " \"#F012BE\", # symbol_6: fuschia\n", + " \"#FF851B\", # symbol_7: orange\n", + " \"#7FDBFF\", # symbol_8: teal\n", + " \"#870C25\" # symbol_9: brown\n", + "])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def load_identifiers_and_preds(dataset_path: str, checkpoint_path: str):\n", + " # Load puzzle identifiers\n", + " with open(os.path.join(dataset_path, \"identifiers.json\"), \"r\") as f:\n", + " identifier_map = json.load(f)\n", + " \n", + " # Load preds\n", + " all_preds = {}\n", + " for filename in glob(f\"{checkpoint_path}_all_preds.*\"):\n", + " preds = torch.load(filename)\n", + " for k, v in preds.items():\n", + " all_preds.setdefault(k, [])\n", + " all_preds[k].append(v)\n", + " \n", + " del preds\n", + "\n", + " all_preds = {k: torch.cat(v, dim=0) for k, v in all_preds.items()}\n", + " \n", + " # Remove paddings\n", + " mask = all_preds[\"puzzle_identifiers\"] != PAD_PUZZLE_IDENTIFIER\n", + " all_preds = {k: v[mask] for k, v in all_preds.items()}\n", + "\n", + " return identifier_map, all_preds\n", + "\n", + "\n", + "def inverse_aug(name: str, grid: np.ndarray):\n", + " if \"_\" not in name:\n", + " return grid\n", + "\n", + " trans_id, perm = name.split(\"_\")[-2:]\n", + " trans_id = int(trans_id[1:]) # Remove \"t\" letter\n", + " inv_perm = np.argsort(list(perm))\n", + " \n", + " return inv_perm[inverse_dihedral_transform(grid, trans_id)]\n", + "\n", + "\n", + "def grid_hash(grid: np.ndarray):\n", + " return hash((grid.tobytes(), grid.shape))\n", + "\n", + "\n", + "@njit\n", + "def crop(grid: np.ndarray):\n", + " # Find maximum-sized rectangle without any EOS token inside.\n", + " grid = grid.reshape(30, 30)\n", + "\n", + " max_area = 0\n", + " max_size = (0, 0)\n", + " nr, nc = grid.shape\n", + " \n", + " num_c = nc\n", + " for num_r in range(1, nr + 1):\n", + " # Scan for maximum c\n", + " for c in range(1, num_c + 1):\n", + " x = grid[num_r - 1, c - 1]\n", + " if (x < 2) | (x > 11):\n", + " num_c = c - 1\n", + " break\n", + " \n", + " area = num_r * num_c\n", + " if area > max_area:\n", + " max_area = area\n", + " max_size = (num_r, num_c)\n", + "\n", + " return grid[:max_size[0], :max_size[1]] - 2\n", + "\n", + "\n", + "def test(visualize, Ks=[1, 2, 10, 100, 1000]):\n", + " identifier_map, all_preds = load_identifiers_and_preds(DATASET_PATH, CHECKPOINT_PATH)\n", + " \n", + " global_hmap = {}\n", + " \n", + " # Get puzzles and corresponding answers\n", + " puzzle_labels = {}\n", + " for identifier, input, label in zip(all_preds[\"puzzle_identifiers\"], all_preds[\"inputs\"], all_preds[\"labels\"]):\n", + " name = identifier_map[identifier]\n", + " if \"_\" not in name: # Not-augmented\n", + " puzzle_labels.setdefault(name, {})\n", + " \n", + " input = crop(input.numpy())\n", + " label = crop(label.numpy())\n", + "\n", + " input_hash = grid_hash(input)\n", + " label_hash = grid_hash(label)\n", + "\n", + " global_hmap[input_hash] = input\n", + " global_hmap[label_hash] = label\n", + "\n", + " assert input_hash not in puzzle_labels[name]\n", + " puzzle_labels[name][input_hash] = label_hash\n", + " \n", + " print (\"Number of puzzles\", len(puzzle_labels))\n", + " \n", + " # Argmax prediction\n", + " preds = all_preds[\"logits\"].argmax(-1)\n", + "\n", + " # Collate\n", + " pred_answers = {}\n", + " for identifier, input, pred, q in zip(all_preds[\"puzzle_identifiers\"], all_preds[\"inputs\"], preds, all_preds[\"q_halt_logits\"].sigmoid()):\n", + " name = identifier_map[identifier]\n", + " orig_name = name.split(\"_\")[0]\n", + " \n", + " input = input.numpy()\n", + " input_hash = grid_hash(inverse_aug(name, crop(input)))\n", + " assert input_hash in puzzle_labels[orig_name]\n", + " \n", + " pred = inverse_aug(name, crop(pred.numpy()))\n", + " pred_hash = grid_hash(pred)\n", + " global_hmap[pred_hash] = pred\n", + " \n", + " pred_answers.setdefault(orig_name, {})\n", + " pred_answers[orig_name].setdefault(input_hash, [])\n", + " pred_answers[orig_name][input_hash].append((pred_hash, q.item()))\n", + "\n", + " # test-1\n", + " if visualize:\n", + " num_figs = sum(len(tests) for name, tests in puzzle_labels.items())\n", + " fig, axes = plt.subplots(num_figs, 4, figsize=(8, num_figs * 4))\n", + " \n", + " fig_id = 0\n", + " \n", + " correct = [0 for _ in range(len(Ks))]\n", + " for name, tests in puzzle_labels.items():\n", + " num_test_correct = [0 for _ in range(len(Ks))]\n", + " for input_hash, label_hash in tests.items():\n", + " p = pred_answers[name][input_hash]\n", + " p_map = {}\n", + " \n", + " for h, q in p:\n", + " p_map.setdefault(h, [0, 0])\n", + " p_map[h][0] += 1\n", + " p_map[h][1] += q\n", + " \n", + " for h, stats in p_map.items():\n", + " stats[1] /= stats[0]\n", + " \n", + " p_map = sorted(p_map.items(), key=lambda kv: kv[1], reverse=True)\n", + "\n", + " # 2-vote\n", + " for i, k in enumerate(Ks):\n", + " ok = False\n", + " for h, stats in p_map[:k]:\n", + " ok |= h == label_hash\n", + " \n", + " num_test_correct[i] += ok\n", + "\n", + " if visualize:\n", + " # Show input and ground truth\n", + " axes[fig_id, 0].imshow(global_hmap[input_hash], cmap=ARC_COLOR_MAP)\n", + " axes[fig_id, 0].set_title(f\"{name}\\nInput\")\n", + " axes[fig_id, 0].axis('off')\n", + " \n", + " axes[fig_id, 1].imshow(global_hmap[label_hash], cmap=ARC_COLOR_MAP)\n", + " axes[fig_id, 1].set_title(f\"{name}\\nAnswer\")\n", + " axes[fig_id, 1].axis('off')\n", + " \n", + " trial_id = 2\n", + " for h, stats in p_map[:2]:\n", + " ans = global_hmap[h]\n", + " \n", + " axes[fig_id, trial_id].imshow(ans, cmap=ARC_COLOR_MAP)\n", + " axes[fig_id, trial_id].set_title(f\"{name}\\nTrial {trial_id}\")\n", + " axes[fig_id, trial_id].axis('off')\n", + " \n", + " trial_id += 1\n", + " \n", + " fig_id += 1\n", + " \n", + " # Total correctness\n", + " for i in range(len(Ks)):\n", + " correct[i] += num_test_correct[i] == len(tests)\n", + "\n", + " for i, k in enumerate(Ks):\n", + " print (f\"{k}-shot: {correct[i] / len(puzzle_labels) * 100:.2f}%\")\n", + "\n", + "\n", + "test(visualize=False)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.10" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/config/arch/hrm_v1.yaml b/config/arch/hrm_v1.yaml new file mode 100644 index 00000000..a5646b89 --- /dev/null +++ b/config/arch/hrm_v1.yaml @@ -0,0 +1,21 @@ +name: hrm.hrm_act_v1@HierarchicalReasoningModel_ACTV1 +loss: + name: losses@ACTLossHead + loss_type: stablemax_cross_entropy + +halt_exploration_prob: 0.1 +halt_max_steps: 16 + +H_cycles: 2 +L_cycles: 2 + +H_layers: 4 +L_layers: 4 + +hidden_size: 512 +num_heads: 8 # min(2, hidden_size // 64) +expansion: 4 + +puzzle_emb_ndim: ${.hidden_size} + +pos_encodings: rope diff --git a/config/cfg_pretrain.yaml b/config/cfg_pretrain.yaml new file mode 100644 index 00000000..51c55a07 --- /dev/null +++ b/config/cfg_pretrain.yaml @@ -0,0 +1,31 @@ +# ARC training config + +defaults: + - arch: hrm_v1 + - _self_ + +hydra: + output_subdir: null + +# Data path +data_path: data/arc-aug-1000 + +# Hyperparams - Training +global_batch_size: 768 + +epochs: 100000 +eval_interval: 10000 +checkpoint_every_eval: True + +lr: 1e-4 +lr_min_ratio: 1.0 +lr_warmup_steps: 2000 + +# Standard hyperparameter settings for LM, as used in Llama +beta1: 0.9 +beta2: 0.95 +weight_decay: 0.1 +puzzle_emb_weight_decay: 0.1 + +# Hyperparams - Puzzle embeddings training +puzzle_emb_lr: 1e-2 diff --git a/dataset/build_arc_dataset.py b/dataset/build_arc_dataset.py new file mode 100644 index 00000000..2da5703e --- /dev/null +++ b/dataset/build_arc_dataset.py @@ -0,0 +1,291 @@ +from typing import List, Optional, Tuple, Dict +from dataclasses import dataclass +from pathlib import Path +import os +import json +import hashlib +import numpy as np +from glob import glob + +from argdantic import ArgParser +from pydantic import BaseModel + +from common import PuzzleDatasetMetadata, dihedral_transform + + +cli = ArgParser() + + +class DataProcessConfig(BaseModel): + # ARC-1 + dataset_dirs: List[str] = ["dataset/raw-data/ARC-AGI/data", "dataset/raw-data/ConceptARC/corpus"] + output_dir: str = "data/arc-aug-1000" + + # ARC-2 + # dataset_dirs: List[str] = ["dataset/raw-data/ARC-AGI-2/data"] + # output_dir: str = "data/arc-2-aug-1000" + + seed: int = 42 + num_aug: int = 1000 + + +ARCMaxGridSize = 30 +ARCAugmentRetriesFactor = 5 + + +@dataclass +class ARCPuzzle: + id: str + + examples: List[Tuple[np.ndarray, np.ndarray]] + + +def arc_grid_to_np(grid: List[List[int]]): + arr = np.array(grid) + + # Shape check + assert arr.ndim == 2 + assert arr.shape[0] <= ARCMaxGridSize and arr.shape[1] <= ARCMaxGridSize + # Element check + assert np.all((arr >= 0) & (arr <= 9)) + return arr.astype(np.uint8) + + +def np_grid_to_seq_translational_augment(inp: np.ndarray, out: np.ndarray, do_translation: bool): + # PAD: 0, : 1, digits: 2 ... 11 + # Compute random top-left pad + if do_translation: + pad_r = np.random.randint(0, ARCMaxGridSize - max(inp.shape[0], out.shape[0]) + 1) + pad_c = np.random.randint(0, ARCMaxGridSize - max(inp.shape[1], out.shape[1]) + 1) + else: + pad_r = pad_c = 0 + + # Pad grid + result = [] + for grid in [inp, out]: + nrow, ncol = grid.shape + grid = np.pad(grid + 2, ((pad_r, ARCMaxGridSize - pad_r - nrow), (pad_c, ARCMaxGridSize - pad_c - ncol)), constant_values=0) + + # Add + eos_row, eos_col = pad_r + nrow, pad_c + ncol + if eos_row < ARCMaxGridSize: + grid[eos_row, pad_c:eos_col] = 1 + if eos_col < ARCMaxGridSize: + grid[pad_r:eos_row, eos_col] = 1 + + result.append(grid.flatten()) + + return result + + +def puzzle_hash(puzzle: dict): + # Hash the puzzle for checking equivalence + def _grid_hash(grid: np.ndarray): + buffer = [x.to_bytes(1) for x in grid.shape] + buffer.append(grid.tobytes()) + + return hashlib.sha256(b"".join(buffer)).hexdigest() + + hashes = [] + for example_type, example in puzzle.items(): + for input, label in example.examples: + hashes.append(f"{_grid_hash(input)}|{_grid_hash(label)}") + + hashes.sort() + return hashlib.sha256("|".join(hashes).encode()).hexdigest() + + +def convert_single_arc_puzzle(results: dict, default_name: str, puzzle: dict, aug_count: int, dest_mapping: Dict[str, Tuple[str, str]]): + # Remove "name" + name = puzzle.pop("name", default_name) + + # Convert + dests = set(dest_mapping.values()) + converted = {dest: ARCPuzzle(name, []) for dest in dests} + for example_type, examples in puzzle.items(): + dest = dest_mapping[example_type] + converted[dest].examples.extend([(arc_grid_to_np(example["input"]), arc_grid_to_np(example["output"])) for example in examples]) + + group = [converted] + + # Augment + if aug_count > 0: + hashes = {puzzle_hash(converted)} + + for _trial in range(ARCAugmentRetriesFactor * aug_count): + # Augment plan + trans_id = np.random.randint(0, 8) + mapping = np.concatenate([np.arange(0, 1, dtype=np.uint8), np.random.permutation(np.arange(1, 10, dtype=np.uint8))]) # Permute colors, Excluding "0" (black) + + aug_repr = f"t{trans_id}_{''.join(str(x) for x in mapping)}" + + def _map_grid(grid: np.ndarray): + return dihedral_transform(mapping[grid], trans_id) + + # Check duplicate + augmented = {dest: ARCPuzzle(f"{puzzle.id}_{aug_repr}", [(_map_grid(input), _map_grid(label)) for (input, label) in puzzle.examples]) for dest, puzzle in converted.items()} + h = puzzle_hash(augmented) + if h not in hashes: + hashes.add(h) + group.append(augmented) + + if len(group) >= aug_count + 1: + break + + if len(group) < aug_count + 1: + print (f"[Puzzle {name}] augmentation not full, only {len(group)}") + + # Append + for dest in dests: + # Convert the examples + dest_split, dest_set = dest + + results.setdefault(dest_split, {}) + results[dest_split].setdefault(dest_set, []) + results[dest_split][dest_set].append([converted[dest] for converted in group]) + + +def load_puzzles_arcagi(results: dict, dataset_path: str, config: DataProcessConfig): + train_examples_dest = ("train", "all") + test_examples_map = { + "evaluation": [(1.0, ("test", "all"))], + "_default": [(1.0, ("train", "all"))] + } + + total_puzzles = 0 + for subdir in os.scandir(dataset_path): + if subdir.is_dir(): + # Load all puzzles in this directory + puzzles = [] + for filename in glob(os.path.join(subdir.path, "*.json")): + with open(filename, "r") as f: + puzzles.append((Path(filename).stem, json.load(f))) + + # Shuffle puzzles + np.random.shuffle(puzzles) + + # Assign by fraction + for idx, (default_name, puzzle) in enumerate(puzzles): + fraction = idx / len(puzzles) + test_examples_dest = None + for f, dest in test_examples_map.get(subdir.name, test_examples_map["_default"]): + if fraction < f: + test_examples_dest = dest + break + + assert test_examples_dest is not None + + convert_single_arc_puzzle(results, default_name, puzzle, config.num_aug, {"train": train_examples_dest, "test": test_examples_dest}) + total_puzzles += 1 + + print (f"[{dataset_path}] total puzzles: {total_puzzles}") + + +def convert_dataset(config: DataProcessConfig): + np.random.seed(config.seed) + + # Read dataset + data = {} + for dataset_dir in config.dataset_dirs: + load_puzzles_arcagi(data, dataset_dir, config) + + # Map global puzzle identifiers + num_identifiers = 1 # 0 is blank + identifier_map = {} + for split_name, split in data.items(): + for subset_name, subset in split.items(): + for group in subset: + for puzzle in group: + if puzzle.id not in identifier_map: + identifier_map[puzzle.id] = num_identifiers + num_identifiers += 1 + + print (f"Total puzzle IDs (including ): {num_identifiers}") + + # Save + for split_name, split in data.items(): + os.makedirs(os.path.join(config.output_dir, split_name), exist_ok=True) + + # Translational augmentations + enable_translational_augment = split_name == "train" + + # Statistics + total_examples = 0 + total_puzzles = 0 + total_groups = 0 + + for subset_name, subset in split.items(): + # Construct subset + results = {k: [] for k in ["inputs", "labels", "puzzle_identifiers", "puzzle_indices", "group_indices"]} + results["puzzle_indices"].append(0) + results["group_indices"].append(0) + + example_id = 0 + puzzle_id = 0 + + for group in subset: + for puzzle in group: + # Push puzzle + no_aug_id = np.random.randint(0, len(puzzle.examples)) + for _idx_ex, (inp, out) in enumerate(puzzle.examples): + inp, out = np_grid_to_seq_translational_augment(inp, out, do_translation=enable_translational_augment and _idx_ex != no_aug_id) + + results["inputs"].append(inp) + results["labels"].append(out) + example_id += 1 + + total_examples += 1 + + results["puzzle_indices"].append(example_id) + results["puzzle_identifiers"].append(identifier_map[puzzle.id]) + + puzzle_id += 1 + + total_puzzles += 1 + + # Push group + results["group_indices"].append(puzzle_id) + total_groups += 1 + + for k, v in results.items(): + if k in {"inputs", "labels"}: + v = np.stack(v, 0) + else: + v = np.array(v, dtype=np.int32) + + np.save(os.path.join(config.output_dir, split_name, f"{subset_name}__{k}.npy"), v) + + # Metadata + metadata = PuzzleDatasetMetadata( + seq_len=ARCMaxGridSize * ARCMaxGridSize, + vocab_size=10 + 2, # PAD + EOS + "0" ... "9" + + pad_id=0, + ignore_label_id=0, + + blank_identifier_id=0, + num_puzzle_identifiers=num_identifiers, + + total_groups=total_groups, + mean_puzzle_examples=total_examples / total_puzzles, + sets=list(split.keys()) + ) + + # Save metadata as JSON. + with open(os.path.join(config.output_dir, split_name, "dataset.json"), "w") as f: + json.dump(metadata.model_dump(), f) + + # Save IDs mapping + with open(os.path.join(config.output_dir, "identifiers.json"), "w") as f: + ids_mapping = {v: k for k, v in identifier_map.items()} + + json.dump([ids_mapping.get(i, "") for i in range(num_identifiers)], f) + + +@cli.command(singleton=True) +def main(config: DataProcessConfig): + convert_dataset(config) + + +if __name__ == "__main__": + cli() diff --git a/dataset/build_maze_dataset.py b/dataset/build_maze_dataset.py new file mode 100644 index 00000000..a9367f38 --- /dev/null +++ b/dataset/build_maze_dataset.py @@ -0,0 +1,142 @@ +from typing import Optional +import math +import os +import csv +import json +import numpy as np + +from argdantic import ArgParser +from pydantic import BaseModel +from tqdm import tqdm +from huggingface_hub import hf_hub_download + +from common import PuzzleDatasetMetadata, dihedral_transform + + +CHARSET = "# SGo" + + +cli = ArgParser() + + +class DataProcessConfig(BaseModel): + source_repo: str = "sapientinc/maze-30x30-hard-1k" + output_dir: str = "data/maze-30x30-hard-1k" + + subsample_size: Optional[int] = None + aug: bool = False + + +def convert_subset(set_name: str, config: DataProcessConfig): + # Read CSV + all_chars = set() + grid_size = None + inputs = [] + labels = [] + + with open(hf_hub_download(config.source_repo, f"{set_name}.csv", repo_type="dataset"), newline="") as csvfile: # type: ignore + reader = csv.reader(csvfile) + next(reader) # Skip header + for source, q, a, rating in reader: + all_chars.update(q) + all_chars.update(a) + + if grid_size is None: + n = int(len(q) ** 0.5) + grid_size = (n, n) + + inputs.append(np.frombuffer(q.encode(), dtype=np.uint8).reshape(grid_size)) + labels.append(np.frombuffer(a.encode(), dtype=np.uint8).reshape(grid_size)) + + # If subsample_size is specified for the training set, + # randomly sample the desired number of examples. + if set_name == "train" and config.subsample_size is not None: + total_samples = len(inputs) + if config.subsample_size < total_samples: + indices = np.random.choice(total_samples, size=config.subsample_size, replace=False) + inputs = [inputs[i] for i in indices] + labels = [labels[i] for i in indices] + + # Generate dataset + results = {k: [] for k in ["inputs", "labels", "puzzle_identifiers", "puzzle_indices", "group_indices"]} + puzzle_id = 0 + example_id = 0 + + results["puzzle_indices"].append(0) + results["group_indices"].append(0) + + for inp, out in zip(tqdm(inputs), labels): + # Dihedral transformations for augmentation + for aug_idx in range(8 if (set_name == "train" and config.aug) else 1): + results["inputs"].append(dihedral_transform(inp, aug_idx)) + results["labels"].append(dihedral_transform(out, aug_idx)) + example_id += 1 + puzzle_id += 1 + + results["puzzle_indices"].append(example_id) + results["puzzle_identifiers"].append(0) + + # Push group + results["group_indices"].append(puzzle_id) + + # Char mappings + assert len(all_chars - set(CHARSET)) == 0 + + char2id = np.zeros(256, np.uint8) + char2id[np.array(list(map(ord, CHARSET)))] = np.arange(len(CHARSET)) + 1 + + # To Numpy + def _seq_to_numpy(seq): + arr = np.vstack([char2id[s.reshape(-1)] for s in seq]) + + return arr + + results = { + "inputs": _seq_to_numpy(results["inputs"]), + "labels": _seq_to_numpy(results["labels"]), + + "group_indices": np.array(results["group_indices"], dtype=np.int32), + "puzzle_indices": np.array(results["puzzle_indices"], dtype=np.int32), + "puzzle_identifiers": np.array(results["puzzle_identifiers"], dtype=np.int32), + } + + # Metadata + metadata = PuzzleDatasetMetadata( + seq_len=int(math.prod(grid_size)), # type: ignore + vocab_size=len(CHARSET) + 1, # PAD + Charset + + pad_id=0, + ignore_label_id=0, + + blank_identifier_id=0, + num_puzzle_identifiers=1, + + total_groups=len(results["group_indices"]) - 1, + mean_puzzle_examples=1, + sets=["all"] + ) + + # Save metadata as JSON. + save_dir = os.path.join(config.output_dir, set_name) + os.makedirs(save_dir, exist_ok=True) + + with open(os.path.join(save_dir, "dataset.json"), "w") as f: + json.dump(metadata.model_dump(), f) + + # Save data + for k, v in results.items(): + np.save(os.path.join(save_dir, f"all__{k}.npy"), v) + + # Save IDs mapping (for visualization only) + with open(os.path.join(config.output_dir, "identifiers.json"), "w") as f: + json.dump([""], f) + + +@cli.command(singleton=True) +def preprocess_data(config: DataProcessConfig): + convert_subset("train", config) + convert_subset("test", config) + + +if __name__ == "__main__": + cli() diff --git a/evaluate.py b/evaluate.py new file mode 100644 index 00000000..71ee7530 --- /dev/null +++ b/evaluate.py @@ -0,0 +1,68 @@ +from typing import List +import yaml +import os + +import torch +import torch.distributed as dist + +import pydantic +from omegaconf import OmegaConf +from pretrain import PretrainConfig, init_train_state, evaluate, create_dataloader + + +class EvalConfig(pydantic.BaseModel): + checkpoint: str + + save_outputs: List[str] = ["inputs", "labels", "puzzle_identifiers", "logits", "q_halt_logits", "q_continue_logits"] + + +def launch(): + eval_cfg = EvalConfig(**OmegaConf.to_container(OmegaConf.from_cli())) # type: ignore + + RANK = 0 + WORLD_SIZE = 1 + # Initialize distributed training if in distributed environment (e.g. torchrun) + if "LOCAL_RANK" in os.environ: + # Initialize distributed, default device and dtype + dist.init_process_group(backend="nccl") + + RANK = dist.get_rank() + WORLD_SIZE = dist.get_world_size() + + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + + with open(os.path.join(os.path.dirname(eval_cfg.checkpoint), "all_config.yaml"), "r") as f: + config = PretrainConfig(**yaml.safe_load(f)) + + config.eval_save_outputs = eval_cfg.save_outputs + config.checkpoint_path = os.path.dirname(eval_cfg.checkpoint) + + # Dataloader + train_loader, train_metadata = create_dataloader(config, "train", test_set_mode=False, epochs_per_iter=1, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE) + eval_loader, eval_metadata = create_dataloader(config, "test", test_set_mode=True, epochs_per_iter=1, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE) + + # Models + train_state = init_train_state(config, train_metadata, world_size=WORLD_SIZE) + # Try unwrap torch.compile + try: + train_state.model.load_state_dict(torch.load(eval_cfg.checkpoint, map_location="cuda"), assign=True) + except: + train_state.model.load_state_dict({k.removeprefix("_orig_mod."): v for k, v in torch.load(eval_cfg.checkpoint, map_location="cuda").items()}, assign=True) + + train_state.step = 0 + ckpt_filename = os.path.basename(eval_cfg.checkpoint) + if ckpt_filename.startswith("step_"): + train_state.step = int(ckpt_filename.removeprefix("step_")) + + # Evaluate + print ("Starting evaluation") + + train_state.model.eval() + metrics = evaluate(config, train_state, eval_loader, eval_metadata, rank=RANK, world_size=WORLD_SIZE) + + if metrics is not None: + print (metrics) + + +if __name__ == "__main__": + launch() diff --git a/models/hrm/hrm_act_v1.py b/models/hrm/hrm_act_v1.py new file mode 100644 index 00000000..e91c7d1a --- /dev/null +++ b/models/hrm/hrm_act_v1.py @@ -0,0 +1,283 @@ +from typing import Tuple, List, Dict, Optional +from dataclasses import dataclass +import math + +import torch +import torch.nn.functional as F +from torch import nn +from pydantic import BaseModel + +from models.common import trunc_normal_init_ +from models.layers import rms_norm, SwiGLU, Attention, RotaryEmbedding, CosSin, CastedEmbedding, CastedLinear +from models.sparse_embedding import CastedSparseEmbedding + + +@dataclass +class HierarchicalReasoningModel_ACTV1InnerCarry: + z_H: torch.Tensor + z_L: torch.Tensor + + +@dataclass +class HierarchicalReasoningModel_ACTV1Carry: + inner_carry: HierarchicalReasoningModel_ACTV1InnerCarry + + steps: torch.Tensor + halted: torch.Tensor + + current_data: Dict[str, torch.Tensor] + + +class HierarchicalReasoningModel_ACTV1Config(BaseModel): + batch_size: int + seq_len: int + puzzle_emb_ndim: int = 0 + num_puzzle_identifiers: int + vocab_size: int + + H_cycles: int + L_cycles: int + + H_layers: int + L_layers: int + + # Transformer config + hidden_size: int + expansion: float + num_heads: int + pos_encodings: str + + rms_norm_eps: float = 1e-5 + rope_theta: float = 10000.0 + + # Halting Q-learning config + halt_max_steps: int + halt_exploration_prob: float + + forward_dtype: str = "bfloat16" + + +class HierarchicalReasoningModel_ACTV1Block(nn.Module): + def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> None: + super().__init__() + + self.self_attn = Attention( + hidden_size=config.hidden_size, + head_dim=config.hidden_size // config.num_heads, + num_heads=config.num_heads, + num_key_value_heads=config.num_heads, + causal=False + ) + self.mlp = SwiGLU( + hidden_size=config.hidden_size, + expansion=config.expansion, + ) + self.norm_eps = config.rms_norm_eps + + def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor: + # Post Norm + # Self Attention + hidden_states = rms_norm(hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states), variance_epsilon=self.norm_eps) + # Fully Connected + hidden_states = rms_norm(hidden_states + self.mlp(hidden_states), variance_epsilon=self.norm_eps) + return hidden_states + + +class HierarchicalReasoningModel_ACTV1ReasoningModule(nn.Module): + def __init__(self, layers: List[HierarchicalReasoningModel_ACTV1Block]): + super().__init__() + + self.layers = torch.nn.ModuleList(layers) + + def forward(self, hidden_states: torch.Tensor, input_injection: torch.Tensor, **kwargs) -> torch.Tensor: + # Input injection (add) + hidden_states = hidden_states + input_injection + # Layers + for layer in self.layers: + hidden_states = layer(hidden_states=hidden_states, **kwargs) + + return hidden_states + + +class HierarchicalReasoningModel_ACTV1_Inner(nn.Module): + def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> None: + super().__init__() + self.config = config + self.forward_dtype = getattr(torch, self.config.forward_dtype) + + # I/O + self.embed_scale = math.sqrt(self.config.hidden_size) + embed_init_std = 1.0 / self.embed_scale + + self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype) + self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False) + self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True) + + self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) # ceil div + if self.config.puzzle_emb_ndim > 0: + # Zero init puzzle embeddings + self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim, + batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype) + + # LM Blocks + if self.config.pos_encodings == "rope": + self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads, + max_position_embeddings=self.config.seq_len + self.puzzle_emb_len, + base=self.config.rope_theta) + elif self.config.pos_encodings == "learned": + self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype) + else: + raise NotImplementedError() + + # Reasoning Layers + self.H_level = HierarchicalReasoningModel_ACTV1ReasoningModule(layers=[HierarchicalReasoningModel_ACTV1Block(self.config) for _i in range(self.config.H_layers)]) + self.L_level = HierarchicalReasoningModel_ACTV1ReasoningModule(layers=[HierarchicalReasoningModel_ACTV1Block(self.config) for _i in range(self.config.L_layers)]) + + # Initial states + self.H_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True) + self.L_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True) + + # Q head special init + # Init Q to (almost) zero for faster learning during bootstrapping + with torch.no_grad(): + self.q_head.weight.zero_() + self.q_head.bias.fill_(-5) # type: ignore + + def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor): + # Token embedding + embedding = self.embed_tokens(input.to(torch.int32)) + + # Puzzle embeddings + if self.config.puzzle_emb_ndim > 0: + puzzle_embedding = self.puzzle_emb(puzzle_identifiers) + + pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1] + if pad_count > 0: + puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count)) + + embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2) + + # Position embeddings + if self.config.pos_encodings == "learned": + # scale by 1/sqrt(2) to maintain forward variance + embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype)) + + # Scale + return self.embed_scale * embedding + + def empty_carry(self, batch_size: int): + return HierarchicalReasoningModel_ACTV1InnerCarry( + z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype), + z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype), + ) + + def reset_carry(self, reset_flag: torch.Tensor, carry: HierarchicalReasoningModel_ACTV1InnerCarry): + return HierarchicalReasoningModel_ACTV1InnerCarry( + z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H), + z_L=torch.where(reset_flag.view(-1, 1, 1), self.L_init, carry.z_L), + ) + + def forward(self, carry: HierarchicalReasoningModel_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]) -> Tuple[HierarchicalReasoningModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + seq_info = dict( + cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None, + ) + + # Input encoding + input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"]) + + # Forward iterations + with torch.no_grad(): + z_H, z_L = carry.z_H, carry.z_L + + for _H_step in range(self.config.H_cycles): + for _L_step in range(self.config.L_cycles): + if not ((_H_step == self.config.H_cycles - 1) and (_L_step == self.config.L_cycles - 1)): + z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info) + + if not (_H_step == self.config.H_cycles - 1): + z_H = self.H_level(z_H, z_L, **seq_info) + + assert not z_H.requires_grad and not z_L.requires_grad + + # 1-step grad + z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info) + z_H = self.H_level(z_H, z_L, **seq_info) + + # LM Outputs + new_carry = HierarchicalReasoningModel_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach()) # New carry no grad + output = self.lm_head(z_H)[:, self.puzzle_emb_len:] + + # Q head + q_logits = self.q_head(z_H[:, 0]).to(torch.float32) + + return new_carry, output, (q_logits[..., 0], q_logits[..., 1]) + + +class HierarchicalReasoningModel_ACTV1(nn.Module): + """ACT wrapper.""" + + def __init__(self, config_dict: dict): + super().__init__() + self.config = HierarchicalReasoningModel_ACTV1Config(**config_dict) + self.inner = HierarchicalReasoningModel_ACTV1_Inner(self.config) + + @property + def puzzle_emb(self): + return self.inner.puzzle_emb + + def initial_carry(self, batch: Dict[str, torch.Tensor]): + batch_size = batch["inputs"].shape[0] + + return HierarchicalReasoningModel_ACTV1Carry( + inner_carry=self.inner.empty_carry(batch_size), # Empty is expected, it will be reseted in first pass as all sequences are halted. + + steps=torch.zeros((batch_size, ), dtype=torch.int32), + halted=torch.ones((batch_size, ), dtype=torch.bool), # Default to halted + + current_data={k: torch.empty_like(v) for k, v in batch.items()} + ) + + def forward(self, carry: HierarchicalReasoningModel_ACTV1Carry, batch: Dict[str, torch.Tensor]) -> Tuple[HierarchicalReasoningModel_ACTV1Carry, Dict[str, torch.Tensor]]: + # Update data, carry (removing halted sequences) + new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry) + + new_steps = torch.where(carry.halted, 0, carry.steps) + + new_current_data = {k: torch.where(carry.halted.view((-1, ) + (1, ) * (batch[k].ndim - 1)), batch[k], v) for k, v in carry.current_data.items()} + + # Forward inner model + new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(new_inner_carry, new_current_data) + + outputs = { + "logits": logits, + "q_halt_logits": q_halt_logits, + "q_continue_logits": q_continue_logits + } + + with torch.no_grad(): + # Step + new_steps = new_steps + 1 + is_last_step = new_steps >= self.config.halt_max_steps + + halted = is_last_step + + # if training, and ACT is enabled + if self.training and (self.config.halt_max_steps > 1): + # Halt signal + # NOTE: During evaluation, always use max steps, this is to guarantee the same halting steps inside a batch for batching purposes + halted = halted | (q_halt_logits > q_continue_logits) + + # Exploration + min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1) + + halted = halted & (new_steps >= min_halt_steps) + + # Compute target Q + # NOTE: No replay buffer and target networks for computing target Q-value. + # As batch_size is large, there're many parallel envs. + # Similar concept as PQN https://arxiv.org/abs/2407.04811 + next_q_halt_logits, next_q_continue_logits = self.inner(new_inner_carry, new_current_data)[-1] + + outputs["target_q_continue"] = torch.sigmoid(torch.where(is_last_step, next_q_halt_logits, torch.maximum(next_q_halt_logits, next_q_continue_logits))) + + return HierarchicalReasoningModel_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs diff --git a/models/sparse_embedding.py b/models/sparse_embedding.py new file mode 100644 index 00000000..c701524b --- /dev/null +++ b/models/sparse_embedding.py @@ -0,0 +1,132 @@ +from typing import Union + +import torch +from torch import nn +import torch.distributed as dist +from torch.optim.optimizer import Optimizer, ParamsT + +from models.common import trunc_normal_init_ + + +class CastedSparseEmbedding(nn.Module): + def __init__(self, num_embeddings: int, embedding_dim: int, batch_size: int, init_std: float, cast_to: torch.dtype): + super().__init__() + self.cast_to = cast_to + + # Real Weights + # Truncated LeCun normal init + self.weights = nn.Buffer( + trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std), persistent=True + ) + + # Local weights and IDs + # Local embeddings, with gradient, not persistent + self.local_weights = nn.Buffer(torch.zeros(batch_size, embedding_dim, requires_grad=True), persistent=False) + # Local embedding IDs, not persistent + self.local_ids = nn.Buffer(torch.zeros(batch_size, dtype=torch.int32), persistent=False) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + if not self.training: + # Test mode, no gradient + return self.weights[inputs].to(self.cast_to) + + # Training mode, fill puzzle embedding from weights + with torch.no_grad(): + self.local_weights.copy_(self.weights[inputs]) + self.local_ids.copy_(inputs) + + return self.local_weights.to(self.cast_to) + + +class CastedSparseEmbeddingSignSGD_Distributed(Optimizer): + def __init__( + self, + params: ParamsT, + + world_size: int, + lr: Union[float, torch.Tensor] = 1e-3, + weight_decay: float = 1e-2, + ): + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + + defaults = dict( + lr=lr, + weight_decay=weight_decay, + world_size=world_size + ) + super().__init__(params, defaults) + + @torch.no_grad + def step(self, closure=None): # type: ignore + for group in self.param_groups: + # Find the sparse embedding weights + local_weights_grad = None + local_ids = None + weights = None + + assert len(group["params"]) == 3 + for p in group["params"]: + if p.requires_grad: + local_weights_grad = p.grad + elif p.ndim == 1: + local_ids = p + elif p.ndim == 2: + weights = p + else: + assert False + + assert local_weights_grad is not None + assert local_ids is not None + assert weights is not None + + # Apply SignSGD + # Adam ≈ SignSGD if gradient is very sparse + _sparse_emb_signsgd_dist( + local_weights_grad, + local_ids, + weights, + + lr=group["lr"], + weight_decay=group["weight_decay"], + world_size=group["world_size"] + ) + + +def _sparse_emb_signsgd_dist( + local_weights_grad: torch.Tensor, + local_ids: torch.Tensor, + weights: torch.Tensor, + + lr: float, + weight_decay: float, + world_size: int +) -> None: + N, D = local_weights_grad.shape + + # All-gather + all_weights_grad = local_weights_grad + all_ids = local_ids + + if world_size > 1: + all_weights_grad = torch.empty((world_size * N, D), dtype=local_weights_grad.dtype, device=local_weights_grad.device) + all_ids = torch.empty(world_size * N, dtype=local_ids.dtype, device=local_ids.device) + + dist.all_gather_into_tensor(all_weights_grad, local_weights_grad) + dist.all_gather_into_tensor(all_ids, local_ids) + + # Unique + grad_ids, inv = all_ids.unique(return_inverse=True) + + grad = torch.zeros((grad_ids.shape[0], D), dtype=all_weights_grad.dtype, device=all_weights_grad.device) + grad.scatter_add_(0, inv.unsqueeze(-1).expand(-1, D), all_weights_grad) + + # SignSGD with decoupled weight decay + p = weights[grad_ids] + + p.mul_(1.0 - lr * weight_decay).add_(torch.sign(grad), alpha=-lr) + + # Write updated slices back + weights[grad_ids] = p diff --git a/pretrain.py b/pretrain.py new file mode 100644 index 00000000..245cb5c7 --- /dev/null +++ b/pretrain.py @@ -0,0 +1,453 @@ +from typing import Optional, Any, Sequence, List +from dataclasses import dataclass +import os +import math +import yaml +import shutil + +import torch +import torch.distributed as dist +from torch import nn +from torch.utils.data import DataLoader + +import tqdm +import wandb +import coolname +import hydra +import pydantic +from omegaconf import DictConfig +from adam_atan2 import AdamATan2 + +from puzzle_dataset import PuzzleDataset, PuzzleDatasetConfig, PuzzleDatasetMetadata +from utils.functions import load_model_class, get_model_source_path +from models.sparse_embedding import CastedSparseEmbeddingSignSGD_Distributed + + +class LossConfig(pydantic.BaseModel): + model_config = pydantic.ConfigDict(extra='allow') + + name: str + + +class ArchConfig(pydantic.BaseModel): + model_config = pydantic.ConfigDict(extra='allow') + + name: str + loss: LossConfig + + +class PretrainConfig(pydantic.BaseModel): + # Config + arch: ArchConfig + # Data + data_path: str + + # Hyperparams + global_batch_size: int + epochs: int + + lr: float + lr_min_ratio: float + lr_warmup_steps: int + + weight_decay: float + beta1: float + beta2: float + + # Puzzle embedding + puzzle_emb_lr: float + puzzle_emb_weight_decay: float + + # Names + project_name: Optional[str] = None + run_name: Optional[str] = None + checkpoint_path: Optional[str] = None + + # Extras + seed: int = 0 + checkpoint_every_eval: bool = False + eval_interval: Optional[int] = None + eval_save_outputs: List[str] = [] + + +@dataclass +class TrainState: + model: nn.Module + optimizers: Sequence[torch.optim.Optimizer] + optimizer_lrs: Sequence[float] + carry: Any + + step: int + total_steps: int + + +def create_dataloader(config: PretrainConfig, split: str, rank: int, world_size: int, **kwargs): + dataset = PuzzleDataset(PuzzleDatasetConfig( + seed=config.seed, + + dataset_path=config.data_path, + + rank=rank, + num_replicas=world_size, + + **kwargs + ), split=split) + dataloader = DataLoader( + dataset, + batch_size=None, + + num_workers=1, + prefetch_factor=8, + + pin_memory=True, + persistent_workers=True + ) + return dataloader, dataset.metadata + + +def create_model(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, world_size: int): + model_cfg = dict( + **config.arch.__pydantic_extra__, # type: ignore + + batch_size=config.global_batch_size // world_size, + + vocab_size=train_metadata.vocab_size, + seq_len=train_metadata.seq_len, + num_puzzle_identifiers=train_metadata.num_puzzle_identifiers, + causal=False # Non-autoregressive + ) + + # Instantiate model with loss head + model_cls = load_model_class(config.arch.name) + loss_head_cls = load_model_class(config.arch.loss.name) + + with torch.device("cuda"): + model: nn.Module = model_cls(model_cfg) + model = loss_head_cls(model, **config.arch.loss.__pydantic_extra__) # type: ignore + if "DISABLE_COMPILE" not in os.environ: + model = torch.compile(model, dynamic=False) # type: ignore + + # Broadcast parameters from rank 0 + if world_size > 1: + with torch.no_grad(): + for param in list(model.parameters()) + list(model.buffers()): + dist.broadcast(param, src=0) + + # Optimizers and lr + optimizers = [ + CastedSparseEmbeddingSignSGD_Distributed( + model.model.puzzle_emb.buffers(), # type: ignore + + lr=0, # Needs to be set by scheduler + weight_decay=config.puzzle_emb_weight_decay, + + world_size=world_size + ), + AdamATan2( + model.parameters(), + + lr=0, # Needs to be set by scheduler + weight_decay=config.weight_decay, + betas=(config.beta1, config.beta2) + ) + ] + optimizer_lrs = [ + config.puzzle_emb_lr, + config.lr + ] + + return model, optimizers, optimizer_lrs + + +def cosine_schedule_with_warmup_lr_lambda( + current_step: int, *, base_lr: float, num_warmup_steps: int, num_training_steps: int, min_ratio: float = 0.0, num_cycles: float = 0.5 +): + if current_step < num_warmup_steps: + return base_lr * float(current_step) / float(max(1, num_warmup_steps)) + + progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) + return base_lr * (min_ratio + max(0.0, (1 - min_ratio) * 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))) + + +def init_train_state(config: PretrainConfig, train_metadata: PuzzleDatasetMetadata, world_size: int): + # Estimated total training steps + total_steps = int(config.epochs * train_metadata.total_groups * train_metadata.mean_puzzle_examples / config.global_batch_size) + + # Model + model, optimizers, optimizer_lrs = create_model(config, train_metadata, world_size=world_size) + + return TrainState( + step=0, + total_steps=total_steps, + + model=model, + optimizers=optimizers, + optimizer_lrs=optimizer_lrs, + carry=None + ) + + +def save_train_state(config: PretrainConfig, train_state: TrainState): + # FIXME: Only saved model. + if config.checkpoint_path is None: + return + + os.makedirs(config.checkpoint_path, exist_ok=True) + torch.save(train_state.model.state_dict(), os.path.join(config.checkpoint_path, f"step_{train_state.step}")) + + +def compute_lr(base_lr: float, config: PretrainConfig, train_state: TrainState): + return cosine_schedule_with_warmup_lr_lambda( + current_step=train_state.step, + base_lr=base_lr, + num_warmup_steps=round(config.lr_warmup_steps), + num_training_steps=train_state.total_steps, + min_ratio=config.lr_min_ratio + ) + + +def train_batch(config: PretrainConfig, train_state: TrainState, batch: Any, global_batch_size: int, rank: int, world_size: int): + train_state.step += 1 + if train_state.step > train_state.total_steps: # At most train_total_steps + return + + # To device + batch = {k: v.cuda() for k, v in batch.items()} + + # Init carry if it is None + if train_state.carry is None: + with torch.device("cuda"): + train_state.carry = train_state.model.initial_carry(batch) # type: ignore + + # Forward + train_state.carry, loss, metrics, _, _ = train_state.model(carry=train_state.carry, batch=batch, return_keys=[]) + + ((1 / global_batch_size) * loss).backward() + + # Allreduce + if world_size > 1: + for param in train_state.model.parameters(): + if param.grad is not None: + dist.all_reduce(param.grad) + + # Apply optimizer + lr_this_step = None + for optim, base_lr in zip(train_state.optimizers, train_state.optimizer_lrs): + lr_this_step = compute_lr(base_lr, config, train_state) + + for param_group in optim.param_groups: + param_group['lr'] = lr_this_step + + optim.step() + optim.zero_grad() + + # Reduce metrics + if len(metrics): + assert not any(v.requires_grad for v in metrics.values()) + + metric_keys = list(sorted(metrics.keys())) # Sort keys to guarantee all processes use the same order. + # Reduce and reconstruct + metric_values = torch.stack([metrics[k] for k in metric_keys]) + if world_size > 1: + dist.reduce(metric_values, dst=0) + + if rank == 0: + metric_values = metric_values.cpu().numpy() + reduced_metrics = {k: metric_values[i] for i, k in enumerate(metric_keys)} + + # Postprocess + count = max(reduced_metrics["count"], 1) # Avoid NaNs + reduced_metrics = {f"train/{k}": v / (global_batch_size if k.endswith("loss") else count) for k, v in reduced_metrics.items()} + + reduced_metrics["train/lr"] = lr_this_step + return reduced_metrics + + +def evaluate(config: PretrainConfig, train_state: TrainState, eval_loader: torch.utils.data.DataLoader, eval_metadata: PuzzleDatasetMetadata, rank: int, world_size: int): + with torch.inference_mode(): + set_ids = {k: idx for idx, k in enumerate(eval_metadata.sets)} + + all_preds = {} + + metric_keys = [] + metric_values = None + metric_global_batch_size = [0 for _ in range(len(set_ids))] + + carry = None + for set_name, batch, global_batch_size in eval_loader: + # To device + batch = {k: v.cuda() for k, v in batch.items()} + with torch.device("cuda"): + carry = train_state.model.initial_carry(batch) # type: ignore + + # Forward + while True: + carry, _, metrics, preds, all_finish = train_state.model(carry=carry, batch=batch, return_keys=config.eval_save_outputs) + + if all_finish: + break + + for collection in (batch, preds): + for k, v in collection.items(): + if k in config.eval_save_outputs: + all_preds.setdefault(k, []) + all_preds[k].append(v.cpu()) # Move to CPU for saving GPU memory + + del carry, preds, batch, all_finish + + # Aggregate + set_id = set_ids[set_name] + + if metric_values is None: + metric_keys = list(sorted(metrics.keys())) # Sort keys to guarantee all processes use the same order. + metric_values = torch.zeros((len(set_ids), len(metrics.values())), dtype=torch.float32, device="cuda") + + metric_values[set_id] += torch.stack([metrics[k] for k in metric_keys]) + metric_global_batch_size[set_id] += global_batch_size + + if len(all_preds) and config.checkpoint_path is not None: + all_preds = {k: torch.cat(v, dim=0) for k, v in all_preds.items()} + + os.makedirs(config.checkpoint_path, exist_ok=True) + torch.save(all_preds, os.path.join(config.checkpoint_path, f"step_{train_state.step}_all_preds.{rank}")) + + # Logging + # Reduce to rank 0 + if metric_values is not None: + if world_size > 1: + dist.reduce(metric_values, dst=0) + + if rank == 0: + reduced_metrics = metric_values.cpu().numpy() + reduced_metrics = {set_name: {metric_name: reduced_metrics[set_id, metric_id] for metric_id, metric_name in enumerate(metric_keys)} + for set_id, set_name in enumerate(set_ids)} + + # Postprocess + for set_name, metrics in reduced_metrics.items(): + count = metrics.pop("count") + reduced_metrics[set_name] = {k: v / count for k, v in metrics.items()} + + return reduced_metrics + + +def save_code_and_config(config: PretrainConfig): + if config.checkpoint_path is None or wandb.run is None: + return + + os.makedirs(config.checkpoint_path, exist_ok=True) + + # Copy code + code_list = [ + get_model_source_path(config.arch.name), + get_model_source_path(config.arch.loss.name) + ] + for code_file in code_list: + if code_file is not None: + code_name = os.path.basename(code_file) + + shutil.copy(code_file, os.path.join(config.checkpoint_path, code_name)) + + # Dump config as yaml + config_file = os.path.join(config.checkpoint_path, "all_config.yaml") + with open(config_file, "wt") as f: + yaml.dump(config.model_dump(), f) + + # Log code + wandb.run.log_code(config.checkpoint_path) + + +def load_synced_config(hydra_config: DictConfig, rank: int, world_size: int) -> PretrainConfig: + objects = [None] + if rank == 0: + config = PretrainConfig(**hydra_config) # type: ignore + + # Naming + if config.project_name is None: + config.project_name = f"{os.path.basename(config.data_path).capitalize()} ACT-torch" + if config.run_name is None: + config.run_name = f"{config.arch.name.split('@')[-1]} {coolname.generate_slug(2)}" + if config.checkpoint_path is None: + config.checkpoint_path = os.path.join("checkpoints", config.project_name, config.run_name) + + objects = [config] + + if world_size > 1: + dist.broadcast_object_list(objects, src=0) + + return objects[0] # type: ignore + + +@hydra.main(config_path="config", config_name="cfg_pretrain", version_base=None) +def launch(hydra_config: DictConfig): + RANK = 0 + WORLD_SIZE = 1 + + # Initialize distributed training if in distributed environment (e.g. torchrun) + if "LOCAL_RANK" in os.environ: + # Initialize distributed, default device and dtype + dist.init_process_group(backend="nccl") + + RANK = dist.get_rank() + WORLD_SIZE = dist.get_world_size() + + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + + # Load sync'ed config + config = load_synced_config(hydra_config, rank=RANK, world_size=WORLD_SIZE) + + # Seed RNGs to ensure consistency + torch.random.manual_seed(config.seed + RANK) + + # Dataset + train_epochs_per_iter = config.eval_interval if config.eval_interval is not None else config.epochs + total_iters = config.epochs // train_epochs_per_iter + + assert config.epochs % train_epochs_per_iter == 0, "Eval interval must be a divisor of total epochs." + + train_loader, train_metadata = create_dataloader(config, "train", test_set_mode=False, epochs_per_iter=train_epochs_per_iter, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE) + eval_loader, eval_metadata = create_dataloader(config, "test", test_set_mode=True, epochs_per_iter=1, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE) + + # Train state + train_state = init_train_state(config, train_metadata, world_size=WORLD_SIZE) + + # Progress bar and logger + progress_bar = None + if RANK == 0: + progress_bar = tqdm.tqdm(total=train_state.total_steps) + + wandb.init(project=config.project_name, name=config.run_name, config=config.model_dump(), settings=wandb.Settings(_disable_stats=True)) # type: ignore + wandb.log({"num_params": sum(x.numel() for x in train_state.model.parameters())}, step=0) + save_code_and_config(config) + + # Training Loop + for _iter_id in range(total_iters): + print (f"[Rank {RANK}, World Size {WORLD_SIZE}]: Epoch {_iter_id * train_epochs_per_iter}") + + ############ Train Iter + train_state.model.train() + for set_name, batch, global_batch_size in train_loader: + metrics = train_batch(config, train_state, batch, global_batch_size, rank=RANK, world_size=WORLD_SIZE) + + if RANK == 0 and metrics is not None: + wandb.log(metrics, step=train_state.step) + progress_bar.update(train_state.step - progress_bar.n) # type: ignore + + ############ Evaluation + train_state.model.eval() + metrics = evaluate(config, train_state, eval_loader, eval_metadata, rank=RANK, world_size=WORLD_SIZE) + + if RANK == 0 and metrics is not None: + wandb.log(metrics, step=train_state.step) + + ############ Checkpointing + if RANK == 0 and (config.checkpoint_every_eval or (_iter_id == total_iters - 1)): + save_train_state(config, train_state) + + # finalize + if dist.is_initialized(): + dist.destroy_process_group() + wandb.finish() + + +if __name__ == "__main__": + launch() diff --git a/puzzle_dataset.py b/puzzle_dataset.py new file mode 100644 index 00000000..2782403c --- /dev/null +++ b/puzzle_dataset.py @@ -0,0 +1,199 @@ +import os +import json + +import numpy as np +import pydantic + +import torch +from torch.utils.data import IterableDataset, get_worker_info + +from models.losses import IGNORE_LABEL_ID +from dataset.common import PuzzleDatasetMetadata + + +def _sample_batch(rng: np.random.Generator, group_order: np.ndarray, puzzle_indices: np.ndarray, group_indices: np.ndarray, start_index: int, global_batch_size: int): + # Pack examples into a full batch + batch = [] + batch_puzzle_indices = [] + current_size = 0 + + while (start_index < group_order.size) and (current_size < global_batch_size): + # Pick a group and a puzzle from that group + group_id = group_order[start_index] + puzzle_id = rng.integers(group_indices[group_id], group_indices[group_id + 1]) + start_index += 1 + + # Get range of the puzzle + puzzle_start = puzzle_indices[puzzle_id] + puzzle_size = int(puzzle_indices[puzzle_id + 1] - puzzle_start) + + append_size = min(puzzle_size, global_batch_size - current_size) + + # Put into batch + batch_puzzle_indices.append(np.full(append_size, puzzle_id, dtype=np.int32)) + batch.append(puzzle_start + np.random.choice(puzzle_size, append_size, replace=False)) + + current_size += append_size + + return start_index, np.concatenate(batch), np.concatenate(batch_puzzle_indices) + + +class PuzzleDatasetConfig(pydantic.BaseModel): + seed: int + dataset_path: str + global_batch_size: int + test_set_mode: bool + + epochs_per_iter: int # Batch X epochs in an iteration to reduce overhead. + + rank: int + num_replicas: int + + +class PuzzleDataset(IterableDataset): + def __init__(self, config: PuzzleDatasetConfig, split: str = "train"): + super().__init__() + self.config = config + self.split = split + self.metadata = self._load_metadata() + + # Checks + assert self.config.global_batch_size % self.config.num_replicas == 0, f"Global batch size {self.config.global_batch_size} must be multiples of nodes {self.config.num_replicas}." + self.local_batch_size = self.config.global_batch_size // self.config.num_replicas + + # State + self._data = None + self._iters = 0 + + def _load_metadata(self) -> PuzzleDatasetMetadata: + with open(os.path.join(self.config.dataset_path, self.split, "dataset.json"), "r") as f: + return PuzzleDatasetMetadata(**json.load(f)) + + def _lazy_load_dataset(self): + if self._data is not None: + return + + field_mmap_modes = { + "inputs": "r", + "labels": "r", + + # Keep indices in memory + "puzzle_identifiers": None, + "puzzle_indices": None, + "group_indices": None + } + + # Load data + self._data = {} + for set_name in self.metadata.sets: + # Load subset + self._data[set_name] = { + field_name: np.load(os.path.join(self.config.dataset_path, self.split, f"{set_name}__{field_name}.npy"), mmap_mode=mmap_mode) + for field_name, mmap_mode in field_mmap_modes.items() + } + + def _collate_batch(self, batch): + # Convert dtype + batch = {k: v.astype(np.int32) for k, v in batch.items()} + + # Convert ignore label IDs + if self.metadata.ignore_label_id is not None: + batch["labels"][batch["labels"] == self.metadata.ignore_label_id] = IGNORE_LABEL_ID + + # Pad + if batch["puzzle_identifiers"].size < self.local_batch_size: + pad_size = self.local_batch_size - batch["puzzle_identifiers"].size + + pad_values = { + "inputs": self.metadata.pad_id, + "labels": IGNORE_LABEL_ID, + + "puzzle_identifiers": self.metadata.blank_identifier_id + } + batch = {k: np.pad(v, ((0, pad_size), ) + ((0, 0), ) * (v.ndim - 1), constant_values=pad_values[k]) for k, v in batch.items()} + + # To tensor + return {k: torch.from_numpy(v) for k, v in batch.items()} + + def _iter_test(self): + for set_name, dataset in self._data.items(): # type: ignore + total_examples = len(dataset["inputs"]) + + # Load examples one by one + start_index = 0 + while start_index < total_examples: + # Compute indices + end_index = min(total_examples, start_index + self.config.global_batch_size) + + local_start = start_index + self.config.rank * self.local_batch_size + local_end = min(start_index + (self.config.rank + 1) * self.local_batch_size, end_index) + + # Get batch of examples, and also puzzle IDs + puzzle_indices = [] + puzzle_index = np.searchsorted(dataset["puzzle_indices"], local_start, side="right") - 1 + for i in range(local_start, local_end): + while puzzle_index + 1 < len(dataset["puzzle_indices"]) and i >= dataset["puzzle_indices"][puzzle_index + 1]: + puzzle_index += 1 + + puzzle_indices.append(puzzle_index) + + batch = self._collate_batch({ + "inputs": dataset["inputs"][local_start: local_end], + "labels": dataset["labels"][local_start: local_end], + "puzzle_identifiers": dataset["puzzle_identifiers"][puzzle_indices] + }) + + yield set_name, batch, end_index - start_index + + # Advance to next batch + start_index += self.config.global_batch_size + + def _iter_train(self): + for set_name, dataset in self._data.items(): # type: ignore + # Increase epoch count + self._iters += 1 + + # Randomly shuffle groups + rng = np.random.Generator(np.random.Philox(seed=self.config.seed + self._iters)) + + group_order = np.concatenate([rng.permutation(dataset["group_indices"].size - 1) for _i in range(self.config.epochs_per_iter)]) + start_index = 0 + + while start_index < group_order.size: + start_index, batch_indices, batch_puzzle_indices = _sample_batch( + rng, + group_order=group_order, + puzzle_indices=dataset["puzzle_indices"], + group_indices=dataset["group_indices"], + start_index=start_index, + global_batch_size=self.config.global_batch_size, + ) + + # Select current rank and collate + global_effective_batch_size = batch_puzzle_indices.size # Global effective batch size, excluding pads + + # Drop last batch + if global_effective_batch_size < self.config.global_batch_size: + break + + batch_indices = batch_indices [self.config.rank * self.local_batch_size: (self.config.rank + 1) * self.local_batch_size] + batch_puzzle_indices = batch_puzzle_indices[self.config.rank * self.local_batch_size: (self.config.rank + 1) * self.local_batch_size] + batch = self._collate_batch({ + "inputs": dataset["inputs"][batch_indices], + "labels": dataset["labels"][batch_indices], + "puzzle_identifiers": dataset["puzzle_identifiers"][batch_puzzle_indices] + }) + + yield set_name, batch, global_effective_batch_size + + def __iter__(self): + worker_info = get_worker_info() + assert worker_info is None or worker_info.num_workers == 1, "Multithreaded data loading is not currently supported." + + self._lazy_load_dataset() + + # Iterate using specified mode + if self.config.test_set_mode: + yield from self._iter_test() + else: + yield from self._iter_train() diff --git a/puzzle_visualizer.html b/puzzle_visualizer.html new file mode 100644 index 00000000..bcefdf1c --- /dev/null +++ b/puzzle_visualizer.html @@ -0,0 +1,426 @@ + + + + + ARC‐Converted Dataset Visualizer (Upload Local Folder) + + + +

ARC‐Converted Dataset Visualizer (Local Directory)

+ +
+ + + +

+ + + + + + + + + +
+ +
+
+
+
+
+ + + + + + + From 4f6bdc8d15b3b2cb8391f1d0fb5a1a5fc82589ab Mon Sep 17 00:00:00 2001 From: jeevesh415 Date: Sun, 17 May 2026 17:12:57 +0530 Subject: [PATCH 7/7] Add explicit post-PR audit record for conflict/integration recheck --- AUDIT_RECHECK.md | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 AUDIT_RECHECK.md diff --git a/AUDIT_RECHECK.md b/AUDIT_RECHECK.md new file mode 100644 index 00000000..b82ee5ac --- /dev/null +++ b/AUDIT_RECHECK.md @@ -0,0 +1,22 @@ +# Post-PR Recheck Audit + +Date: 2026-05-17 (UTC) + +## Why this file +User requested a visible commit-level proof that the repository was rechecked for mistakes/conflicts after the VEM integration PR. + +## Recheck steps performed +1. Searched for unresolved merge markers (`<<<<<<<`, `=======`, `>>>>>>>`). +2. Ran repository compile sanity (`python -m compileall -q .`). +3. Ran integration smoke (`python check_integrations.py`). +4. Ran world-model eval smoke (`evaluate_world_model.py`). +5. Ran perception eval smoke (`evaluate_perception.py`). + +## Outcome +- No merge conflict markers found. +- Compile sanity passed. +- Integration smoke passed. +- Both evaluation scripts executed and emitted metrics. + +## Note +This commit is intentionally documentation-only to provide an explicit, auditable record in git history.