From a9aee1fd56a6d1d52e99e17118f77d4d20a37a99 Mon Sep 17 00:00:00 2001 From: Thomson Lam Date: Mon, 30 Mar 2026 13:49:52 -0400 Subject: [PATCH 1/2] feat: environment metrics for training --- docs/envspec.md | 104 +++++-- docs/ml_architecture.md | 111 -------- docs/planning/impl-plan.md | 29 +- docs/planning/proposal.md | 215 --------------- docs/planning/train-plan.md | 535 ++++++++++++++++++++++++++++++++++++ docs/rl_agent.md | 137 --------- src/models/fire_env.py | 23 ++ 7 files changed, 669 insertions(+), 485 deletions(-) delete mode 100644 docs/ml_architecture.md delete mode 100644 docs/planning/proposal.md create mode 100644 docs/planning/train-plan.md delete mode 100644 docs/rl_agent.md diff --git a/docs/envspec.md b/docs/envspec.md index 2d2f69f..637252a 100644 --- a/docs/envspec.md +++ b/docs/envspec.md @@ -1,8 +1,10 @@ -# Fire Environment Specification (Current Implementation) +# Fire Environment Specification -This document describes the currently implemented RL environment in `src/models/fire_env.py` and its training/evaluation usage in `src/models/train_rl_agent.py` and `src/models/evaluate_agents.py`. +This document describes the RL environment in `src/models/fire_env.py` and its training/evaluation usage in `src/models/train_rl_agent.py` and `src/models/evaluate_agents.py`. -It is intentionally code-first and only documents behavior that exists in the current codebase. +It remains code-first for implemented behavior, but it also records the benchmark-alignment requirements that training and evaluation code must satisfy before full canonical runs are launched. + +The concrete benchmark execution plan lives in `docs/planning/train-plan.md`. --- @@ -174,14 +176,14 @@ Optimization intent: --- -## 7) Training Process (Current Code) +## 7) Training Process Current training script: - `src/models/train_rl_agent.py` -- algorithm currently implemented in this script: `PPO` (Stable-Baselines3) +- currently implemented learned method: `PPO` (Stable-Baselines3) -Canonical training flow: +Current implemented flow: 1. load seeded train split dataset 2. create vectorized benchmark envs (`n_envs`) @@ -189,13 +191,32 @@ Canonical training flow: 4. save model to `src/models/tactical_ppo_agent.zip` 5. run quick evaluation on train and optional val/holdout datasets +Benchmark-aligned target flow: + +1. use a unified runner with `--algo {ppo,a2c,dqn}` +2. keep the same benchmark-mode dataset path and split validation for all methods +3. use vectorized envs for `PPO` and `A2C` +4. use a single benchmark env for `DQN` by default +5. write checkpoint metrics every fixed training interval +6. save per-run config and per-checkpoint metrics to disk +7. choose the best checkpoint by validation `asset_survival_rate` +8. run final split-wise evaluation on train/val/holdout after training completes + Current benchmark evaluation script: - `src/models/evaluate_agents.py` - evaluates agents across splits (`train`, `val`, `holdout`) -- supported evaluated agents: `ppo`, `greedy`, `random` +- currently implemented evaluated agents: `ppo`, `greedy`, `random` - can output JSON summary via `--output` +Benchmark-aligned target support: + +- `ppo` +- `a2c` +- `dqn` +- `greedy` +- `random` + Transparency outputs from current code: - training console output: timesteps, env count, dataset path/count, quick split metrics @@ -203,6 +224,13 @@ Transparency outputs from current code: - evaluation console summary per split/agent - optional evaluation JSON with aggregate metrics +Required benchmark transparency outputs: + +- serialized run config per seed +- checkpoint metrics on train/val/holdout +- best-checkpoint selection record +- final evaluation JSON aggregated by seed, then across seeds + Recommended transparency plots (from saved eval JSON/logs): - split-wise mean return (`train` vs `val` vs `holdout`) @@ -212,24 +240,66 @@ Recommended transparency plots (from saved eval JSON/logs): --- -## 8) Reporting Metrics +## 8) Reporting Metrics -Primary optimization target/what the agent is trained to do: **Minimize assets damaged/lost** +Primary optimization target: **Minimize assets damaged/lost** -Additional reported metrics (already computed or directly derivable from current eval): +Frozen benchmark metric definitions: - mean episodic return -- standard deviation / variance across episodes - asset survival rate - containment success rate -- mean final burned area -- mean time to containment -- mean resource efficiency +- mean burned-area fraction: `(burned + burning + asset_burned) / 625` +- mean time to containment, conditioned on successful containment only +- mean resource efficiency: `successful_deployments / total_deployments` +- standard deviation across seeds for each reported metric +- wasted deployment rate - mean normalized burn ratio (optional in evaluator) -Report these diagnostics during training: +Important alignment notes: + +- pooled episode variance is not a substitute for seed-level aggregation +- raw burned-cell count can still be logged, but the benchmark-facing metric should be the normalized burned-area fraction +- holdout performance is for final reporting only and must not be used for tuning +- the current temporal holdout dataset has only one unique seeded record, so it is a final diagnostic only until expanded + +Report these diagnostics during training checkpoints: -- train/val/holdout gap for each metric +- train/val gap for each metric +- optional train/family-holdout gap for each metric - per-seed summary tables -- baseline comparisons (`greedy`, `random`) against PPO +- baseline comparisons (`greedy`, `random`) against learned methods + +## 9) Environment-Side Requirements For Benchmark Alignment + +The environment and evaluator together must expose enough information to compute the benchmark metrics exactly. + +Environment-side counters or `info` fields required for clean evaluation: + +- `assets_lost` +- `step` +- `heli_left` +- `crew_left` +- count of successful helicopter deployments +- count of successful crew deployments +- count of wasted deployment attempts +- count of total deployment attempts + +Operational metric rules: + +- `mean_resource_efficiency = successful_deployments / total_deployments` +- if `total_deployments == 0`, report `0.0` +- `wasted_deployment_rate = wasted_deployments / total_deployment_attempts` +- if `total_deployment_attempts == 0`, report `0.0` + +Evaluator-side aggregation requirements: + +- aggregate per seed first, then summarize across seeds +- compute `time_to_containment` only on contained episodes +- compute normalized burn ratio against the same scenario record and evaluation seed under the deterministic non-intervention baseline defined in `docs/planning/train-plan.md` +- do not surface temporal holdout metrics during checkpoint evaluation in canonical runs +- pass `scenario_families` explicitly for canonical train, validation, and family-holdout runs rather than relying on the environment default + +Verification requirement before full runs: +- all of the above metrics must appear in a short smoke-test evaluation artifact before any full 5-seed training sweep is launched diff --git a/docs/ml_architecture.md b/docs/ml_architecture.md deleted file mode 100644 index 19442e4..0000000 --- a/docs/ml_architecture.md +++ /dev/null @@ -1,111 +0,0 @@ -# FireGrid ML Architecture - -## Overview - -FireGrid uses a two-stage ML pipeline: - -1. **XGBoost Spread Model** — Supervised learning. Answers: *"How far will this fire spread?"* -2. **PPO Tactical Agent** *(planned)* — Reinforcement learning. Answers: *"Where should we deploy assets to stop it?"* - ---- - -## Stage 1: XGBoost Fire Spread Model - -**File:** `backend/src/models/spread_model.py` -**Saved weights:** `backend/src/models/spread_1h_model.joblib`, `spread_3h_model.joblib` - -### What It Predicts -Two separate `XGBRegressor` models trained in parallel: -- `spread_1h_m` — predicted fire spread **radius in metres** after 1 hour -- `spread_3h_m` — predicted spread radius after 3 hours - -These radii feed the map's fire spread ring visualization. - -### Training Data Strategy -Real historical spread data requires years of CWFIS archives. For this hackathon, we generate **6,000 physics-informed synthetic samples** using the **Rothermel fire spread formula** as a label generator: - -``` -spread_rate ∝ ISI × exp(0.05039 × wind_speed) × (1 - RH/120) × FFMC_factor × slope_factor -``` - -We add Gaussian noise so XGBoost learns a smooth, non-linear approximation rather than memorizing the formula. - -### Feature Vector (11 inputs) - -| Feature | Source | Why It Matters | -|---|---|---| -| `wind_speed_km_h` | Open-Meteo API | Primary spread driver (exponential effect) | -| `wind_u` | Derived: `speed × cos(dir_rad)` | Eastward wind vector — fixes cyclic discontinuity | -| `wind_v` | Derived: `speed × sin(dir_rad)` | Northward wind vector | -| `temperature_c` | Open-Meteo API | Fuel pre-heating, moisture evaporation | -| `relative_humidity_pct` | Open-Meteo API | Dampens spread — key moisture indicator | -| `fwi` | CWFIS/CFFDRS | Overall Canadian fire danger index | -| `isi` | CWFIS/CFFDRS | Initial Spread Index — rate of spread | -| `bui` | CWFIS/CFFDRS | Buildup Index — available fuel load | -| `area_hectares` | DynamoDB | Current fire size context | -| `slope_pct` | Synthetic (–20 to +45%) | Uphill fires accelerate via convective preheating | -| `rh_trend_24h` | Synthetic / ECCC | RH change over 24h — drying conditions amplify danger | - -### Key Design Decisions - -**Wind U/V decomposition** — Tree-based models cannot handle cyclical features. Feeding `wind_direction_deg = 359°` and `1°` as raw numbers makes them look 358 units apart, when they are physically 2° apart. Projecting onto Cartesian `(U, V)` vectors eliminates this discontinuity. This is a standard practice in numerical weather prediction (NWP) preprocessing. - -**Slope factor** — The Rothermel (1972) model includes a slope intensification term. Fires moving uphill receive convective preheating from the rising column of hot gas ahead of the flame front, dramatically increasing spread rate. We apply `slope_factor = 1 + max(0, slope%)/20`, giving up to a 3.25× multiplier at 45% slope. - -**RH trend (temporal context)** — A fire at 40% RH is more dangerous if RH was 80% yesterday and is dropping fast. Adding `rh_trend_24h` lets the model distinguish static vs. rapidly-drying conditions without building a full time-series architecture. - -### Observed Feature Importances (v2 model) - -``` -isi ███████████████ 0.301 ← CFFDRS Initial Spread Index -wind_speed_km_h ██████████████ 0.293 ← Wind magnitude -slope_pct ███████ 0.145 ← Terrain topography -relative_humidity_pct ███ 0.064 ← Moisture damper -wind_v ███ 0.063 ← Northward component -wind_u ██ 0.048 ← Eastward component -rh_trend_24h █ 0.037 ← Temporal drying context -temperature_c █ 0.026 -area_hectares 0.009 -bui 0.009 -fwi 0.005 -``` - -### Evaluation Results - -| Model | MAE | R² | -|---|---|---| -| 1-hour spread | 421 m | 0.977 | -| 3-hour spread | 1,323 m | 0.977 | - -High R² expected on synthetic data — the model is learning the correct physical relationships between features and spread rate. - -### API Integration - -``` -GET /api/v1/predictions/{fire_id} -``` -1. Fetches fire record from DynamoDB (lat/lon, area) -2. Calls `get_fire_weather(lat, lon)` → Open-Meteo API (real-time) -3. Builds 11-feature vector (with CFFDRS fallback defaults) -4. Runs `model_1h.predict()` and `model_3h.predict()` -5. Returns `{spread_1h_m, spread_3h_m, features_used, model}` - ---- - -## Stage 2: PPO Tactical Agent *(Planned — Phase 3b)* - -**See:** `docs/ppo_plan.md` - -The RL agent takes the XGBoost spread prediction as input and outputs tactical deployment coordinates for helicopters, ground crews, and dozers. - ---- - -## Running the Model - -```bash -# Train from scratch and run live demo predictions -uv run python -m src.models.spread_model - -# Call via API (backend must be running) -Invoke-RestMethod http://localhost:8000/api/v1/predictions/BC-2026-001 -``` diff --git a/docs/planning/impl-plan.md b/docs/planning/impl-plan.md index 00f0f4b..54b92af 100644 --- a/docs/planning/impl-plan.md +++ b/docs/planning/impl-plan.md @@ -2,6 +2,8 @@ This file is the single source of truth for implementation and evaluation. +For the concrete training runner, tuning, checkpointing, and verification workflow, see `docs/planning/train-plan.md`. + Project direction: **Empirical comparison of standard RL algorithms on an enhanced custom wildfire simulator with one objective: protect critical assets under limited suppression budget.** @@ -194,6 +196,12 @@ If unstable, adjust only `asset` and `burn` coefficients once, then freeze. Report both in-distribution and held-out performance. +Split interpretation note: + +- family holdout and temporal holdout are distinct and must be labeled separately in reporting. +- canonical train/validation runs should pass explicit scenario families rather than relying on environment defaults. +- the current temporal holdout artifact is a single-record diagnostic and should not be treated as a full held-out benchmark until expanded. + --- ## 7) Algorithms to Benchmark @@ -212,6 +220,8 @@ Recurrent baselines are not included because we will not add and test hidden reg ## 8) Benchmark Harness and Logging (Required Infrastructure) +Exact metric definitions and the verification ladder are frozen in `docs/planning/train-plan.md`. + Requirements: 1. Unified runner for all algorithms. @@ -219,14 +229,15 @@ Requirements: 3. Metrics written to CSV/JSON per checkpoint and final summary. 4. Distinct evaluation mode with fallback heuristics disabled for RL methods. 5. Seed-aware aggregation scripts for mean/std and confidence intervals. +6. Canonical checkpoint evaluation must exclude temporal holdout. Core metrics: 1. mean episodic return 2. asset survival rate 3. containment success rate -4. final burned area -5. variance across seeds +4. mean burned-area fraction +5. standard deviation across seeds Secondary metrics: @@ -242,6 +253,12 @@ Normalized burn ratio definition: - The denominator comes from a no-action baseline rollout using the same scenario record and RNG seed. - This is an evaluation-only metric and does not modify the training reward. +Metric interpretation notes: + +- `time to containment` is conditioned on successful containment episodes only. +- `resource efficiency` is `successful_deployments / total_deployments`. +- pooled episode variance is not the benchmark aggregate; summarize per seed first, then report `mean +- std` across seeds. + --- ## 9) Static Scenario Parameter Interface @@ -361,9 +378,11 @@ flowchart TD 6. Implement scenario generator with frozen train/test families. 7. Implement snapshot cache loader and offline parameter-to-env mapping. 8. Add evaluation-only normalized burn ratio reporting. -9. Run reward sanity pass and freeze coefficients. -10. Run full multi-seed benchmarks for DQN/A2C/PPO + greedy/random. -11. Aggregate plots/tables and write limitations. +9. Add checkpoint metrics, config serialization, and best-checkpoint selection. +10. Run reward sanity pass and freeze coefficients. +11. Run algorithm smoke tests and short pilot tuning runs. +12. Run full multi-seed benchmarks for DQN/A2C/PPO + greedy/random. +13. Aggregate plots/tables and write limitations. --- diff --git a/docs/planning/proposal.md b/docs/planning/proposal.md deleted file mode 100644 index d2b5512..0000000 --- a/docs/planning/proposal.md +++ /dev/null @@ -1,215 +0,0 @@ -# Proposal: RL Benchmark for Wildfire Asset Protection (Option A) - -## Working Title - -**Protecting Critical Assets Under Limited Suppression Budget: An Empirical RL Benchmark in a Custom Wildfire Simulator** - ---- - -## 1) Problem Statement - -Emergency wildfire response requires fast tactical decisions under uncertainty: where to move, when to deploy suppression, and which limited resource to spend first. Real-world evaluation is expensive, risky, and not reproducible. - -This project studies a controlled version of that problem in a custom simulator with one concrete objective: - -**Protect critical assets under limited suppression budget.** - ---- - -## 2) Why Existing Techniques Are Not Fully Satisfying - -- Heuristic approaches are often brittle across different fire layouts and spread severities. -- A single PPO implementation does not establish robust algorithm suitability. -- The current hackathon codebase lacks a clean multi-seed, multi-algorithm benchmark harness. -- Existing docs overstate operational realism relative to implementation; paper claims must match implemented evidence. - -Therefore, a defensible one-week contribution is an empirical benchmark of standard RL methods on an improved simulator with real control tradeoffs. - ---- - -## 3) Intuition Behind the Developed Technique - -The developed technique is not a new RL algorithm. It is an enhanced benchmark environment and evaluation protocol designed to produce meaningful tactical tradeoffs: - -1. **Prioritization under risk**: critical assets can be lost if not protected. -2. **Planning under scarcity**: helicopter/crew actions are limited and costly. -3. **Spatial reasoning**: non-uniform spread conditions from fixed episode parameters such as spread severity and wind bias. -4. **Robustness testing**: multiple scenario families and held-out test families. - -Core intuition: better benchmark structure and rigorous evaluation produce more defensible RL evidence than adding algorithmic novelty under time pressure. - ---- - -## 4) Techniques to Tackle the Problem - -## 4.1 Algorithms to compare (with rationale) - -- **DQN**: value-based baseline for discrete tactical control. -- **A2C**: lightweight on-policy actor-critic baseline. -- **PPO**: stronger policy-gradient baseline and current repo baseline. -- **Greedy heuristic**: domain-inspired non-RL tactical baseline. -- **Random policy**: floor sanity check. - -Optional only if hidden regime shifts are added and time permits: - -- **Recurrent PPO baseline** for partial observability robustness. - -## 4.2 Environment enhancements to keep (minimum viable strong paper) - -1. **Scenario diversity** - - ignition patterns: center, edge, corner, multi-cluster - - severity levels: low/medium/high - -2. **Finite resources with costs/cooldowns** - - limited helicopter drops and crew deployments - - resource cost and cooldown penalties - -3. **Heterogeneous spread conditions** - - precomputed spread severity from the static dataset - - directional wind bias - -4. **Clean benchmark harness** - - fixed train/eval protocol - - multi-seed runs - - no fallback heuristic contamination during evaluation - -## 4.3 High-value additions (if time permits) - -1. **Critical assets (recommended, highest value)** - - place 2-5 assets on map - - strong penalty when assets burn - -2. **Travel/action latency** - - deployment requires position or delayed effect - -3. **Simple hidden shift test** - - stochastic wind shift mid-episode - -4. **Train/test split on scenario families** - - evaluate on held-out ignition/severity combinations - ---- - -## 5) Planned Related Work Review (4-8 papers) - -The report will cover: - -1. Core RL algorithm papers: - - DQN - - A3C/A2C - - PPO - -2. RL benchmark/reproducibility papers: - - seed sensitivity, fair comparison protocols - -3. Wildfire decision-support/spread-modeling papers: - - decision-support framing - - spread prediction and response strategy - -The literature section will support an empirical benchmarking contribution, not a novel algorithm claim. - ---- - -## 6) Experimental Plan - -## 6.1 Fixed protocol - -- Same environment generator family for all algorithms. -- Same timestep budget per algorithm. -- Same seed set (3-5 seeds). -- Same evaluation episodes/checkpoints. -- Fallback heuristic disabled during RL benchmark runs. - -## 6.2 Metrics - -Primary: - -- mean episodic return -- critical asset survival rate -- containment success rate -- final burned area (cells) - -Secondary: - -- resource efficiency (suppression impact per resource spent) -- wasted deployment rate -- time to containment -- variance across seeds - -## 6.3 Generalization evaluation - -- Train on subset of scenario families. -- Test on held-out ignition/severity combinations. -- Compare performance drop and robustness ranking across methods. - ---- - -## 7) Data Pipeline Positioning in the Paper - -Data pipeline remains **supporting context**, not the central empirical claim. - -Based on audit findings, claims must stay realistic: - -- implemented ingestion: FIRMS, CWFIS active fires, Open-Meteo, CFFDRS -- benchmark use: one-time ingestion and preprocessing into static scenario records with precomputed environment variables -- not fully implemented as production ETL: CIFFC, BC/AB ArcGIS full pipeline, ECCC Datamart orchestration, broad historical validated spread labels - -Paper wording will avoid operational overclaim and state: - -"We benchmark RL methods in an enhanced custom wildfire simulator using static snapshot-derived scenario records and fixed environment parameterization." - ---- - -## 8) One-Week Execution Plan - -Day 1: -- Freeze objective and benchmark protocol. -- Add critical assets + resource budgets to environment. - -Day 2: -- Add scenario generator and static parameter preprocessing for spread severity and wind bias. -- Add eval mode without fallback contamination. - -Day 3: -- Implement DQN/A2C runners alongside PPO. -- Standardize logs and output schema. - -Day 4: -- Pilot runs and reward sanity checks. -- Fix instability and environment calibration issues. - -Day 5: -- Full multi-seed train/eval runs. - -Day 6: -- Aggregate results, figures, and tables. -- Draft evaluation/discussion. - -Day 7: -- Final report polish with limitations and future work. - ---- - -## 9) Expected Contribution and Defensibility - -This proposal is defensible because it: - -- focuses on a single tangible objective, -- introduces genuine tactical tradeoffs, -- compares standard baselines fairly, -- reports multi-seed results under controlled scenario families, -- avoids overclaiming real-world deployment readiness. - -Expected claim: - -"We design an enhanced wildfire tactical suppression benchmark with protected assets, limited suppression budget, and heterogeneous spread, then compare standard RL and heuristic baselines under controlled scenario families." - ---- - -## 10) Future Work (Out of Scope This Week) - -- multi-agent coordination -- real GIS terrain integration -- historical replay validation -- complex dispatch logistics -- continuous-action aircraft routing diff --git a/docs/planning/train-plan.md b/docs/planning/train-plan.md new file mode 100644 index 0000000..2e5e818 --- /dev/null +++ b/docs/planning/train-plan.md @@ -0,0 +1,535 @@ +# Training and Benchmark Plan + +This document defines the concrete plan for implementing, validating, tuning, and benchmarking learned agents for `WildfireEnv`. + +It is narrower than `docs/planning/impl-plan.md`: this file focuses specifically on training infrastructure, metric definitions, verification steps, and the execution order needed to produce paper-ready results. + +--- + +## 1) Goal + +Benchmark standard RL methods against simple baselines on the frozen wildfire asset-protection task. + +Required benchmark methods: + +- `random` +- `greedy` +- `DQN` +- `A2C` +- `PPO` + +Primary paper question: + +> Under the same frozen environment, split datasets, and training budget, which method best protects critical assets while containing fire under limited suppression resources? + +--- + +## 2) Frozen Protocol + +Unless explicitly labeled as an ablation, use the following benchmark protocol. + +- Grid size: `25 x 25` +- Episode horizon: `150` +- Train dataset: `data/static/scenario_parameter_records_seeded_train.json` +- Validation dataset: `data/static/scenario_parameter_records_seeded_val.json` +- Holdout dataset: `data/static/scenario_parameter_records_seeded_holdout.json` +- Training seeds: `11, 22, 33, 44, 55` +- Evaluation cadence during training: every `20,000` env steps +- Checkpoint evaluation episodes per split: `20` +- Final evaluation episodes per seed for train/val: `100` + +Training budget per algorithm per seed: + +- `PPO`: `200,000` env steps +- `A2C`: `200,000` env steps +- `DQN`: `200,000` env steps + +The final paper may report a longer PPO run separately only if all compared methods receive the same additional budget or the extra run is clearly labeled as an ablation. + +### 2.1 Split semantics are two-dimensional + +Canonical benchmarking has two distinct notions of split, and both must be enforced explicitly in code. + +1. Temporal data split from the seeded scenario-record files: + - train records: `scenario_parameter_records_seeded_train.json` + - validation records: `scenario_parameter_records_seeded_val.json` + - holdout records: `scenario_parameter_records_seeded_holdout.json` +2. Scenario-family split inside `WildfireEnv`: + - in-distribution families: `TRAIN_FAMILIES` + - held-out OOD families: `HELD_OUT_FAMILIES` + +Canonical meaning of each split: + +- Train: train records + `TRAIN_FAMILIES` +- Validation: validation records + `TRAIN_FAMILIES` +- Family holdout: validation or expanded holdout records + `HELD_OUT_FAMILIES` +- Temporal holdout: holdout records + explicit family list, reported separately from train/val + +Implementation rule: + +- The training/evaluation runner must pass `scenario_families` explicitly. +- Canonical runs must not rely on the environment default of `TRAIN_FAMILIES` when split semantics matter. + +### 2.2 Current holdout limitation + +The current seeded temporal holdout dataset contains only one unique record. + +Consequences: + +- it is not a credible benchmark split for model selection or checkpoint-time monitoring +- repeated rollouts across seeds do not produce a meaningful `std_across_seeds` estimate for holdout +- it may still be used as a final, clearly labeled stress-test diagnostic + +Paper-ready rule: + +- Do not present the current single-record temporal holdout as a full held-out benchmark. +- Before making strong final holdout claims, expand `scenario_parameter_records_seeded_holdout.json` beyond one record. + +--- + +## 3) Implementation Scope + +### 3.1 Unified training runner + +Extend `src/models/train_rl_agent.py` into a unified runner with `--algo {ppo,a2c,dqn}`. + +Add a benchmark-safe preset in code: + +- `--benchmark-preset canonical` + +The canonical preset should fill in the frozen benchmark defaults unless the user explicitly overrides them for an ablation run. + +Canonical preset values: + +- train dataset: `data/static/scenario_parameter_records_seeded_train.json` +- validation dataset: `data/static/scenario_parameter_records_seeded_val.json` +- holdout dataset: `data/static/scenario_parameter_records_seeded_holdout.json` +- train/validation families: `TRAIN_FAMILIES` +- family-holdout families: `HELD_OUT_FAMILIES` +- checkpoint cadence: `20,000` env steps +- checkpoint evaluation episodes: `20` +- checkpoint-visible splits: `train`, `val`, and optional family holdout only +- final evaluation episodes for train/val: `100` +- benchmark-mode env creation enabled + +Holdout visibility rule for the canonical preset: + +- Do not surface temporal holdout metrics during checkpoint evaluation or hyperparameter sweeps. +- Temporal holdout is final-reporting-only until the holdout dataset is expanded beyond one record. + +Requirements: + +1. Shared dataset loading and benchmark-mode env construction. +2. Shared run config serialization. +3. Shared checkpoint evaluation path. +4. Per-algorithm model construction. +5. Per-algorithm output path naming. + +Run artifact directory naming: + +- `artifacts/benchmark///seed_/` + +Required `run_label` values: + +- `smoke` +- `pilot` +- `final` + +Purpose: + +- prevent smoke tests and pilot sweeps from overwriting canonical final benchmark artifacts + +Canonical per-seed artifacts: + +- `config.json` +- `checkpoint_metrics.json` +- `best_checkpoint.json` +- `best_model.zip` +- `last_model.zip` +- `final_eval_best.json` + +Optional convenience exports outside the artifact directory are allowed, but they are not the canonical benchmark outputs. + +Artifact semantics: + +- `best_model.zip` is the paper-facing artifact for that seed +- `last_model.zip` is retained for debugging and reproducibility only +- `best_checkpoint.json` records the selected training step and selection metric values + +### 3.1.1 Frozen checkpoint metric schema + +`checkpoint_metrics.json` should be a JSON array. Each element should have this structure: + +```json +{ + "algo": "ppo", + "seed": 11, + "train_steps": 20000, + "selected_for_best": false, + "splits": { + "train": { + "mean_return": 0.0, + "asset_survival_rate": 0.0, + "containment_success_rate": 0.0, + "mean_burned_area_fraction": 0.0, + "mean_time_to_containment": null, + "mean_resource_efficiency": 0.0, + "wasted_deployment_rate": 0.0, + "episodes": 20 + }, + "val": { + "mean_return": 0.0, + "asset_survival_rate": 0.0, + "containment_success_rate": 0.0, + "mean_burned_area_fraction": 0.0, + "mean_time_to_containment": null, + "mean_resource_efficiency": 0.0, + "wasted_deployment_rate": 0.0, + "episodes": 20 + } + } +} +``` + +If a family-holdout evaluation is enabled during development, it should appear under a distinct key such as: + +```json +{ + "splits": { + "family_holdout": { + "mean_return": 0.0, + "asset_survival_rate": 0.0, + "containment_success_rate": 0.0, + "mean_burned_area_fraction": 0.0, + "mean_time_to_containment": null, + "mean_resource_efficiency": 0.0, + "wasted_deployment_rate": 0.0, + "episodes": 20 + } + } +} +``` + +The training runner should update `selected_for_best` only after the full checkpoint comparison is complete. + +Temporal holdout results, if produced, should be written only in final evaluation artifacts and clearly labeled as single-record diagnostic results when that constraint still applies. + +### 3.2 Environment construction by algorithm + +Use the same benchmark environment contract for all methods. + +- `PPO`: vectorized envs +- `A2C`: vectorized envs +- `DQN`: single env by default + +Rationale: + +- `WildfireEnv` already exposes a flat observation and discrete action space compatible with all three methods. +- The main algorithm-specific difference is that `DQN` should not reuse the parallel vectorized setup intended for `PPO` and `A2C`. + +Implementation requirement for benchmark metrics: + +- the environment `info` payload or evaluator-accessible state must expose `assets_lost`, `step`, successful deployment counts, wasted deployment counts, and total deployment attempt counts +- these counters are required for `mean_time_to_containment`, `mean_resource_efficiency`, and `wasted_deployment_rate` +- zero-denominator cases for deployment-based metrics must be handled explicitly as defined in Section 4 + +### 3.3 Unified evaluation runner + +Extend `src/models/evaluate_agents.py` so it can evaluate: + +- `ppo` +- `a2c` +- `dqn` +- `greedy` +- `random` + +The rollout loop can remain shared because all learned methods expose `model.predict(...)`. + +Add a matching benchmark-safe preset in code for evaluation so canonical runs do not rely on ad hoc CLI arguments. + +The evaluation preset should: + +- use the canonical split dataset paths by default +- default to `100` episodes for train/val +- use the benchmark metric schema defined in this file +- evaluate the chosen artifact explicitly, such as `best_model.zip` or `last_model.zip` + +Canonical evaluation outputs should distinguish: + +- `train` +- `val` +- optional `family_holdout` +- optional `temporal_holdout_diagnostic` + +--- + +## 4) Metric Definitions + +These definitions must be used consistently in code, plots, tables, and paper text. + +Core metrics: + +1. `mean_return` + - Mean episodic return over evaluation episodes. +2. `asset_survival_rate` + - Fraction of episodes with `assets_lost == 0`. +3. `containment_success_rate` + - Fraction of episodes that terminate by extinguishing the fire before truncation. +4. `mean_burned_area_fraction` + - Mean final burned-area fraction. + - Per episode: `(burned + burning + asset_burned cells) / 625`. +5. `std_across_seeds` + - Standard deviation of the seed-level metric means, not pooled episode variance. + +Secondary metrics: + +1. `mean_time_to_containment` + - Mean step of containment over only the episodes where containment occurs. + - Report `null` or `NA` when no episode in the slice is contained. +2. `mean_resource_efficiency` + - Mean of `successful_deployments / total_deployments`. + - A successful deployment changes at least one cell state through suppression/firebreak action. + - If `total_deployments == 0`, report `0.0`. +3. `wasted_deployment_rate` + - Mean of `wasted_deployments / total_deployments_attempted`. + - If `total_deployments_attempted == 0`, report `0.0`. +4. `heldout_performance_drop` + - Difference between validation or holdout performance and train performance for the same metric. + - Use `asset_survival_rate` as the primary reported generalization drop. +5. `mean_normalized_burn_ratio` + - `final_burned_area_with_policy / final_burned_area_under_non_intervention_baseline_on_same_seed_same_record`. + - Evaluation-only diagnostic. + +Definition of `non_intervention_baseline`: + +- This is not a trained policy. +- This is not a second task. +- It is a deterministic evaluation-only baseline that never deploys suppression. +- Under the frozen 6-action benchmark, implement it as a deterministic movement-only policy because `WAIT` is not part of the action space. +- Its purpose is to estimate the damage level when the agent does not intervene with helicopter or crew actions. + +Reporting rule: + +- Final paper tables should report seed-level `mean +- std` for the core metrics. +- Secondary metrics can appear in a second table, appendix, or supplemental JSON. +- `mean_normalized_burn_ratio` should remain a secondary diagnostic, not a headline benchmark metric. + +Metrics not to use as the only model-selection criterion: + +- pooled episode variance +- raw final burned cell count without normalization +- reward alone without asset survival context + +--- + +## 5) Training-Time Logging + +Checkpoint evaluation is required to make the training process auditable. + +At each checkpoint, record per split: + +- `mean_return` +- `asset_survival_rate` +- `containment_success_rate` +- `mean_burned_area_fraction` +- `mean_time_to_containment` +- `mean_resource_efficiency` +- `wasted_deployment_rate` + +Checkpoint-visible splits for canonical runs: + +- `train` +- `val` +- optional `family_holdout` + +Do not log or inspect temporal holdout metrics during checkpoint evaluation, pilot tuning, or model selection. + +The checkpoint metric file is the source of truth for best-checkpoint selection. Do not infer best-checkpoint status later from console logs. + +Checkpoint plots to generate from the logs: + +- validation `asset_survival_rate` vs env steps +- validation `mean_return` vs env steps +- train/val gap for `asset_survival_rate` +- train/family-holdout gap for `asset_survival_rate` after final training when family holdout is enabled + +Model selection rule: + +- Select the best checkpoint by highest validation `asset_survival_rate`. +- Use validation `mean_return` as tie-breaker. +- Do not select by holdout performance. +- Save both `best_model.zip` and `last_model.zip` for every seed. +- Use `best_model.zip` for the final paper-facing evaluation unless an ablation explicitly studies last-checkpoint behavior. + +--- + +## 6) Hyperparameter Strategy + +Do not choose final hyperparameters by naked-eye inspection alone. + +Use a small, fixed validation sweep. + +### 6.1 Tuning policy + +1. Use a small coarse sweep on the validation split. +2. Tune only a few high-impact knobs per algorithm. +3. Freeze the chosen configuration before the full 5-seed benchmark. +4. Do not change hyperparameters after viewing family-holdout or temporal-holdout results. + +### 6.2 Sweep budget + +Recommended tuning budget per algorithm: + +- short pilot runs only +- `1` seed per candidate config +- smaller training budget than final runs + +The goal is to eliminate obviously bad settings, not to over-optimize. + +### 6.3 Parameters to tune + +`PPO` + +- learning rate +- `n_steps` +- entropy coefficient + +`A2C` + +- learning rate +- `n_steps` +- entropy coefficient + +`DQN` + +- learning rate +- exploration fraction / final epsilon +- target update interval +- replay buffer size + +### 6.4 Selection rule + +Choose the config with the best validation `asset_survival_rate` at the end of the pilot budget. + +Tie-breakers: + +1. validation `mean_return` +2. validation `containment_success_rate` +3. lower instability across repeated quick checks if needed + +--- + +## 7) Verification Ladder Before Full Training + +Run the following checks in order. Do not launch full 5-seed benchmarks until each prior stage passes. + +### 7.1 Environment and data contract check + +Verify that: + +- train/val/holdout seeded files load in benchmark mode +- split mismatches fail fast +- explicit `scenario_families` are passed for canonical train/val/family-holdout runs +- reset/step terminate correctly +- observations remain length `636` +- actions remain `Discrete(6)` + +### 7.2 Algorithm structure smoke test + +For each of `PPO`, `A2C`, `DQN`: + +1. instantiate the model against the benchmark env +2. run a very short training job +3. save and reload the model +4. run a short deterministic evaluation rollout +5. verify that both `best_model.zip` and `last_model.zip` are emitted when checkpointing is enabled + +This verifies: + +- the chosen env wrapper is compatible with the algorithm +- the model serialization path works +- `evaluate_agents.py` can load and score the model + +### 7.3 Checkpoint logging test + +Run one short training job with checkpoint evaluation enabled and verify that: + +- checkpoints fire at the expected step interval +- per-split metrics are written to JSON +- the best-checkpoint selection rule behaves as expected +- no fallback heuristic contaminates learned-agent evaluation +- the benchmark preset produces the frozen protocol values without needing manual CLI reconstruction +- temporal holdout metrics do not appear in checkpoint artifacts for canonical runs + +### 7.4 Reward sanity pilot + +Run `PPO` for `20,000` steps on one seed and confirm: + +- asset-loss penalties appear in the return trace +- returns are not numerically unstable +- the agent learns something better than random on train at minimum + +If reward coefficients must change, do it once here and refreeze. + +### 7.5 Per-algorithm pilot benchmark + +For each of `DQN`, `A2C`, `PPO`, run a short pilot with the candidate hyperparameters and compare against: + +- `random` +- `greedy` + +Only proceed to full benchmark if: + +- the learned method beats `random` on validation `asset_survival_rate`, or +- the run is stable enough to justify a final budget run + +--- + +## 8) Benchmark Execution Order + +1. Implement unified train/eval support for `ppo`, `a2c`, `dqn`. +2. Add benchmark-safe presets for training and evaluation. +3. Fix evaluator metric definitions to match this document. +4. Add checkpoint evaluation, config serialization, and canonical artifact writing. +5. Run smoke tests for all methods. +6. Run one-seed pilot tuning sweeps. +7. Freeze one config per algorithm. +8. Run full 5-seed training for `DQN`, `A2C`, `PPO`. +9. Evaluate all learned methods plus `greedy`, `random`, and the non-intervention baseline on train/val and optional family holdout. +10. Run temporal holdout evaluation only as a final, separately labeled diagnostic until the holdout dataset is expanded. +11. Aggregate seed-level means and standard deviations. +12. Produce final plots and paper tables. + +--- + +## 9) Minimum Paper-Ready Outputs + +The benchmark is ready for reporting when the following artifacts exist. + +Per learned algorithm: + +- `best_model.zip` and `last_model.zip` for each seed +- checkpoint metric logs +- final split-wise evaluation JSON for the best checkpoint +- selected hyperparameter config + +Aggregate benchmark outputs: + +- table of train/val and optional family-holdout `mean +- std` across 5 seeds +- learning curve for at least `PPO` +- holdout comparison figure across methods only when the chosen holdout benchmark is credible for that claim +- explicit note of the model-selection rule and tuning budget +- explicit note that the non-intervention baseline is an evaluation-only secondary diagnostic +- explicit note that single-record temporal holdout results, if shown, are stress-test diagnostics rather than full benchmark evidence + +--- + +## 10) Non-Goals + +This plan does not include: + +- recurrent policies +- hidden-regime-shift training +- large-scale automated hyperparameter optimization +- modifying the reward during final benchmarking +- using holdout performance for tuning diff --git a/docs/rl_agent.md b/docs/rl_agent.md deleted file mode 100644 index 74091f1..0000000 --- a/docs/rl_agent.md +++ /dev/null @@ -1,137 +0,0 @@ -# FireGrid RL Tactical Agent — Architecture & Results - -## What Is It? - -The RL Tactical Agent is the second half of the FireGrid ML pipeline. - -- **Stage 1 (XGBoost)** answers: *"How far will this fire spread?"* -- **Stage 2 (PPO Agent)** answers: *"Where should we deploy helicopters and ground crews to stop it?"* - ---- - -## Architecture: Cellular Automata + PPO - -### The Fire Simulator (`models/fire_env.py`) - -The agent learns inside a 50×50 grid-based wildfire simulator built as a [`gymnasium`](https://gymnasium.farama.org/) environment — the standard interface for RL research. - -``` -Grid cell states: - 0 = Unburned fuel - 1 = Actively burning - 2 = Burned / scorched - 3 = Suppressed (firebreak / retardant) -``` - -**Key design:** The fire doesn't spread at a fixed rate. The spread probability per timestep is computed from the XGBoost output: -```python -spread_prob = min(0.85, xgboost_spread_1h_m / 250) -``` -High ISI + low humidity = fast spread → harder game for the agent. - -### Agent Actions (6 discrete) - -| Action | Effect | -|---|---| -| Move N / S / E / W | Navigate to best tactical position | -| Deploy Helicopter (4) | Suppresses 3×3 cell area around agent with retardant | -| Deploy Ground Crew (5) | Places 1-cell firebreak at agent position | - -### Reward Function - -The agent is rewarded or penalized after every timestep: - -| Event | Reward | -|---|---| -| Suppress a burning cell | **+3 per cell** | -| Proactive firebreak in fire's path | **+2** | -| Fire extinguished entirely | **+100 bonus** | -| Each new cell catching fire | **−0.5 per cell** | -| Wasting resources on burned land | **−3** | -| Wasted helicopter drop (no burning cells nearby) | **−2** | - -### Policy: Proximal Policy Optimization (PPO) - -PPO is the industry standard RL algorithm (same family as OpenAI's InstructGPT). It learns by playing thousands of fire simulations, updating its policy using the reward signal. Implemented via [`stable-baselines3`](https://stable-baselines3.readthedocs.io/). - -**Observation space:** Flat 2,502-dimensional vector: -- 2,500 = 50×50 grid values (normalized 0–1) -- 2 = agent position (x, y) - -**Policy network:** `MlpPolicy` (3-layer MLP, 64 hidden units per layer) - ---- - -## Training Results - -**Hyperparameters:** -- Total timesteps: 50,000 -- Parallel environments: 4 -- Learning rate: 3×10⁻⁴, γ = 0.995, ε_clip = 0.2 -- ~2 minutes on CPU - -**Reward curve (key milestones):** - -| Timestep | Avg Episode Reward | Interpretation | -|---|---|---| -| 2,048 | **-84.8** | Agent wanders randomly, fire burns entire grid | -| 10,240 | **-14.8** | Agent starts locating fire and deploying | -| 26,624 | **+6.28** | Crosses zero — agent is net suppressing fire | -| 51,200 | **+22.6** | Agent consistently initiates suppression early | - -The clean monotonic improvement from -84 to +22 indicates stable learning. - -**Saved weights:** `backend/src/models/tactical_ppo_agent.zip` - ---- - -## Inference & Greedy Safety Fallback - -At inference time, the agent runs 60 deterministic steps in the simulator. Every time it takes a `Deploy Helicopter` or `Deploy Ground Crew` action, the grid cell is converted back to real-world lat/lon and added to the output. - -**Why a fallback?** RL agents can sometimes fail to trigger a deployment action within the inference horizon, especially on novel initial states. To guarantee the frontend always has tactical lines to draw (critical for a live demo), the system includes a **Greedy Geometric Fallback**. If the PPO produces no deployments, it places 5 waypoints geometrically around the fire's 1-hour perimeter at cardinal directions. - ---- - -## API - -``` -GET /api/v1/choke_points/{fire_id} -``` - -**Pipeline:** -1. Fetches fire record from DynamoDB (lat/lon, area) -2. Calls XGBoost → gets `spread_1h_m` to set fire danger level for the simulator -3. Runs PPO agent inference (or greedy fallback) -4. Returns: -```json -{ - "fire_id": "BC-2026-001", - "spread_1h_m": 465, - "spread_3h_m": 1541, - "waypoints": [ - { - "latitude": 49.912, - "longitude": -119.501, - "asset_type": "helicopter", - "rationale": "PPO recommended helicopter deployment", - "source": "ppo_agent" - } - ] -} -``` - ---- - -## Run It - -```powershell -# Retrain the agent (optional — weights already saved) -uv run python -m src.models.train_rl_agent - -# Quick 30-second training test -uv run python -m src.models.train_rl_agent --timesteps 10000 - -# Test via API (backend must be running) -Invoke-RestMethod http://localhost:8000/api/v1/choke_points/BC-2026-001 -``` diff --git a/src/models/fire_env.py b/src/models/fire_env.py index 202aa28..6ae1537 100644 --- a/src/models/fire_env.py +++ b/src/models/fire_env.py @@ -639,6 +639,10 @@ def __init__( self.crew_cd: int = 0 self.assets_lost: int = 0 self.initial_asset_count: int = 0 + self.successful_heli_deployments: int = 0 + self.successful_crew_deployments: int = 0 + self.wasted_deployment_attempts: int = 0 + self.total_deployment_attempts: int = 0 self._ignition_seed_used: int | None = None self._layout_seed_used: int | None = None self._ignition_rng: np.random.Generator | None = None @@ -655,6 +659,10 @@ def reset(self, seed: int | None = None, options: dict | None = None): self.grid = np.zeros((self.grid_size, self.grid_size), dtype=np.int32) self.step_count = 0 self.assets_lost = 0 + self.successful_heli_deployments = 0 + self.successful_crew_deployments = 0 + self.wasted_deployment_attempts = 0 + self.total_deployment_attempts = 0 self._active_parameter_record = None self._active_record_id = self._scenario.record_id self._ignition_seed_used = None @@ -774,6 +782,12 @@ def step(self, action: int): "step": self.step_count, "assets_lost": self.assets_lost, "assets_remaining": self.initial_asset_count - self.assets_lost, + "successful_heli_deployments": self.successful_heli_deployments, + "successful_crew_deployments": self.successful_crew_deployments, + "successful_deployments": self.successful_heli_deployments + + self.successful_crew_deployments, + "wasted_deployment_attempts": self.wasted_deployment_attempts, + "total_deployment_attempts": self.total_deployment_attempts, "heli_left": self.heli_left, "crew_left": self.crew_left, "scenario": self._scenario, @@ -1014,6 +1028,7 @@ def _execute_action(self, action: int) -> tuple[float, bool, bool]: elif action == MOVE_W and c > 0: self.agent_pos[1] -= 1 elif action == DEPLOY_HELICOPTER: + self.total_deployment_attempts += 1 if self.heli_left > 0 and self.heli_cd == 0: # Suppress 3x3 area around agent suppressed = 0 @@ -1029,26 +1044,34 @@ def _execute_action(self, action: int) -> tuple[float, bool, bool]: self.heli_cd = self.heli_cooldown_duration heli_used = True if suppressed > 0: + self.successful_heli_deployments += 1 reward += suppressed * 3.0 else: + self.wasted_deployment_attempts += 1 reward -= 1.0 # wasted else: + self.wasted_deployment_attempts += 1 reward -= 1.0 # blocked by budget or cooldown elif action == DEPLOY_CREW: + self.total_deployment_attempts += 1 if self.crew_left > 0 and self.crew_cd == 0: cell = self.grid[r, c] if cell == BURNING: self.grid[r, c] = SUPPRESSED + self.successful_crew_deployments += 1 reward += 3.0 elif cell == UNBURNED: self.grid[r, c] = SUPPRESSED + self.successful_crew_deployments += 1 reward += 2.0 # firebreak else: + self.wasted_deployment_attempts += 1 reward -= 1.0 # wasted self.crew_left -= 1 self.crew_cd = self.crew_cooldown_duration crew_used = True else: + self.wasted_deployment_attempts += 1 reward -= 1.0 # blocked by budget or cooldown return reward, heli_used, crew_used From f90659a515e23993e1e76bb1fe1f4e0474116125 Mon Sep 17 00:00:00 2001 From: Thomson Lam Date: Mon, 30 Mar 2026 15:00:44 -0400 Subject: [PATCH 2/2] feat: training code + benchmark, sh & ps1 wraper scripts for reproducing --- .gitignore | 1 + README.md | 124 ++++- scripts/run_benchmark_eval.ps1 | 88 ++++ scripts/run_benchmark_eval.sh | 93 ++++ scripts/run_benchmark_train.ps1 | 98 ++++ scripts/run_benchmark_train.sh | 96 ++++ src/models/benchmarking.py | 428 +++++++++++++++ src/models/evaluate_agents.py | 390 ++++++-------- src/models/rl_agent.py | 164 ------ src/models/train_rl_agent.py | 612 ++++++++++++++-------- tests/models/test_benchmarking_metrics.py | 132 +++++ 11 files changed, 1622 insertions(+), 604 deletions(-) create mode 100644 scripts/run_benchmark_eval.ps1 create mode 100755 scripts/run_benchmark_eval.sh create mode 100644 scripts/run_benchmark_train.ps1 create mode 100755 scripts/run_benchmark_train.sh create mode 100644 src/models/benchmarking.py delete mode 100644 src/models/rl_agent.py create mode 100644 tests/models/test_benchmarking_metrics.py diff --git a/.gitignore b/.gitignore index 4e951d0..9ff611a 100644 --- a/.gitignore +++ b/.gitignore @@ -11,6 +11,7 @@ wheels/ .env data/ +outputs/ #logs wandb diff --git a/README.md b/README.md index b80e33a..42c6371 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,68 @@ Empirical RL benchmark for wildfire tactical suppression. Compares DQN, A2C, PPO, and heuristic baselines on a 25x25 grid environment with critical assets and finite suppression budgets. +The physics informed environment and built environment records from the [Alberta Historical Wildfires Database](https://open.alberta.ca/opendata/wildfire-data). + +## Project Tree + +```text +firebot/ +├── README.md +├── pyproject.toml +├── uv.lock +├── ruff.toml +├── lefthook.yml +├── fp-historical-wildfire-data-dictionary-2006-2025.pdf # from the dataset download +├── data/ +│ └── static/ +│ ├── fp-historical-wildfire-data-2006-2025.csv # raw Alberta historical wildfire CSV +│ ├── snapshot_records.json # full normalized snapshot records from raw CSV +│ ├── snapshot_records_train.json # train-year snapshot subset +│ ├── snapshot_records_val.json # validation-year snapshot subset +│ ├── snapshot_records_holdout.json # holdout-year snapshot subset +│ ├── scenario_parameter_records.json # full unseeded environment parameter records +│ ├── scenario_parameter_records_train.json # train split unseeded records +│ ├── scenario_parameter_records_val.json # validation split unseeded records +│ ├── scenario_parameter_records_holdout.json # holdout split unseeded records +│ ├── scenario_parameter_records_seeded.json # full seeded records with ignition/layout seeds +│ ├── scenario_parameter_records_seeded_train.json # train runtime records +│ ├── scenario_parameter_records_seeded_val.json # validation runtime records +│ └── scenario_parameter_records_seeded_holdout.json # temporal holdout runtime records +├── docs/ +│ ├── data-pipeline.md +│ ├── envspec.md +│ └── planning/ +│ ├── env-checklist.md +│ ├── impl-plan.md +│ └── train-plan.md +├── src/ +│ ├── __init__.py +│ ├── ingestion/ +│ │ ├── __init__.py +│ │ ├── clean_historical.py # row cleaning and required-field checks +│ │ ├── cffdrs.py # CFFDRS station ingestion, not used +│ │ ├── weather.py # legacy Open-Meteo weather fetch helpers, not used +│ │ └── static_dataset.py # builds snapshot/scenario parameter records in data/static +│ └── models/ # environment, training, evaluation, and shared benchmark utilities +│ ├── __init__.py +│ ├── fire_env.py # WildfireEnv implementation and benchmark env construction helpers +│ ├── benchmarking.py # shared benchmark presets, rollout metrics, and aggregation functions +│ ├── train_rl_agent.py # unified PPO/A2C/DQN trainer with checkpoint and final evaluation artifacts +│ └── evaluate_agents.py # classdef for PPO/A2C/DQN plus greedy/random baselines +├── scripts/ +│ ├── run_benchmark_train.sh # bash script for smoke validation then full 5-seed benchmark training +│ ├── run_benchmark_train.ps1 # powershell equivalent +│ ├── run_benchmark_eval.sh # bash script for post-training benchmark evaluation by seed +│ └── run_benchmark_eval.ps1 # powershell equivalent +├── tests/ +│ ├── conftest.py +│ └── models/ # environment and benchmark metric contract tests +│ ├── test_fire_env_setup_contract.py # benchmark-mode env loading/split/schema contract tests +│ └── test_benchmarking_metrics.py # benchmark metric/preset/aggregation tests +├── outputs/ # generated training and evaluation artifacts (gitignored) +└── drd-archive/ # archived prototype code from the earlier DRD proposal +``` + ## Setup Requirements: [uv](https://docs.astral.sh/uv/getting-started/installation/) @@ -114,30 +176,68 @@ uv run python -m src.ingestion.static_dataset --fire-records path/to/fire_record ### Training -After building the dataset, you can train by running: +For controlled and reproducible benchmark training, use the script wrappers in `scripts/`. + +Run from project root on macOS/Linux (bash): ```bash -uv run python -m src.models.train_rl_agent --scenario-dataset data/static/scenario_parameter_records_seeded_train.json --val-dataset data/static/scenario_parameter_records_seeded_val.json --holdout-dataset data/static/scenario_parameter_records_seeded_holdout.json +./scripts/run_benchmark_train.sh ``` -The seeded scenario parameter files are the canonical benchmark inputs for `FireEnv` and PPO training. +Run from project root on Windows (PowerShell): -The builder also writes year-based split files for the benchmark: +```powershell +./scripts/run_benchmark_train.ps1 +``` -- `train`: `2006-2022` -- `val`: `2023` -- `holdout`: `2024-2025` +Script runs: + +- Stage 1 (smoke): runs short validation training for `ppo`, `a2c`, `dqn` on one seed +- Stage 2 (smoke eval): loads smoke `best_model.zip` artifacts and runs evaluator sanity checks +- Stage 3 (formal): runs full canonical training for all three algorithms across 5 seeds (`11,22,33,44,55`) +- Uses artifact root `outputs/benchmark/` and keeps default trainer settings for env count, timesteps, and checkpoint cadence on formal runs + +Training script environment overrides (optional): + +- `ARTIFACT_ROOT` (default `outputs/benchmark`) +- `SMOKE_TIMESTEPS` (default `20000`, one canonical checkpoint interval) +- `SMOKE_SEED` (default `11`) +- `SMOKE_EVAL_EPISODES` (default `5`) +- `FINAL_SEEDS_CSV` (default `11,22,33,44,55`) +- `ALGO_ORDER_CSV` (default `ppo,a2c,dqn`) -Training command: +After `run_benchmark_train` completes, run benchmark evaluation wrappers. + +Run from project root on macOS/Linux (bash): ```bash -uv run python -m src.models.train_rl_agent --scenario-dataset data/static/scenario_parameter_records_seeded_train.json --val-dataset data/static/scenario_parameter_records_seeded_val.json --holdout-dataset data/static/scenario_parameter_records_seeded_holdout.json +./scripts/run_benchmark_eval.sh ``` -General split benchmark evaluation (PPO + baselines): +Run from project root on Windows (PowerShell): -```bash -uv run python -m src.models.evaluate_agents --agents ppo,greedy,random --train-dataset data/static/scenario_parameter_records_seeded_train.json --val-dataset data/static/scenario_parameter_records_seeded_val.json --holdout-dataset data/static/scenario_parameter_records_seeded_holdout.json --episodes 20 --seeds 42,43,44 +```powershell +./scripts/run_benchmark_eval.ps1 ``` +These are the default values that can be overridden via env ars or editing the `ps1` and `.sh` scripts. + +- `ARTIFACT_ROOT` (default `outputs/benchmark`) +- `RUN_LABEL` (default `final`) +- `EVAL_SEEDS_CSV` (default `11,22,33,44,55`) +- `EVAL_EPISODES` (default `100`) +- `AGENTS` (default `ppo,a2c,dqn,greedy,random`) +- `OUTPUT_DIR` (default `outputs/benchmark//eval`) +- `INCLUDE_FAMILY_HOLDOUT` (`0` or `1`, default `0`) +- `INCLUDE_TEMPORAL_HOLDOUT` (`0` or `1`, default `0`) +- `NO_NORMALIZED_BURN` (`0` or `1`, default `0`) + +The seeded scenario parameter files are the benchmark inputs for `FireEnv` training and script-driven evaluation. + +The builder also writes year-based split files for the benchmark: + +- `train`: `2006-2022` +- `val`: `2023` +- `holdout`: `2024-2025` + The dataset builder prints cleaning/drop summaries to stdout and uses progress bars when `tqdm` is available. diff --git a/scripts/run_benchmark_eval.ps1 b/scripts/run_benchmark_eval.ps1 new file mode 100644 index 0000000..00bfa39 --- /dev/null +++ b/scripts/run_benchmark_eval.ps1 @@ -0,0 +1,88 @@ +Set-StrictMode -Version Latest +$ErrorActionPreference = "Stop" + +$RootDir = Split-Path -Parent $PSScriptRoot +Set-Location $RootDir + +if (-not (Get-Command uv -ErrorAction SilentlyContinue)) { + Write-Error "'uv' is not installed or not on PATH. Install: https://docs.astral.sh/uv/getting-started/installation/" +} + +$ArtifactRoot = if ($env:ARTIFACT_ROOT) { $env:ARTIFACT_ROOT } else { "outputs/benchmark" } +$RunLabel = if ($env:RUN_LABEL) { $env:RUN_LABEL } else { "final" } +$EvalSeedsCsv = if ($env:EVAL_SEEDS_CSV) { $env:EVAL_SEEDS_CSV } else { "11,22,33,44,55" } +$EvalEpisodes = if ($env:EVAL_EPISODES) { [int]$env:EVAL_EPISODES } else { 100 } +$Agents = if ($env:AGENTS) { $env:AGENTS } else { "ppo,a2c,dqn,greedy,random" } + +$TrainDataset = if ($env:TRAIN_DATASET) { $env:TRAIN_DATASET } else { "data/static/scenario_parameter_records_seeded_train.json" } +$ValDataset = if ($env:VAL_DATASET) { $env:VAL_DATASET } else { "data/static/scenario_parameter_records_seeded_val.json" } +$HoldoutDataset = if ($env:HOLDOUT_DATASET) { $env:HOLDOUT_DATASET } else { "data/static/scenario_parameter_records_seeded_holdout.json" } + +$DefaultOutputDir = Join-Path -Path $ArtifactRoot -ChildPath "$RunLabel/eval" +$OutputDir = if ($env:OUTPUT_DIR) { $env:OUTPUT_DIR } else { $DefaultOutputDir } +$IncludeFamilyHoldout = if ($env:INCLUDE_FAMILY_HOLDOUT) { [int]$env:INCLUDE_FAMILY_HOLDOUT } else { 0 } +$IncludeTemporalHoldout = if ($env:INCLUDE_TEMPORAL_HOLDOUT) { [int]$env:INCLUDE_TEMPORAL_HOLDOUT } else { 0 } +$NoNormalizedBurn = if ($env:NO_NORMALIZED_BURN) { [int]$env:NO_NORMALIZED_BURN } else { 0 } + +foreach ($dataset in @($TrainDataset, $ValDataset, $HoldoutDataset)) { + if (-not (Test-Path $dataset)) { + Write-Error "Missing dataset '$dataset'." + } +} + +$EvalSeeds = $EvalSeedsCsv -split "," | ForEach-Object { $_.Trim() } | Where-Object { $_ -ne "" } +New-Item -ItemType Directory -Force -Path $OutputDir | Out-Null + +Write-Host "== Benchmark evaluation configuration ==" +Write-Host "artifact_root : $ArtifactRoot" +Write-Host "run_label : $RunLabel" +Write-Host "eval_seeds : $EvalSeedsCsv" +Write-Host "eval_episodes_per_seed : $EvalEpisodes" +Write-Host "agents : $Agents" +Write-Host "output_dir : $OutputDir" +Write-Host "" + +foreach ($seed in $EvalSeeds) { + $PpoModel = Join-Path -Path $ArtifactRoot -ChildPath "$RunLabel/ppo/seed_$seed/best_model.zip" + $A2cModel = Join-Path -Path $ArtifactRoot -ChildPath "$RunLabel/a2c/seed_$seed/best_model.zip" + $DqnModel = Join-Path -Path $ArtifactRoot -ChildPath "$RunLabel/dqn/seed_$seed/best_model.zip" + + foreach ($modelPath in @($PpoModel, $A2cModel, $DqnModel)) { + if (-not (Test-Path $modelPath)) { + Write-Error "Missing model '$modelPath'. Run training first: ./scripts/run_benchmark_train.ps1" + } + } + + $OutputJson = Join-Path -Path $OutputDir -ChildPath "seed_$seed.json" + Write-Host "[EVAL] seed=$seed -> $OutputJson" + + $CmdArgs = @( + "run", "python", "-m", "src.models.evaluate_agents", + "--agents", $Agents, + "--ppo-model", $PpoModel, + "--a2c-model", $A2cModel, + "--dqn-model", $DqnModel, + "--train-dataset", $TrainDataset, + "--val-dataset", $ValDataset, + "--holdout-dataset", $HoldoutDataset, + "--seeds", $seed, + "--episodes", "$EvalEpisodes", + "--run-label", $RunLabel, + "--output", $OutputJson + ) + + if ($IncludeFamilyHoldout -eq 1) { + $CmdArgs += "--include-family-holdout" + } + if ($IncludeTemporalHoldout -eq 1) { + $CmdArgs += "--include-temporal-holdout" + } + if ($NoNormalizedBurn -eq 1) { + $CmdArgs += "--no-normalized-burn" + } + + & uv @CmdArgs +} + +Write-Host "" +Write-Host "Evaluation complete. Results are in '$OutputDir'." diff --git a/scripts/run_benchmark_eval.sh b/scripts/run_benchmark_eval.sh new file mode 100755 index 0000000..12e0e79 --- /dev/null +++ b/scripts/run_benchmark_eval.sh @@ -0,0 +1,93 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "$ROOT_DIR" + +if ! command -v uv >/dev/null 2>&1; then + echo "ERROR: 'uv' is not installed or not on PATH." + echo "Install: https://docs.astral.sh/uv/getting-started/installation/" + exit 1 +fi + +ARTIFACT_ROOT="${ARTIFACT_ROOT:-outputs/benchmark}" +RUN_LABEL="${RUN_LABEL:-final}" +EVAL_SEEDS_CSV="${EVAL_SEEDS_CSV:-11,22,33,44,55}" +EVAL_EPISODES="${EVAL_EPISODES:-100}" +AGENTS="${AGENTS:-ppo,a2c,dqn,greedy,random}" + +TRAIN_DATASET="${TRAIN_DATASET:-data/static/scenario_parameter_records_seeded_train.json}" +VAL_DATASET="${VAL_DATASET:-data/static/scenario_parameter_records_seeded_val.json}" +HOLDOUT_DATASET="${HOLDOUT_DATASET:-data/static/scenario_parameter_records_seeded_holdout.json}" + +OUTPUT_DIR="${OUTPUT_DIR:-$ARTIFACT_ROOT/$RUN_LABEL/eval}" +INCLUDE_FAMILY_HOLDOUT="${INCLUDE_FAMILY_HOLDOUT:-0}" +INCLUDE_TEMPORAL_HOLDOUT="${INCLUDE_TEMPORAL_HOLDOUT:-0}" +NO_NORMALIZED_BURN="${NO_NORMALIZED_BURN:-0}" + +for dataset in "$TRAIN_DATASET" "$VAL_DATASET" "$HOLDOUT_DATASET"; do + if [[ ! -f "$dataset" ]]; then + echo "ERROR: Missing dataset '$dataset'." + exit 1 + fi +done + +IFS=',' read -r -a EVAL_SEEDS <<< "$EVAL_SEEDS_CSV" +mkdir -p "$OUTPUT_DIR" + +echo "== Benchmark evaluation configuration ==" +echo "artifact_root : $ARTIFACT_ROOT" +echo "run_label : $RUN_LABEL" +echo "eval_seeds : $EVAL_SEEDS_CSV" +echo "eval_episodes_per_seed : $EVAL_EPISODES" +echo "agents : $AGENTS" +echo "output_dir : $OUTPUT_DIR" +echo + +for seed in "${EVAL_SEEDS[@]}"; do + seed_trimmed="${seed// /}" + ppo_model="$ARTIFACT_ROOT/$RUN_LABEL/ppo/seed_${seed_trimmed}/best_model.zip" + a2c_model="$ARTIFACT_ROOT/$RUN_LABEL/a2c/seed_${seed_trimmed}/best_model.zip" + dqn_model="$ARTIFACT_ROOT/$RUN_LABEL/dqn/seed_${seed_trimmed}/best_model.zip" + + for model_path in "$ppo_model" "$a2c_model" "$dqn_model"; do + if [[ ! -f "$model_path" ]]; then + echo "ERROR: Missing model '$model_path'." + echo "Run training first: ./scripts/run_benchmark_train.sh" + exit 1 + fi + done + + output_json="$OUTPUT_DIR/seed_${seed_trimmed}.json" + echo "[EVAL] seed=$seed_trimmed -> $output_json" + + cmd=( + uv run python -m src.models.evaluate_agents + --agents "$AGENTS" + --ppo-model "$ppo_model" + --a2c-model "$a2c_model" + --dqn-model "$dqn_model" + --train-dataset "$TRAIN_DATASET" + --val-dataset "$VAL_DATASET" + --holdout-dataset "$HOLDOUT_DATASET" + --seeds "$seed_trimmed" + --episodes "$EVAL_EPISODES" + --run-label "$RUN_LABEL" + --output "$output_json" + ) + + if [[ "$INCLUDE_FAMILY_HOLDOUT" == "1" ]]; then + cmd+=(--include-family-holdout) + fi + if [[ "$INCLUDE_TEMPORAL_HOLDOUT" == "1" ]]; then + cmd+=(--include-temporal-holdout) + fi + if [[ "$NO_NORMALIZED_BURN" == "1" ]]; then + cmd+=(--no-normalized-burn) + fi + + "${cmd[@]}" +done + +echo +echo "Evaluation complete. Results are in '$OUTPUT_DIR'." diff --git a/scripts/run_benchmark_train.ps1 b/scripts/run_benchmark_train.ps1 new file mode 100644 index 0000000..8f57a19 --- /dev/null +++ b/scripts/run_benchmark_train.ps1 @@ -0,0 +1,98 @@ +Set-StrictMode -Version Latest +$ErrorActionPreference = "Stop" + +$RootDir = Split-Path -Parent $PSScriptRoot +Set-Location $RootDir + +if (-not (Get-Command uv -ErrorAction SilentlyContinue)) { + Write-Error "'uv' is not installed or not on PATH. Install: https://docs.astral.sh/uv/getting-started/installation/" +} + +$ArtifactRoot = if ($env:ARTIFACT_ROOT) { $env:ARTIFACT_ROOT } else { "outputs/benchmark" } +# Default smoke length is one canonical checkpoint interval. +$SmokeTimesteps = if ($env:SMOKE_TIMESTEPS) { [int]$env:SMOKE_TIMESTEPS } else { 20000 } +$SmokeSeed = if ($env:SMOKE_SEED) { [int]$env:SMOKE_SEED } else { 11 } +$SmokeEvalEpisodes = if ($env:SMOKE_EVAL_EPISODES) { [int]$env:SMOKE_EVAL_EPISODES } else { 5 } +$FinalSeedsCsv = if ($env:FINAL_SEEDS_CSV) { $env:FINAL_SEEDS_CSV } else { "11,22,33,44,55" } +$AlgoOrderCsv = if ($env:ALGO_ORDER_CSV) { $env:ALGO_ORDER_CSV } else { "ppo,a2c,dqn" } + +$TrainDataset = "data/static/scenario_parameter_records_seeded_train.json" +$ValDataset = "data/static/scenario_parameter_records_seeded_val.json" +$HoldoutDataset = "data/static/scenario_parameter_records_seeded_holdout.json" + +foreach ($dataset in @($TrainDataset, $ValDataset, $HoldoutDataset)) { + if (-not (Test-Path $dataset)) { + Write-Error "Missing dataset '$dataset'. Run dataset build first: uv run python -m src.ingestion.static_dataset --target-count 50000 --raw-alberta-csv data/static/fp-historical-wildfire-data-2006-2025.csv" + } +} + +$AlgoOrder = $AlgoOrderCsv -split "," | ForEach-Object { $_.Trim() } | Where-Object { $_ -ne "" } +$FinalSeeds = $FinalSeedsCsv -split "," | ForEach-Object { [int]$_.Trim() } + +Write-Host "== Benchmark training configuration ==" +Write-Host "artifact_root : $ArtifactRoot" +Write-Host "algo_order : $AlgoOrderCsv" +Write-Host "smoke_seed : $SmokeSeed" +Write-Host "smoke_timesteps : $SmokeTimesteps" +Write-Host "smoke_eval_episodes: $SmokeEvalEpisodes" +Write-Host "final_seeds : $FinalSeedsCsv" +Write-Host "" +Write-Host "Note: full runs keep default trainer timesteps/envs/checkpoint cadence." +Write-Host "" + +function Invoke-SmokeTrain { + param( + [Parameter(Mandatory = $true)][string]$Algo + ) + + Write-Host "[SMOKE] Training $Algo (seed=$SmokeSeed, timesteps=$SmokeTimesteps)" + uv run python -m src.models.train_rl_agent ` + --algo $Algo ` + --run-label smoke ` + --seed $SmokeSeed ` + --timesteps $SmokeTimesteps ` + --artifact-root $ArtifactRoot +} + +function Invoke-FinalTrain { + param( + [Parameter(Mandatory = $true)][string]$Algo, + [Parameter(Mandatory = $true)][int]$Seed + ) + + Write-Host "[FINAL] Training $Algo (seed=$Seed, default timesteps/envs)" + uv run python -m src.models.train_rl_agent ` + --algo $Algo ` + --run-label final ` + --seed $Seed ` + --artifact-root $ArtifactRoot +} + +Write-Host "== Stage 1/3: Algorithm smoke training ==" +foreach ($algo in $AlgoOrder) { + Invoke-SmokeTrain -Algo $algo +} + +Write-Host "" +# TODO: Check seed! +Write-Host "== Stage 2/3: Smoke evaluation (load + score sanity check) ==" +uv run python -m src.models.evaluate_agents ` + --agents ppo,a2c,dqn,greedy,random ` + --ppo-model "$ArtifactRoot/smoke/ppo/seed_$SmokeSeed/best_model.zip" ` + --a2c-model "$ArtifactRoot/smoke/a2c/seed_$SmokeSeed/best_model.zip" ` + --dqn-model "$ArtifactRoot/smoke/dqn/seed_$SmokeSeed/best_model.zip" ` + --seeds "$SmokeSeed" ` + --episodes $SmokeEvalEpisodes ` + --run-label smoke ` + --output "$ArtifactRoot/smoke/eval_smoke.json" + +Write-Host "" +Write-Host "== Stage 3/3: Full 5-seed benchmark training ==" +foreach ($algo in $AlgoOrder) { + foreach ($seed in $FinalSeeds) { + Invoke-FinalTrain -Algo $algo -Seed $seed + } +} + +Write-Host "" +Write-Host "All runs finished. Artifacts are under '$ArtifactRoot'." diff --git a/scripts/run_benchmark_train.sh b/scripts/run_benchmark_train.sh new file mode 100755 index 0000000..6d4d092 --- /dev/null +++ b/scripts/run_benchmark_train.sh @@ -0,0 +1,96 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "$ROOT_DIR" + +if ! command -v uv >/dev/null 2>&1; then + echo "ERROR: 'uv' is not installed or not on PATH." + echo "Install: https://docs.astral.sh/uv/getting-started/installation/" + exit 1 +fi + +ARTIFACT_ROOT="${ARTIFACT_ROOT:-outputs/benchmark}" +# Default smoke length is one canonical checkpoint interval. +SMOKE_TIMESTEPS="${SMOKE_TIMESTEPS:-20000}" +SMOKE_SEED="${SMOKE_SEED:-11}" +SMOKE_EVAL_EPISODES="${SMOKE_EVAL_EPISODES:-5}" +FINAL_SEEDS_CSV="${FINAL_SEEDS_CSV:-11,22,33,44,55}" +ALGO_ORDER_CSV="${ALGO_ORDER_CSV:-ppo,a2c,dqn}" + +TRAIN_DATASET="data/static/scenario_parameter_records_seeded_train.json" +VAL_DATASET="data/static/scenario_parameter_records_seeded_val.json" +HOLDOUT_DATASET="data/static/scenario_parameter_records_seeded_holdout.json" + +for dataset in "$TRAIN_DATASET" "$VAL_DATASET" "$HOLDOUT_DATASET"; do + if [[ ! -f "$dataset" ]]; then + echo "ERROR: Missing dataset '$dataset'." + echo "Run dataset build first: uv run python -m src.ingestion.static_dataset --target-count 50000 --raw-alberta-csv data/static/fp-historical-wildfire-data-2006-2025.csv" + exit 1 + fi +done + +IFS=',' read -r -a ALGO_ORDER <<< "$ALGO_ORDER_CSV" +IFS=',' read -r -a FINAL_SEEDS <<< "$FINAL_SEEDS_CSV" + +echo "== Benchmark training configuration ==" +echo "artifact_root : $ARTIFACT_ROOT" +echo "algo_order : $ALGO_ORDER_CSV" +echo "smoke_seed : $SMOKE_SEED" +echo "smoke_timesteps : $SMOKE_TIMESTEPS" +echo "smoke_eval_episodes: $SMOKE_EVAL_EPISODES" +echo "final_seeds : $FINAL_SEEDS_CSV" +echo +echo "Note: full runs keep default trainer timesteps/envs/checkpoint cadence." +echo + +train_smoke() { + local algo="$1" + echo "[SMOKE] Training $algo (seed=$SMOKE_SEED, timesteps=$SMOKE_TIMESTEPS)" + uv run python -m src.models.train_rl_agent \ + --algo "$algo" \ + --run-label smoke \ + --seed "$SMOKE_SEED" \ + --timesteps "$SMOKE_TIMESTEPS" \ + --artifact-root "$ARTIFACT_ROOT" +} + +train_final() { + local algo="$1" + local seed="$2" + echo "[FINAL] Training $algo (seed=$seed, default timesteps/envs)" + uv run python -m src.models.train_rl_agent \ + --algo "$algo" \ + --run-label final \ + --seed "$seed" \ + --artifact-root "$ARTIFACT_ROOT" +} + +echo "== Stage 1/3: Algorithm smoke training ==" +for algo in "${ALGO_ORDER[@]}"; do + train_smoke "$algo" +done + +echo +echo "== Stage 2/3: Smoke evaluation (load + score sanity check) ==" +uv run python -m src.models.evaluate_agents \ + --agents ppo,a2c,dqn,greedy,random \ + --ppo-model "$ARTIFACT_ROOT/smoke/ppo/seed_${SMOKE_SEED}/best_model.zip" \ + --a2c-model "$ARTIFACT_ROOT/smoke/a2c/seed_${SMOKE_SEED}/best_model.zip" \ + --dqn-model "$ARTIFACT_ROOT/smoke/dqn/seed_${SMOKE_SEED}/best_model.zip" \ + --seeds "$SMOKE_SEED" \ + --episodes "$SMOKE_EVAL_EPISODES" \ + --run-label smoke \ + --output "$ARTIFACT_ROOT/smoke/eval_smoke.json" + +echo +# TODO: Check seeds for correctness! +echo "== Stage 3/3: Full 5-seed benchmark training ==" +for algo in "${ALGO_ORDER[@]}"; do + for seed in "${FINAL_SEEDS[@]}"; do + train_final "$algo" "$seed" + done +done + +echo +echo "All runs finished. Artifacts are under '$ARTIFACT_ROOT'." diff --git a/src/models/benchmarking.py b/src/models/benchmarking.py new file mode 100644 index 0000000..0e29997 --- /dev/null +++ b/src/models/benchmarking.py @@ -0,0 +1,428 @@ +"""Shared benchmark config and evaluation helpers for wildfire RL runs.""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import numpy as np + +from src.models.fire_env import ( + ASSET_BURNED, + BURNED, + BURNING, + DEPLOY_CREW, + DEPLOY_HELICOPTER, + HELD_OUT_FAMILIES, + MOVE_E, + MOVE_N, + MOVE_S, + MOVE_W, + TRAIN_FAMILIES, + WildfireEnv, + create_benchmark_env, + load_scenario_parameter_records, +) + +CANONICAL_TRAIN_DATASET = Path("data/static/scenario_parameter_records_seeded_train.json") +CANONICAL_VAL_DATASET = Path("data/static/scenario_parameter_records_seeded_val.json") +CANONICAL_HOLDOUT_DATASET = Path("data/static/scenario_parameter_records_seeded_holdout.json") +CANONICAL_TRAINING_SEEDS = [11, 22, 33, 44, 55] +CANONICAL_CHECKPOINT_INTERVAL_STEPS = 20_000 +CANONICAL_CHECKPOINT_EVAL_EPISODES = 20 +CANONICAL_FINAL_EVAL_EPISODES = 100 +CANONICAL_TIMESTEPS_BY_ALGO = { + "ppo": 200_000, + "a2c": 200_000, + "dqn": 200_000, +} +RUN_LABELS = ("smoke", "pilot", "final") + +ROLLOUT_AGENT_TYPES = ("ppo", "a2c", "dqn", "greedy", "random", "non_intervention") + + +@dataclass(frozen=True) +class SplitConfig: + """Configuration for evaluating one split.""" + + name: str + expected_split: str + dataset_path: Path + scenario_families: list[tuple[str, str, str]] + + +def canonical_train_preset(algo: str) -> dict[str, Any]: + """Return canonical training defaults for one algorithm.""" + algo_name = algo.lower() + if algo_name not in CANONICAL_TIMESTEPS_BY_ALGO: + msg = f"Unknown algorithm '{algo_name}' for canonical preset" + raise ValueError(msg) + + return { + "algo": algo_name, + "total_timesteps": CANONICAL_TIMESTEPS_BY_ALGO[algo_name], + "train_dataset": CANONICAL_TRAIN_DATASET, + "val_dataset": CANONICAL_VAL_DATASET, + "holdout_dataset": CANONICAL_HOLDOUT_DATASET, + "train_families": list(TRAIN_FAMILIES), + "val_families": list(TRAIN_FAMILIES), + "family_holdout_families": list(HELD_OUT_FAMILIES), + "checkpoint_interval_steps": CANONICAL_CHECKPOINT_INTERVAL_STEPS, + "checkpoint_eval_episodes": CANONICAL_CHECKPOINT_EVAL_EPISODES, + "final_eval_episodes": CANONICAL_FINAL_EVAL_EPISODES, + "checkpoint_visible_splits": ["train", "val"], + "benchmark_mode": True, + "holdout_checkpoint_visibility": False, + } + + +def canonical_eval_preset() -> dict[str, Any]: + """Return canonical evaluation defaults.""" + return { + "train_dataset": CANONICAL_TRAIN_DATASET, + "val_dataset": CANONICAL_VAL_DATASET, + "holdout_dataset": CANONICAL_HOLDOUT_DATASET, + "episodes": CANONICAL_FINAL_EVAL_EPISODES, + "seeds": list(CANONICAL_TRAINING_SEEDS), + "train_families": list(TRAIN_FAMILIES), + "val_families": list(TRAIN_FAMILIES), + "family_holdout_families": list(HELD_OUT_FAMILIES), + "temporal_holdout_families": list(TRAIN_FAMILIES), + } + + +def load_records(path: Path, *, expected_split: str) -> list[dict]: + """Load validated benchmark records for one split.""" + return load_scenario_parameter_records(path, benchmark_mode=True, expected_split=expected_split) + + +def build_default_splits( + *, + train_dataset: Path, + val_dataset: Path, + holdout_dataset: Path, + include_family_holdout: bool, + include_temporal_holdout: bool, + train_families: list[tuple[str, str, str]], + val_families: list[tuple[str, str, str]], + family_holdout_families: list[tuple[str, str, str]], + temporal_holdout_families: list[tuple[str, str, str]], +) -> list[SplitConfig]: + """Build canonical split evaluation configs.""" + splits = [ + SplitConfig( + name="train", + expected_split="train", + dataset_path=train_dataset, + scenario_families=train_families, + ), + SplitConfig( + name="val", + expected_split="val", + dataset_path=val_dataset, + scenario_families=val_families, + ), + ] + if include_family_holdout: + splits.append( + SplitConfig( + name="family_holdout", + expected_split="val", + dataset_path=val_dataset, + scenario_families=family_holdout_families, + ) + ) + if include_temporal_holdout: + splits.append( + SplitConfig( + name="temporal_holdout_diagnostic", + expected_split="holdout", + dataset_path=holdout_dataset, + scenario_families=temporal_holdout_families, + ) + ) + return splits + + +def _nearest_burning_cell(env: WildfireEnv) -> tuple[int, int] | None: + burning_positions = np.argwhere(env.grid == BURNING) + if burning_positions.size == 0: + return None + ar, ac = env.agent_pos + dists = np.abs(burning_positions[:, 0] - ar) + np.abs(burning_positions[:, 1] - ac) + index = int(np.argmin(dists)) + return int(burning_positions[index, 0]), int(burning_positions[index, 1]) + + +def greedy_action(env: WildfireEnv) -> int: + """Simple greedy baseline policy.""" + ar, ac = env.agent_pos + + if env.heli_left > 0 and env.heli_cd == 0: + for dr in range(-1, 2): + for dc in range(-1, 2): + rr, cc = ar + dr, ac + dc + if ( + 0 <= rr < env.grid_size + and 0 <= cc < env.grid_size + and env.grid[rr, cc] == BURNING + ): + return DEPLOY_HELICOPTER + + if env.crew_left > 0 and env.crew_cd == 0 and env.grid[ar, ac] == BURNING: + return DEPLOY_CREW + + target = _nearest_burning_cell(env) + if target is None: + return MOVE_N + + tr, tc = target + if tr < ar: + return MOVE_N + if tr > ar: + return MOVE_S + if tc > ac: + return MOVE_E + if tc < ac: + return MOVE_W + + if env.crew_left > 0 and env.crew_cd == 0: + return DEPLOY_CREW + return MOVE_N + + +def _select_action(agent_name: str, env: WildfireEnv, obs: np.ndarray, model) -> int: + if agent_name == "random": + return int(env.action_space.sample()) + if agent_name == "greedy": + return greedy_action(env) + if agent_name == "non_intervention": + return MOVE_N + action, _ = model.predict(obs, deterministic=True) + return int(action) + + +def rollout_episode( + env: WildfireEnv, *, agent_name: str, model, seed: int +) -> dict[str, float | int | None]: + """Roll one deterministic benchmark episode and return scalar metrics.""" + if agent_name not in ROLLOUT_AGENT_TYPES: + msg = f"Unsupported agent type '{agent_name}'" + raise ValueError(msg) + if agent_name in {"ppo", "a2c", "dqn"} and model is None: + msg = f"Agent '{agent_name}' requires a loaded model" + raise ValueError(msg) + + obs, _ = env.reset(seed=seed) + episode_return = 0.0 + terminated = False + truncated = False + info: dict[str, Any] = {} + + for _ in range(env.max_steps): + action = _select_action(agent_name, env, obs, model) + obs, reward, terminated, truncated, info = env.step(action) + episode_return += float(reward) + if terminated or truncated: + break + + final_burned_cells = int( + np.sum((env.grid == BURNED) | (env.grid == BURNING) | (env.grid == ASSET_BURNED)) + ) + total_cells = float(env.grid_size * env.grid_size) + + total_deployments = int(info.get("total_deployment_attempts", 0)) + successful_deployments = int(info.get("successful_deployments", 0)) + wasted_deployments = int(info.get("wasted_deployment_attempts", 0)) + resource_efficiency = ( + float(successful_deployments / total_deployments) if total_deployments > 0 else 0.0 + ) + wasted_deployment_rate = ( + float(wasted_deployments / total_deployments) if total_deployments > 0 else 0.0 + ) + containment_success = bool(terminated and not truncated) + + return { + "return": float(episode_return), + "assets_lost": int(info.get("assets_lost", env.assets_lost)), + "containment_success": 1.0 if containment_success else 0.0, + "burned_area_fraction": float(final_burned_cells / total_cells), + "time_to_containment": int(info.get("step", env.step_count)) + if containment_success + else None, + "resource_efficiency": resource_efficiency, + "wasted_deployment_rate": wasted_deployment_rate, + "final_burned_area_cells": final_burned_cells, + } + + +def summarize_episodes(episode_metrics: list[dict[str, float | int | None]]) -> dict[str, Any]: + """Summarize episode-level benchmark metrics for one split.""" + if not episode_metrics: + raise ValueError("No episode metrics to summarize") + + returns = np.array([float(m["return"]) for m in episode_metrics], dtype=float) + assets_lost = np.array([float(m["assets_lost"]) for m in episode_metrics], dtype=float) + containment = np.array([float(m["containment_success"]) for m in episode_metrics], dtype=float) + burned_fractions = np.array( + [float(m["burned_area_fraction"]) for m in episode_metrics], dtype=float + ) + resource_eff = np.array([float(m["resource_efficiency"]) for m in episode_metrics], dtype=float) + wasted_rates = np.array( + [float(m["wasted_deployment_rate"]) for m in episode_metrics], dtype=float + ) + + containment_steps = [ + float(m["time_to_containment"]) + for m in episode_metrics + if m.get("time_to_containment") is not None + ] + mean_time_to_containment = ( + float(np.mean(np.array(containment_steps, dtype=float))) if containment_steps else None + ) + + summary: dict[str, Any] = { + "episodes": len(episode_metrics), + "mean_return": float(returns.mean()), + "asset_survival_rate": float(np.mean(assets_lost == 0.0)), + "containment_success_rate": float(containment.mean()), + "mean_burned_area_fraction": float(burned_fractions.mean()), + "mean_time_to_containment": mean_time_to_containment, + "mean_resource_efficiency": float(resource_eff.mean()), + "wasted_deployment_rate": float(wasted_rates.mean()), + } + + if "normalized_burn_ratio" in episode_metrics[0]: + normalized = np.array( + [float(m["normalized_burn_ratio"]) for m in episode_metrics], + dtype=float, + ) + summary["mean_normalized_burn_ratio"] = float(normalized.mean()) + + return summary + + +def _mean_and_std(values: list[float]) -> tuple[float | None, float | None]: + if not values: + return None, None + arr = np.array(values, dtype=float) + return float(arr.mean()), float(arr.std()) + + +def aggregate_seed_summaries(seed_summaries: list[dict[str, Any]]) -> dict[str, Any]: + """Aggregate per-seed summaries into mean/std_across_seeds values.""" + if not seed_summaries: + raise ValueError("No seed summaries provided") + + metric_keys = [ + "mean_return", + "asset_survival_rate", + "containment_success_rate", + "mean_burned_area_fraction", + "mean_time_to_containment", + "mean_resource_efficiency", + "wasted_deployment_rate", + ] + if "mean_normalized_burn_ratio" in seed_summaries[0]: + metric_keys.append("mean_normalized_burn_ratio") + + aggregate: dict[str, Any] = { + "episodes_per_seed": int(seed_summaries[0]["episodes"]), + "num_seeds": len(seed_summaries), + "std_across_seeds": {}, + } + for key in metric_keys: + values = [float(summary[key]) for summary in seed_summaries if summary.get(key) is not None] + mean_value, std_value = _mean_and_std(values) + aggregate[key] = mean_value + aggregate["std_across_seeds"][key] = std_value + + return aggregate + + +def evaluate_agent_on_split( + *, + agent_name: str, + model, + records: list[dict], + expected_split: str, + scenario_families: list[tuple[str, str, str]], + seeds: list[int], + episodes_per_seed: int, + compute_normalized_burn_ratio: bool, +) -> dict[str, Any]: + """Evaluate one agent on one split over one or more seeds.""" + seed_summaries: list[dict[str, Any]] = [] + + for seed in seeds: + env = create_benchmark_env( + expected_split=expected_split, + scenario_parameter_records=records, + scenario_families=scenario_families, + ) + baseline_env = None + if compute_normalized_burn_ratio: + baseline_env = create_benchmark_env( + expected_split=expected_split, + scenario_parameter_records=records, + scenario_families=scenario_families, + ) + try: + episode_metrics = [] + for ep in range(episodes_per_seed): + eval_seed = seed * 1_000_000 + ep + metrics = rollout_episode(env, agent_name=agent_name, model=model, seed=eval_seed) + if baseline_env is not None: + baseline_metrics = rollout_episode( + baseline_env, + agent_name="non_intervention", + model=None, + seed=eval_seed, + ) + baseline_burned = int(baseline_metrics["final_burned_area_cells"]) + metrics["normalized_burn_ratio"] = float( + float(metrics["final_burned_area_cells"]) / max(1, baseline_burned) + ) + episode_metrics.append(metrics) + + seed_summary = summarize_episodes(episode_metrics) + seed_summary["seed"] = seed + seed_summaries.append(seed_summary) + finally: + env.close() + if baseline_env is not None: + baseline_env.close() + + return { + "seed_metrics": seed_summaries, + "aggregate": aggregate_seed_summaries(seed_summaries), + } + + +def load_model_for_algo(algo: str, model_path: Path): + """Load a Stable-Baselines3 model for the given algorithm name.""" + algo_name = algo.lower() + if not model_path.exists(): + raise FileNotFoundError(f"Model not found at {model_path}") + + if algo_name == "ppo": + from stable_baselines3 import PPO + + return PPO.load(str(model_path)) + if algo_name == "a2c": + from stable_baselines3 import A2C + + return A2C.load(str(model_path)) + if algo_name == "dqn": + from stable_baselines3 import DQN + + return DQN.load(str(model_path)) + + msg = f"Unsupported model algo '{algo_name}'" + raise ValueError(msg) + + +def heldout_performance_drop(train_metric: float, heldout_metric: float) -> float: + """Compute held-out drop for a metric (held-out - train).""" + return float(heldout_metric - train_metric) diff --git a/src/models/evaluate_agents.py b/src/models/evaluate_agents.py index 8e8e127..56ad2b9 100644 --- a/src/models/evaluate_agents.py +++ b/src/models/evaluate_agents.py @@ -1,265 +1,217 @@ -"""General benchmark evaluation interface for RL agents on split datasets.""" +"""Unified benchmark evaluation runner for learned and heuristic agents.""" from __future__ import annotations import argparse import json from pathlib import Path - -import numpy as np - -from src.models.fire_env import ( - ASSET_BURNED, - BURNED, - BURNING, - DEPLOY_CREW, - DEPLOY_HELICOPTER, - MOVE_E, - MOVE_N, - MOVE_S, - MOVE_W, - WildfireEnv, - create_benchmark_env, - load_scenario_parameter_records, +from typing import Any + +from src.models.benchmarking import ( + RUN_LABELS, + build_default_splits, + canonical_eval_preset, + evaluate_agent_on_split, + heldout_performance_drop, + load_model_for_algo, + load_records, ) -try: - from tqdm import tqdm -except Exception: # pragma: no cover - optional dependency - - def tqdm(iterable, **_kwargs): - return iterable - - -DEFAULT_TRAIN_DATASET = Path("data/static/scenario_parameter_records_seeded_train.json") -DEFAULT_VAL_DATASET = Path("data/static/scenario_parameter_records_seeded_val.json") -DEFAULT_HOLDOUT_DATASET = Path("data/static/scenario_parameter_records_seeded_holdout.json") -DEFAULT_PPO_MODEL = Path("src/models/tactical_ppo_agent.zip") - - -def _load_ppo_model(path: Path): - from stable_baselines3 import PPO - - if not path.exists(): - raise FileNotFoundError(f"PPO model not found at {path}") - return PPO.load(str(path)) - +SUPPORTED_AGENTS = ("ppo", "a2c", "dqn", "greedy", "random") +LEARNED_AGENTS = {"ppo", "a2c", "dqn"} -def _nearest_burning_cell(env: WildfireEnv) -> tuple[int, int] | None: - burning_positions = np.argwhere(env.grid == BURNING) - if burning_positions.size == 0: - return None - ar, ac = env.agent_pos - dists = np.abs(burning_positions[:, 0] - ar) + np.abs(burning_positions[:, 1] - ac) - idx = int(np.argmin(dists)) - return int(burning_positions[idx, 0]), int(burning_positions[idx, 1]) +def _parse_csv_ints(raw: str) -> list[int]: + values = [int(part.strip()) for part in raw.split(",") if part.strip()] + if not values: + raise ValueError("At least one seed must be provided") + return values -def _greedy_action(env: WildfireEnv) -> int: - ar, ac = env.agent_pos - if env.heli_left > 0 and env.heli_cd == 0: - for dr in range(-1, 2): - for dc in range(-1, 2): - rr, cc = ar + dr, ac + dc - if ( - 0 <= rr < env.grid_size - and 0 <= cc < env.grid_size - and env.grid[rr, cc] == BURNING - ): - return DEPLOY_HELICOPTER +def _parse_agents(raw: str) -> list[str]: + agents = [part.strip().lower() for part in raw.split(",") if part.strip()] + if not agents: + raise ValueError("At least one agent must be provided") + invalid = [agent for agent in agents if agent not in SUPPORTED_AGENTS] + if invalid: + msg = f"Unsupported agents {invalid}; expected any of {sorted(SUPPORTED_AGENTS)}" + raise ValueError(msg) + return agents - if env.crew_left > 0 and env.crew_cd == 0 and env.grid[ar, ac] == BURNING: - return DEPLOY_CREW - target = _nearest_burning_cell(env) - if target is None: - return MOVE_N - - tr, tc = target - if tr < ar: - return MOVE_N - if tr > ar: - return MOVE_S - if tc > ac: - return MOVE_E - if tc < ac: - return MOVE_W - return DEPLOY_CREW if env.crew_left > 0 and env.crew_cd == 0 else MOVE_N - - -def _run_episode(env: WildfireEnv, agent_name: str, model, seed: int) -> dict: - obs, _info = env.reset(seed=seed) - episode_return = 0.0 - terminated = False - truncated = False - info = {} - - for _ in range(env.max_steps): - if agent_name == "random": - action = int(env.action_space.sample()) - elif agent_name == "greedy": - action = _greedy_action(env) - else: - action, _ = model.predict(obs, deterministic=True) - action = int(action) - - obs, reward, terminated, truncated, info = env.step(action) - episode_return += float(reward) - if terminated or truncated: - break - - final_burned_area = int( - np.sum((env.grid == BURNED) | (env.grid == BURNING) | (env.grid == ASSET_BURNED)) - ) - containment_success = 1 if terminated and not truncated else 0 - heli_used = env.heli_budget_init - info.get("heli_left", env.heli_left) - crew_used = env.crew_budget_init - info.get("crew_left", env.crew_left) - - return { - "return": episode_return, - "assets_lost": int(info.get("assets_lost", env.assets_lost)), - "containment_success": containment_success, - "final_burned_area": final_burned_area, - "time_to_containment": int(info.get("step", env.step_count)), - "heli_used": int(heli_used), - "crew_used": int(crew_used), - "resource_efficiency": float(final_burned_area / max(1, heli_used + crew_used)), - } +def _resolve_model_paths(args, agents: list[str]) -> dict[str, Path]: + learned_agents = [agent for agent in agents if agent in LEARNED_AGENTS] + model_paths: dict[str, Path] = {} + for agent in learned_agents: + path = getattr(args, f"{agent}_model") + if path is not None: + model_paths[agent] = path -def _evaluate_agent_on_split( - *, - agent_name: str, - records: list[dict], - seeds: list[int], - episodes_per_seed: int, - model, - compute_normalized_burn_ratio: bool, - split_name: str, -) -> dict: - episode_metrics = [] + if args.model_path is not None and len(learned_agents) == 1: + model_paths[learned_agents[0]] = args.model_path - for seed in seeds: - env = create_benchmark_env( - scenario_parameter_records=records, - expected_split=split_name, + missing = [agent for agent in learned_agents if agent not in model_paths] + if missing: + msg = ( + "Missing model paths for learned agents " + f"{missing}. Provide --model-path (single learned agent only) or --ppo-model/--a2c-model/--dqn-model." ) - baseline_env = create_benchmark_env( - scenario_parameter_records=records, - expected_split=split_name, + raise ValueError(msg) + + return model_paths + + +def _compute_performance_drops(agent_result: dict[str, Any]) -> dict[str, float]: + drops: dict[str, float] = {} + train_summary = agent_result.get("train", {}).get("aggregate", {}) + train_asset_survival = train_summary.get("asset_survival_rate") + if train_asset_survival is None: + return drops + + for split in ("val", "family_holdout", "temporal_holdout_diagnostic"): + split_summary = agent_result.get(split, {}).get("aggregate", {}) + heldout_value = split_summary.get("asset_survival_rate") + if heldout_value is None: + continue + drops[f"{split}_asset_survival_drop"] = heldout_performance_drop( + float(train_asset_survival), + float(heldout_value), ) - iterator = tqdm(range(episodes_per_seed), desc=f"{agent_name} seed={seed}", unit="ep") - for ep in iterator: - eval_seed = seed * 10_000 + ep - metrics = _run_episode(env, agent_name, model, seed=eval_seed) - if compute_normalized_burn_ratio: - # Use MOVE_N-only as deterministic no-action surrogate baseline. - _obs, _ = baseline_env.reset(seed=eval_seed) - for _ in range(baseline_env.max_steps): - _obs, _reward, done, trunc, _base_info = baseline_env.step(MOVE_N) - if done or trunc: - break - baseline_burned = int( - np.sum( - (baseline_env.grid == BURNED) - | (baseline_env.grid == BURNING) - | (baseline_env.grid == ASSET_BURNED) - ) - ) - metrics["normalized_burn_ratio"] = float( - metrics["final_burned_area"] / max(1, baseline_burned) - ) - episode_metrics.append(metrics) - - arr = { - key: np.array([m[key] for m in episode_metrics], dtype=float) for key in episode_metrics[0] - } - summary = { - "episodes": len(episode_metrics), - "mean_return": float(arr["return"].mean()), - "std_return": float(arr["return"].std()), - "asset_survival_rate": float((arr["assets_lost"] == 0).mean()), - "containment_success_rate": float(arr["containment_success"].mean()), - "mean_final_burned_area": float(arr["final_burned_area"].mean()), - "mean_time_to_containment": float(arr["time_to_containment"].mean()), - "mean_resource_efficiency": float(arr["resource_efficiency"].mean()), - "variance_across_episodes": float(arr["return"].var()), - } - if "normalized_burn_ratio" in arr: - summary["mean_normalized_burn_ratio"] = float(arr["normalized_burn_ratio"].mean()) - return summary - - -def _load_split_records(path: Path | None, *, split_name: str) -> list[dict]: - if path is None or not path.exists(): - return [] - return load_scenario_parameter_records( - path, - benchmark_mode=True, - expected_split=split_name, - ) + return drops def main() -> None: + preset = canonical_eval_preset() + parser = argparse.ArgumentParser( - description="Evaluate benchmark agents on train/val/holdout splits" + description="Evaluate benchmark agents on canonical wildfire splits" + ) + parser.add_argument( + "--benchmark-preset", + type=str, + default="canonical", + choices=("canonical",), + help="Benchmark-safe eval preset", + ) + parser.add_argument("--agents", type=str, default="ppo,a2c,dqn,greedy,random") + parser.add_argument("--model-path", type=Path, default=None) + parser.add_argument("--ppo-model", type=Path, default=None) + parser.add_argument("--a2c-model", type=Path, default=None) + parser.add_argument("--dqn-model", type=Path, default=None) + parser.add_argument("--train-dataset", type=Path, default=preset["train_dataset"]) + parser.add_argument("--val-dataset", type=Path, default=preset["val_dataset"]) + parser.add_argument("--holdout-dataset", type=Path, default=preset["holdout_dataset"]) + parser.add_argument( + "--episodes", + type=int, + default=preset["episodes"], + help="Episodes per seed per split", + ) + parser.add_argument( + "--seeds", + type=str, + default=",".join(str(seed) for seed in preset["seeds"]), + help="Comma-separated seed list", + ) + parser.add_argument( + "--include-family-holdout", + action="store_true", + help="Evaluate validation records with HELD_OUT_FAMILIES", + ) + parser.add_argument( + "--include-temporal-holdout", + action="store_true", + help="Evaluate temporal holdout as diagnostic output", ) - parser.add_argument("--agents", type=str, default="ppo,greedy,random") - parser.add_argument("--train-dataset", type=Path, default=DEFAULT_TRAIN_DATASET) - parser.add_argument("--val-dataset", type=Path, default=DEFAULT_VAL_DATASET) - parser.add_argument("--holdout-dataset", type=Path, default=DEFAULT_HOLDOUT_DATASET) - parser.add_argument("--ppo-model", type=Path, default=DEFAULT_PPO_MODEL) - parser.add_argument("--episodes", type=int, default=20, help="Episodes per seed per split") - parser.add_argument("--seeds", type=str, default="42,43,44") parser.add_argument("--no-normalized-burn", action="store_true") + parser.add_argument( + "--run-label", + type=str, + default="final", + choices=RUN_LABELS, + help="Run label attached to output metadata", + ) parser.add_argument("--output", type=Path, default=None) args = parser.parse_args() - seeds = [int(s.strip()) for s in args.seeds.split(",") if s.strip()] - agents = [a.strip().lower() for a in args.agents.split(",") if a.strip()] + del args.benchmark_preset # canonical is currently the only supported preset + + seeds = _parse_csv_ints(args.seeds) + agents = _parse_agents(args.agents) + model_paths = _resolve_model_paths(args, agents) + + split_configs = build_default_splits( + train_dataset=args.train_dataset, + val_dataset=args.val_dataset, + holdout_dataset=args.holdout_dataset, + include_family_holdout=args.include_family_holdout, + include_temporal_holdout=args.include_temporal_holdout, + train_families=preset["train_families"], + val_families=preset["val_families"], + family_holdout_families=preset["family_holdout_families"], + temporal_holdout_families=preset["temporal_holdout_families"], + ) split_records = { - "train": _load_split_records(args.train_dataset, split_name="train"), - "val": _load_split_records(args.val_dataset, split_name="val"), - "holdout": _load_split_records(args.holdout_dataset, split_name="holdout"), + split.name: load_records(split.dataset_path, expected_split=split.expected_split) + for split in split_configs } - results: dict[str, dict] = {} - ppo_model = None - if "ppo" in agents: - ppo_model = _load_ppo_model(args.ppo_model) + results: dict[str, Any] = { + "config": { + "run_label": args.run_label, + "agents": agents, + "seeds": seeds, + "episodes_per_seed": args.episodes, + "compute_normalized_burn_ratio": not args.no_normalized_burn, + "splits": [split.name for split in split_configs], + "datasets": { + "train": str(args.train_dataset), + "val": str(args.val_dataset), + "holdout": str(args.holdout_dataset), + }, + "model_paths": {name: str(path) for name, path in model_paths.items()}, + }, + "results": {}, + } + + loaded_models = {agent: load_model_for_algo(agent, path) for agent, path in model_paths.items()} - for agent_name in agents: - results[agent_name] = {} - for split_name, records in split_records.items(): - if not records: - continue - model = ppo_model if agent_name == "ppo" else None - summary = _evaluate_agent_on_split( - agent_name=agent_name, - records=records, + for agent in agents: + agent_result: dict[str, Any] = {} + model = loaded_models.get(agent) + for split in split_configs: + split_eval = evaluate_agent_on_split( + agent_name=agent, + model=model, + records=split_records[split.name], + expected_split=split.expected_split, + scenario_families=split.scenario_families, seeds=seeds, episodes_per_seed=args.episodes, - model=model, compute_normalized_burn_ratio=not args.no_normalized_burn, - split_name=split_name, ) - results[agent_name][split_name] = summary + agent_result[split.name] = split_eval + agent_result["heldout_performance_drop"] = _compute_performance_drops(agent_result) + results["results"][agent] = agent_result print("\nBenchmark Summary") - print("=" * 72) - for agent_name, split_summaries in results.items(): - for split_name, summary in split_summaries.items(): + print("=" * 92) + for agent, split_payload in results["results"].items(): + for split in split_configs: + aggregate = split_payload[split.name]["aggregate"] + stds = aggregate["std_across_seeds"] print( - f"{agent_name:>8} | {split_name:<7} | episodes={summary['episodes']:>4} " - f"| return={summary['mean_return']:.1f} " - f"| assets_survival={summary['asset_survival_rate']:.3f} " - f"| containment={summary['containment_success_rate']:.3f} " - f"| burned={summary['mean_final_burned_area']:.1f}" + f"{agent:>8} | {split.name:<26} " + f"| return={aggregate['mean_return']:.2f}±{stds['mean_return']:.2f} " + f"| asset_survival={aggregate['asset_survival_rate']:.3f}±{stds['asset_survival_rate']:.3f} " + f"| containment={aggregate['containment_success_rate']:.3f}±{stds['containment_success_rate']:.3f} " + f"| burned_frac={aggregate['mean_burned_area_fraction']:.3f}±{stds['mean_burned_area_fraction']:.3f}" ) if args.output: + args.output.parent.mkdir(parents=True, exist_ok=True) args.output.write_text(json.dumps(results, indent=2)) print(f"\nSaved results to {args.output}") diff --git a/src/models/rl_agent.py b/src/models/rl_agent.py deleted file mode 100644 index b1c8561..0000000 --- a/src/models/rl_agent.py +++ /dev/null @@ -1,164 +0,0 @@ -""" -rl_agent.py — PPO tactical agent inference interface. - -Loads the trained PPO model and returns tactical deployment waypoints -for a given fire, converting grid coordinates back to real lat/lon. - -Usage: - from src.models.rl_agent import get_tactical_recommendations - waypoints = get_tactical_recommendations("BC-2026-001", fire_data, spread_output) -""" - -from __future__ import annotations - -import logging -import math -from pathlib import Path - -logger = logging.getLogger(__name__) - -MODEL_PATH = Path(__file__).parent / "tactical_ppo_agent.zip" -GRID_SIZE = 25 - - -def _grid_to_latlon( - grid_pos: list[int], - fire_center_lat: float, - fire_center_lon: float, - spread_radius_m: float, -) -> tuple[float, float]: - """Convert a grid cell position to real-world lat/lon.""" - metres_per_deg_lat = 111_320 - metres_per_deg_lon = 111_320 * math.cos(math.radians(fire_center_lat)) - - cell_size_m = (spread_radius_m * 2) / GRID_SIZE - center_cell = GRID_SIZE // 2 - - delta_row = grid_pos[0] - center_cell - delta_col = grid_pos[1] - center_cell - - lat = fire_center_lat - (delta_row * cell_size_m / metres_per_deg_lat) - lon = fire_center_lon + (delta_col * cell_size_m / metres_per_deg_lon) - - return round(lat, 5), round(lon, 5) - - -def _greedy_fallback( - fire_lat: float, - fire_lon: float, - spread_1h_m: float, - spread_3h_m: float, -) -> list[dict]: - """ - Greedy heuristic recommendations if the PPO model isn't trained yet. - Deploys along the fire perimeter at cardinal/intercardinal points. - """ - directions = [ - ("N", -1, 0, "ground_crew", "Establish northern anchor line"), - ("NE", -0.7, 0.7, "helicopter", "Pre-position tanker for NE flank"), - ("E", 0, 1, "ground_crew", "Cut containment line on eastern flank"), - ("S", 1, 0, "ground_crew", "Southern backfire opportunity"), - ("W", 0, -1, "helicopter", "Air attack on advancing western head"), - ] - - metres_per_deg_lat = 111_320 - metres_per_deg_lon = 111_320 * math.cos(math.radians(fire_lat)) - offset_m = spread_1h_m * 0.8 - - waypoints = [] - for name, dlat_factor, dlon_factor, asset, rationale in directions: - lat = fire_lat + (dlat_factor * offset_m / metres_per_deg_lat) - lon = fire_lon + (dlon_factor * offset_m / metres_per_deg_lon) - waypoints.append( - { - "direction": name, - "latitude": round(lat, 5), - "longitude": round(lon, 5), - "asset_type": asset, - "rationale": rationale, - "score": round(0.9 - len(waypoints) * 0.1, 2), - "source": "greedy_heuristic", - } - ) - - return waypoints - - -def get_tactical_recommendations( - fire_id: str, - fire_data: dict | None = None, - spread_output: dict | None = None, - n_inference_steps: int = 60, -) -> list[dict]: - """ - Generate tactical deployment waypoints for a fire. - - If the PPO model is trained and saved, uses it for inference. - Falls back to the greedy heuristic if model isn't available. - """ - fire_lat = float(fire_data.get("latitude", 49.9071)) if fire_data else 49.9071 - fire_lon = float(fire_data.get("longitude", -119.496)) if fire_data else -119.496 - spread_1h = float(spread_output.get("spread_1h_m", 1200)) if spread_output else 1200 - spread_3h = float(spread_output.get("spread_3h_m", 3600)) if spread_output else 3600 - - if not MODEL_PATH.exists(): - logger.info(f"PPO model not found at {MODEL_PATH} — using greedy heuristic") - return _greedy_fallback(fire_lat, fire_lon, spread_1h, spread_3h) - - try: - from stable_baselines3 import PPO - - from src.models.fire_env import WildfireEnv - - spread_rate_m_per_min = spread_1h / 60.0 - env = WildfireEnv( - base_spread_rate_m_per_min=spread_rate_m_per_min, - benchmark_mode=False, - ) - model = PPO.load(str(MODEL_PATH), env=env) - - obs, _ = env.reset() - waypoints = [] - deployment_actions = {4, 5} - - for _ in range(n_inference_steps): - action, _ = model.predict(obs, deterministic=True) - obs, reward, done, truncated, _info = env.step(int(action)) - - if int(action) in deployment_actions: - lat, lon = _grid_to_latlon(env.agent_pos, fire_lat, fire_lon, spread_1h) - asset = "helicopter" if int(action) == 4 else "ground_crew" - waypoints.append( - { - "latitude": lat, - "longitude": lon, - "asset_type": asset, - "rationale": f"PPO recommended {asset} deployment (step {env.step_count})", - "score": round(float(reward), 2), - "source": "ppo_agent", - } - ) - - if done or truncated: - break - - # Deduplicate very close waypoints (within 200m) - unique = [] - for wp in waypoints: - too_close = any( - abs(wp["latitude"] - u["latitude"]) < 0.002 - and abs(wp["longitude"] - u["longitude"]) < 0.002 - for u in unique - ) - if not too_close: - unique.append(wp) - - if not unique: - logger.warning("PPO produced no deployments — falling back to heuristic") - return _greedy_fallback(fire_lat, fire_lon, spread_1h, spread_3h) - - return unique[:8] - - except Exception as e: - logger.error(f"PPO inference failed: {e} — using greedy fallback") - return _greedy_fallback(fire_lat, fire_lon, spread_1h, spread_3h) diff --git a/src/models/train_rl_agent.py b/src/models/train_rl_agent.py index 3c93c33..83d7d3d 100644 --- a/src/models/train_rl_agent.py +++ b/src/models/train_rl_agent.py @@ -1,232 +1,426 @@ -""" -train_rl_agent.py — PPO tactical agent training script. +"""Unified wildfire benchmark trainer for PPO/A2C/DQN.""" -Trains a PPO agent on the WildfireEnv gymnasium environment (25×25 grid -with critical assets and finite suppression budgets). - -Run: - uv run python -m src.models.train_rl_agent - uv run python -m src.models.train_rl_agent --timesteps 10000 # quick test -""" +from __future__ import annotations import argparse -import logging +import json import sys from pathlib import Path +from typing import Any -logging.basicConfig(level=logging.WARNING) -logger = logging.getLogger(__name__) - -MODEL_SAVE_PATH = Path(__file__).parent / "tactical_ppo_agent" -DEFAULT_SCENARIO_DATASETS = (Path("data/static/scenario_parameter_records_seeded_train.json"),) - - -def _resolve_dataset_path(path: str | None) -> str | None: - if path: - return path - for candidate in DEFAULT_SCENARIO_DATASETS: - if candidate.exists(): - return str(candidate) - return None - - -def _existing_path(path: str | None) -> str | None: - if path and Path(path).exists(): - return path - return None - - -def _evaluate_model( - model, - dataset_path: str, - seed: int, - episodes: int = 5, - expected_split: str | None = None, -) -> tuple[float, float]: - from src.models.fire_env import create_benchmark_env - - split = expected_split or "train" - eval_env = create_benchmark_env(dataset_path=dataset_path, expected_split=split) - returns = [] - assets_lost_total = [] - for ep in range(episodes): - obs, _ = eval_env.reset(seed=seed + ep + 100) - ep_return = 0.0 - for _ in range(150): - action, _ = model.predict(obs, deterministic=True) - obs, reward, done, truncated, info = eval_env.step(int(action)) - ep_return += reward - if done or truncated: - break - returns.append(ep_return) - assets_lost_total.append(info["assets_lost"]) - return sum(returns) / len(returns), sum(assets_lost_total) / len(assets_lost_total) - - -def train( - total_timesteps: int = 200_000, - spread_rate_m_per_min: float = 15.0, - n_envs: int = 4, - seed: int = 42, - scenario_dataset_path: str | None = None, - val_dataset_path: str | None = None, - holdout_dataset_path: str | None = None, - allow_legacy_dev_fallback: bool = False, -) -> None: - """ - Train the PPO tactical agent. - - Args: - total_timesteps: Total env steps to train for. - spread_rate_m_per_min: Legacy fixed spread rate used only in dev fallback mode. - n_envs: Parallel environments. - seed: Random seed for reproducibility. - """ - try: +from src.models.benchmarking import ( + RUN_LABELS, + build_default_splits, + canonical_train_preset, + evaluate_agent_on_split, + load_model_for_algo, + load_records, +) +from src.models.fire_env import benchmark_env_kwargs + +ALGO_CHOICES = ("ppo", "a2c", "dqn") + + +def _default_hyperparameters(algo: str) -> dict[str, Any]: + if algo == "ppo": + return { + "learning_rate": 3e-4, + "n_steps": 512, + "batch_size": 64, + "n_epochs": 10, + "gamma": 0.995, + "gae_lambda": 0.95, + "clip_range": 0.2, + "ent_coef": 0.01, + } + if algo == "a2c": + return { + "learning_rate": 7e-4, + "n_steps": 5, + "gamma": 0.995, + "gae_lambda": 1.0, + "ent_coef": 0.01, + } + if algo == "dqn": + return { + "learning_rate": 1e-4, + "buffer_size": 100_000, + "learning_starts": 1_000, + "batch_size": 64, + "gamma": 0.995, + "train_freq": 4, + "gradient_steps": 1, + "target_update_interval": 1_000, + "exploration_fraction": 0.2, + "exploration_final_eps": 0.05, + } + raise ValueError(f"Unsupported algorithm '{algo}'") + + +def _build_model(*, algo: str, env, seed: int, device: str, hyperparams: dict[str, Any]): + if algo == "ppo": from stable_baselines3 import PPO - from stable_baselines3.common.env_util import make_vec_env - from src.models.fire_env import WildfireEnv, benchmark_env_kwargs - except ImportError as e: - print(f"Missing dependency: {e}") - print(" Run: uv sync") - sys.exit(1) + return PPO("MlpPolicy", env, seed=seed, device=device, verbose=1, **hyperparams) + if algo == "a2c": + from stable_baselines3 import A2C - print("=" * 60) - print(" FireGrid PPO Tactical Agent — Training") - print("=" * 60) - print(f" Timesteps: {total_timesteps:,}") - print(f" Environments: {n_envs} parallel") - print(" Grid: 25×25 with critical assets") - print(" Budgets: heli=8, crew=20") - print() - - scenario_dataset_path = _resolve_dataset_path(scenario_dataset_path) - - env_kwargs: dict = {} - if scenario_dataset_path: - env_kwargs = benchmark_env_kwargs( - dataset_path=scenario_dataset_path, - expected_split="train", - ) - records = env_kwargs["scenario_parameter_records"] - print(" Runtime data: frozen offline scenario records (no live ingestion)") - print(f" Scenario records: {len(records)} from {scenario_dataset_path}") - else: - if not allow_legacy_dev_fallback: - msg = ( - "No training scenario dataset found. Canonical training requires precomputed " - "scenario_parameter_records_seeded_train.json (or explicit equivalent). " - "To run non-canonical dev mode, pass --allow-legacy-dev-fallback with " - "--spread-rate." - ) - raise ValueError(msg) - print( - " No scenario dataset found; running explicit legacy dev mode " - "with --spread-rate fallback." - ) - print(f" Legacy spread rate: {spread_rate_m_per_min} m/min") - env_kwargs["benchmark_mode"] = False - env_kwargs["base_spread_rate_m_per_min"] = spread_rate_m_per_min - vec_env = make_vec_env( - WildfireEnv, - n_envs=n_envs, - seed=seed, - env_kwargs=env_kwargs, - ) + return A2C("MlpPolicy", env, seed=seed, device=device, verbose=1, **hyperparams) + if algo == "dqn": + from stable_baselines3 import DQN - model = PPO( - "MlpPolicy", - vec_env, - verbose=1, - learning_rate=3e-4, - n_steps=512, - batch_size=64, - n_epochs=10, - gamma=0.995, - gae_lambda=0.95, - clip_range=0.2, - ent_coef=0.01, - seed=seed, - device="cpu", - ) + return DQN("MlpPolicy", env, seed=seed, device=device, verbose=1, **hyperparams) + raise ValueError(f"Unsupported algorithm '{algo}'") + + +def _create_train_env(*, algo: str, env_kwargs: dict[str, Any], n_envs: int, seed: int): + from src.models.fire_env import WildfireEnv + + if algo in {"ppo", "a2c"}: + from stable_baselines3.common.env_util import make_vec_env - print("Training PPO agent...\n") - model.learn(total_timesteps=total_timesteps) - - model.save(str(MODEL_SAVE_PATH)) - print(f"\nPPO model saved -> {MODEL_SAVE_PATH}.zip") - - # Quick evaluation - print("\nRunning quick evaluation (5 episodes)...") - eval_targets = [("train", scenario_dataset_path)] - if _existing_path(val_dataset_path): - eval_targets.append(("val", val_dataset_path)) - if _existing_path(holdout_dataset_path): - eval_targets.append(("holdout", holdout_dataset_path)) - - for split_name, dataset_path in eval_targets: - if not dataset_path: - continue - mean_return, mean_assets_lost = _evaluate_model( - model, - dataset_path, + return make_vec_env( + WildfireEnv, + n_envs=n_envs, seed=seed, - episodes=5, - expected_split=split_name, + env_kwargs=env_kwargs, ) - print(f" [{split_name}] Mean return: {mean_return:.1f}") - print(f" [{split_name}] Mean assets lost: {mean_assets_lost:.1f}") - print(f"\nTraining complete. Model ready at {MODEL_SAVE_PATH}.zip") + return WildfireEnv(**env_kwargs) -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Train PPO wildfire tactical agent") - parser.add_argument( - "--timesteps", type=int, default=200_000, help="Total training timesteps (default: 200000)" - ) - parser.add_argument( - "--spread-rate", - type=float, - default=15.0, - help="Legacy dev-mode fixed spread rate in m/min (default: 15.0)", - ) + +def _selects_better_checkpoint(candidate: dict[str, Any], incumbent: dict[str, Any] | None) -> bool: + if incumbent is None: + return True + cand_val = candidate["splits"]["val"] + inc_val = incumbent["splits"]["val"] + + cand_primary = float(cand_val["asset_survival_rate"]) + inc_primary = float(inc_val["asset_survival_rate"]) + if cand_primary != inc_primary: + return cand_primary > inc_primary + + cand_tie = float(cand_val["mean_return"]) + inc_tie = float(inc_val["mean_return"]) + return cand_tie > inc_tie + + +def _single_seed_split_summary(split_eval: dict[str, Any]) -> dict[str, Any]: + seed_summary = dict(split_eval["seed_metrics"][0]) + seed_summary.pop("seed", None) + return seed_summary + + +def _write_json(path: Path, payload: Any) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(payload, indent=2)) + + +def _resolve_hyperparameters(args, algo: str) -> dict[str, Any]: + hyperparams = _default_hyperparameters(algo) + + if args.learning_rate is not None: + hyperparams["learning_rate"] = args.learning_rate + + if algo in {"ppo", "a2c"} and args.n_steps is not None: + hyperparams["n_steps"] = args.n_steps + if algo in {"ppo", "a2c"} and args.ent_coef is not None: + hyperparams["ent_coef"] = args.ent_coef + + if algo == "dqn": + if args.exploration_fraction is not None: + hyperparams["exploration_fraction"] = args.exploration_fraction + if args.exploration_final_eps is not None: + hyperparams["exploration_final_eps"] = args.exploration_final_eps + if args.target_update_interval is not None: + hyperparams["target_update_interval"] = args.target_update_interval + if args.replay_buffer_size is not None: + hyperparams["buffer_size"] = args.replay_buffer_size + + return hyperparams + + +def _families_to_jsonable(families: list[tuple[str, str, str]]) -> list[list[str]]: + return [list(family) for family in families] + + +def main() -> None: + parser = argparse.ArgumentParser(description="Train benchmark RL agents for wildfire task") + parser.add_argument("--algo", type=str, default="ppo", choices=ALGO_CHOICES) parser.add_argument( - "--allow-legacy-dev-fallback", - action="store_true", - help="Allow non-canonical fallback when no scenario dataset is available", + "--benchmark-preset", + type=str, + default="canonical", + choices=("canonical",), + help="Benchmark-safe training preset", ) - parser.add_argument( - "--envs", type=int, default=4, help="Number of parallel environments (default: 4)" + parser.add_argument("--run-label", type=str, default="smoke", choices=RUN_LABELS) + parser.add_argument("--seed", type=int, default=11) + parser.add_argument("--timesteps", type=int, default=None) + parser.add_argument("--envs", type=int, default=4) + parser.add_argument("--device", type=str, default="cpu") + parser.add_argument("--artifact-root", type=Path, default=Path("artifacts/benchmark")) + + parser.add_argument("--train-dataset", type=Path, default=None) + parser.add_argument("--val-dataset", type=Path, default=None) + parser.add_argument("--holdout-dataset", type=Path, default=None) + + parser.add_argument("--checkpoint-interval", type=int, default=None) + parser.add_argument("--checkpoint-eval-episodes", type=int, default=None) + parser.add_argument("--final-eval-episodes", type=int, default=None) + parser.add_argument("--include-family-holdout-checkpoints", action="store_true") + parser.add_argument("--include-family-holdout-final", action="store_true") + parser.add_argument("--include-temporal-holdout-final", action="store_true") + parser.add_argument("--no-normalized-burn-final", action="store_true") + + parser.add_argument("--learning-rate", type=float, default=None) + parser.add_argument("--n-steps", type=int, default=None) + parser.add_argument("--ent-coef", type=float, default=None) + parser.add_argument("--exploration-fraction", type=float, default=None) + parser.add_argument("--exploration-final-eps", type=float, default=None) + parser.add_argument("--target-update-interval", type=int, default=None) + parser.add_argument("--replay-buffer-size", type=int, default=None) + + args = parser.parse_args() + + try: + preset = canonical_train_preset(args.algo) + except Exception as exc: + print(f"Failed to load benchmark preset: {exc}") + sys.exit(1) + + del args.benchmark_preset # canonical is currently the only supported preset + + total_timesteps = args.timesteps or int(preset["total_timesteps"]) + checkpoint_interval = args.checkpoint_interval or int(preset["checkpoint_interval_steps"]) + checkpoint_eval_episodes = args.checkpoint_eval_episodes or int( + preset["checkpoint_eval_episodes"] ) - parser.add_argument("--seed", type=int, default=42, help="Random seed (default: 42)") - parser.add_argument( - "--scenario-dataset", - type=str, - default=None, - help="Path to cached training scenario parameter JSON dataset", + final_eval_episodes = args.final_eval_episodes or int(preset["final_eval_episodes"]) + + train_dataset = args.train_dataset or Path(preset["train_dataset"]) + val_dataset = args.val_dataset or Path(preset["val_dataset"]) + holdout_dataset = args.holdout_dataset or Path(preset["holdout_dataset"]) + + train_families = list(preset["train_families"]) + val_families = list(preset["val_families"]) + family_holdout_families = list(preset["family_holdout_families"]) + + train_records = load_records(train_dataset, expected_split="train") + val_records = load_records(val_dataset, expected_split="val") + holdout_records = load_records(holdout_dataset, expected_split="holdout") + + run_dir = args.artifact_root / args.run_label / args.algo / f"seed_{args.seed}" + run_dir.mkdir(parents=True, exist_ok=True) + + hyperparams = _resolve_hyperparameters(args, args.algo) + + train_env_kwargs = benchmark_env_kwargs( + expected_split="train", + scenario_parameter_records=train_records, ) - parser.add_argument( - "--val-dataset", - type=str, - default="data/static/scenario_parameter_records_seeded_val.json", - help="Path to cached validation scenario parameter JSON dataset", + train_env_kwargs["scenario_families"] = train_families + + print("=" * 72) + print(f"Wildfire benchmark training | algo={args.algo} | seed={args.seed}") + print("=" * 72) + print(f"run_label={args.run_label}") + print(f"timesteps={total_timesteps:,}") + print(f"checkpoint_interval={checkpoint_interval:,}") + print(f"checkpoint_eval_episodes={checkpoint_eval_episodes}") + print(f"final_eval_episodes={final_eval_episodes}") + print(f"artifacts={run_dir}") + + try: + train_env = _create_train_env( + algo=args.algo, + env_kwargs=train_env_kwargs, + n_envs=args.envs, + seed=args.seed, + ) + model = _build_model( + algo=args.algo, + env=train_env, + seed=args.seed, + device=args.device, + hyperparams=hyperparams, + ) + except ImportError as exc: + print(f"Missing dependency: {exc}") + print("Run: uv sync") + sys.exit(1) + + checkpoint_splits = build_default_splits( + train_dataset=train_dataset, + val_dataset=val_dataset, + holdout_dataset=holdout_dataset, + include_family_holdout=args.include_family_holdout_checkpoints, + include_temporal_holdout=False, + train_families=train_families, + val_families=val_families, + family_holdout_families=family_holdout_families, + temporal_holdout_families=train_families, ) - parser.add_argument( - "--holdout-dataset", - type=str, - default="data/static/scenario_parameter_records_seeded_holdout.json", - help="Path to cached holdout scenario parameter JSON dataset", + + records_by_split = { + "train": train_records, + "val": val_records, + "family_holdout": val_records, + "temporal_holdout_diagnostic": holdout_records, + } + + config_payload = { + "algo": args.algo, + "run_label": args.run_label, + "seed": args.seed, + "timesteps": total_timesteps, + "n_envs": args.envs, + "device": args.device, + "datasets": { + "train": str(train_dataset), + "val": str(val_dataset), + "holdout": str(holdout_dataset), + }, + "record_counts": { + "train": len(train_records), + "val": len(val_records), + "holdout": len(holdout_records), + }, + "families": { + "train": _families_to_jsonable(train_families), + "val": _families_to_jsonable(val_families), + "family_holdout": _families_to_jsonable(family_holdout_families), + }, + "checkpoint": { + "interval_steps": checkpoint_interval, + "episodes": checkpoint_eval_episodes, + "visible_splits": [split.name for split in checkpoint_splits], + "selection_metric": "val.asset_survival_rate", + "tie_breaker": "val.mean_return", + "temporal_holdout_visible": False, + }, + "final_evaluation": { + "episodes": final_eval_episodes, + "include_family_holdout": args.include_family_holdout_final, + "include_temporal_holdout_diagnostic": args.include_temporal_holdout_final, + "compute_normalized_burn_ratio": not args.no_normalized_burn_final, + }, + "hyperparameters": hyperparams, + } + _write_json(run_dir / "config.json", config_payload) + + checkpoint_entries: list[dict[str, Any]] = [] + best_entry: dict[str, Any] | None = None + best_index: int | None = None + + while int(model.num_timesteps) < total_timesteps: + remaining = total_timesteps - int(model.num_timesteps) + learn_chunk = min(checkpoint_interval, remaining) + model.learn(total_timesteps=learn_chunk, reset_num_timesteps=False, progress_bar=False) + current_steps = int(model.num_timesteps) + + split_metrics = {} + for split in checkpoint_splits: + split_eval = evaluate_agent_on_split( + agent_name=args.algo, + model=model, + records=records_by_split[split.name], + expected_split=split.expected_split, + scenario_families=split.scenario_families, + seeds=[args.seed], + episodes_per_seed=checkpoint_eval_episodes, + compute_normalized_burn_ratio=False, + ) + split_metrics[split.name] = _single_seed_split_summary(split_eval) + + entry = { + "algo": args.algo, + "seed": args.seed, + "train_steps": current_steps, + "selected_for_best": False, + "splits": split_metrics, + } + checkpoint_entries.append(entry) + + if _selects_better_checkpoint(entry, best_entry): + best_entry = entry + best_index = len(checkpoint_entries) - 1 + model.save(str(run_dir / "best_model")) + + print( + f"checkpoint step={current_steps:,} " + f"val.asset_survival={entry['splits']['val']['asset_survival_rate']:.3f} " + f"val.return={entry['splits']['val']['mean_return']:.2f}" + ) + + model.save(str(run_dir / "last_model")) + + if best_index is None: + msg = "No checkpoint evaluations were produced; cannot select best checkpoint" + raise RuntimeError(msg) + + checkpoint_entries[best_index]["selected_for_best"] = True + best_entry = checkpoint_entries[best_index] + + _write_json(run_dir / "checkpoint_metrics.json", checkpoint_entries) + _write_json( + run_dir / "best_checkpoint.json", + { + "algo": args.algo, + "seed": args.seed, + "selected_train_steps": best_entry["train_steps"], + "selection_metric": "val.asset_survival_rate", + "tie_breaker": "val.mean_return", + "val_metrics": best_entry["splits"]["val"], + "best_checkpoint_entry": best_entry, + }, ) - args = parser.parse_args() - train( - total_timesteps=args.timesteps, - spread_rate_m_per_min=args.spread_rate, - n_envs=args.envs, - seed=args.seed, - scenario_dataset_path=args.scenario_dataset, - val_dataset_path=args.val_dataset, - holdout_dataset_path=args.holdout_dataset, - allow_legacy_dev_fallback=args.allow_legacy_dev_fallback, + best_model = load_model_for_algo(args.algo, run_dir / "best_model.zip") + final_splits = build_default_splits( + train_dataset=train_dataset, + val_dataset=val_dataset, + holdout_dataset=holdout_dataset, + include_family_holdout=args.include_family_holdout_final, + include_temporal_holdout=args.include_temporal_holdout_final, + train_families=train_families, + val_families=val_families, + family_holdout_families=family_holdout_families, + temporal_holdout_families=train_families, ) + + final_split_metrics = {} + for split in final_splits: + final_eval = evaluate_agent_on_split( + agent_name=args.algo, + model=best_model, + records=records_by_split[split.name], + expected_split=split.expected_split, + scenario_families=split.scenario_families, + seeds=[args.seed], + episodes_per_seed=final_eval_episodes, + compute_normalized_burn_ratio=not args.no_normalized_burn_final, + ) + final_split_metrics[split.name] = _single_seed_split_summary(final_eval) + + final_eval_payload: dict[str, Any] = { + "algo": args.algo, + "seed": args.seed, + "model_artifact": str(run_dir / "best_model.zip"), + "episodes_per_split": final_eval_episodes, + "splits": final_split_metrics, + } + if args.include_temporal_holdout_final and len(holdout_records) <= 1: + final_eval_payload["temporal_holdout_note"] = ( + "Temporal holdout contains one record and is reported as diagnostic-only evidence." + ) + + _write_json(run_dir / "final_eval_best.json", final_eval_payload) + + print("\nTraining complete") + print(f"best checkpoint step={best_entry['train_steps']:,}") + print(f"artifacts written to {run_dir}") + + +if __name__ == "__main__": + main() diff --git a/tests/models/test_benchmarking_metrics.py b/tests/models/test_benchmarking_metrics.py new file mode 100644 index 0000000..8f609d9 --- /dev/null +++ b/tests/models/test_benchmarking_metrics.py @@ -0,0 +1,132 @@ +from __future__ import annotations + +from pathlib import Path + +from src.models.benchmarking import ( + CANONICAL_CHECKPOINT_EVAL_EPISODES, + CANONICAL_CHECKPOINT_INTERVAL_STEPS, + CANONICAL_FINAL_EVAL_EPISODES, + CANONICAL_TIMESTEPS_BY_ALGO, + aggregate_seed_summaries, + canonical_train_preset, + evaluate_agent_on_split, + summarize_episodes, +) + + +def _record(**overrides): + record = { + "record_id": "AB-2020-001__20200101", + "split": "train", + "base_spread_prob": 0.14, + "severity_bucket": "medium", + "wind_direction": "E", + "wind_strength": 0.35, + "ignition_seed": 101, + "layout_seed": 202, + } + record.update(overrides) + return record + + +def test_canonical_train_preset_matches_frozen_protocol(): + preset = canonical_train_preset("ppo") + + assert preset["total_timesteps"] == CANONICAL_TIMESTEPS_BY_ALGO["ppo"] + assert preset["checkpoint_interval_steps"] == CANONICAL_CHECKPOINT_INTERVAL_STEPS + assert preset["checkpoint_eval_episodes"] == CANONICAL_CHECKPOINT_EVAL_EPISODES + assert preset["final_eval_episodes"] == CANONICAL_FINAL_EVAL_EPISODES + assert Path(preset["train_dataset"]).name == "scenario_parameter_records_seeded_train.json" + assert preset["holdout_checkpoint_visibility"] is False + + +def test_summarize_episodes_handles_absent_containment_times(): + summary = summarize_episodes( + [ + { + "return": 10.0, + "assets_lost": 1, + "containment_success": 0.0, + "burned_area_fraction": 0.20, + "time_to_containment": None, + "resource_efficiency": 0.0, + "wasted_deployment_rate": 1.0, + "final_burned_area_cells": 125, + }, + { + "return": -5.0, + "assets_lost": 2, + "containment_success": 0.0, + "burned_area_fraction": 0.25, + "time_to_containment": None, + "resource_efficiency": 0.5, + "wasted_deployment_rate": 0.25, + "final_burned_area_cells": 150, + }, + ] + ) + + assert summary["episodes"] == 2 + assert summary["mean_return"] == 2.5 + assert summary["containment_success_rate"] == 0.0 + assert summary["mean_time_to_containment"] is None + assert summary["mean_resource_efficiency"] == 0.25 + + +def test_evaluate_agent_on_split_reports_seed_and_aggregate_schema(): + result = evaluate_agent_on_split( + agent_name="random", + model=None, + records=[_record()], + expected_split="train", + scenario_families=[("center", "medium", "A")], + seeds=[11, 22], + episodes_per_seed=2, + compute_normalized_burn_ratio=True, + ) + + assert len(result["seed_metrics"]) == 2 + assert "aggregate" in result + aggregate = result["aggregate"] + assert aggregate["episodes_per_seed"] == 2 + assert aggregate["num_seeds"] == 2 + assert "mean_return" in aggregate + assert "asset_survival_rate" in aggregate + assert "mean_burned_area_fraction" in aggregate + assert "mean_normalized_burn_ratio" in aggregate + assert "std_across_seeds" in aggregate + assert "mean_return" in aggregate["std_across_seeds"] + + +def test_aggregate_seed_summaries_handles_optional_none_metrics(): + aggregate = aggregate_seed_summaries( + [ + { + "seed": 1, + "episodes": 3, + "mean_return": 10.0, + "asset_survival_rate": 0.5, + "containment_success_rate": 0.25, + "mean_burned_area_fraction": 0.3, + "mean_time_to_containment": None, + "mean_resource_efficiency": 0.2, + "wasted_deployment_rate": 0.7, + }, + { + "seed": 2, + "episodes": 3, + "mean_return": 20.0, + "asset_survival_rate": 0.75, + "containment_success_rate": 0.5, + "mean_burned_area_fraction": 0.1, + "mean_time_to_containment": None, + "mean_resource_efficiency": 0.4, + "wasted_deployment_rate": 0.6, + }, + ] + ) + + assert aggregate["mean_return"] == 15.0 + assert aggregate["std_across_seeds"]["mean_return"] == 5.0 + assert aggregate["mean_time_to_containment"] is None + assert aggregate["std_across_seeds"]["mean_time_to_containment"] is None