diff --git a/.ci-smoke/results.json b/.ci-smoke/results.json new file mode 100644 index 0000000..07760d9 --- /dev/null +++ b/.ci-smoke/results.json @@ -0,0 +1,78 @@ +{ + "greedy": { + "train": { + "episodes": 1, + "mean_return": -139.8, + "std_return": 0.0, + "asset_survival_rate": 0.0, + "containment_success_rate": 0.0, + "mean_final_burned_area": 557.0, + "mean_time_to_containment": 150.0, + "mean_resource_efficiency": 19.892857142857142, + "variance_across_episodes": 0.0, + "mean_normalized_burn_ratio": 0.911620294599018 + }, + "val": { + "episodes": 1, + "mean_return": -139.8, + "std_return": 0.0, + "asset_survival_rate": 0.0, + "containment_success_rate": 0.0, + "mean_final_burned_area": 557.0, + "mean_time_to_containment": 150.0, + "mean_resource_efficiency": 19.892857142857142, + "variance_across_episodes": 0.0, + "mean_normalized_burn_ratio": 0.911620294599018 + }, + "holdout": { + "episodes": 1, + "mean_return": -139.8, + "std_return": 0.0, + "asset_survival_rate": 0.0, + "containment_success_rate": 0.0, + "mean_final_burned_area": 557.0, + "mean_time_to_containment": 150.0, + "mean_resource_efficiency": 19.892857142857142, + "variance_across_episodes": 0.0, + "mean_normalized_burn_ratio": 0.911620294599018 + } + }, + "random": { + "train": { + "episodes": 1, + "mean_return": -297.3999999999999, + "std_return": 0.0, + "asset_survival_rate": 0.0, + "containment_success_rate": 0.0, + "mean_final_burned_area": 595.0, + "mean_time_to_containment": 150.0, + "mean_resource_efficiency": 21.25, + "variance_across_episodes": 0.0, + "mean_normalized_burn_ratio": 0.9738134206219312 + }, + "val": { + "episodes": 1, + "mean_return": -257.60000000000014, + "std_return": 0.0, + "asset_survival_rate": 0.0, + "containment_success_rate": 0.0, + "mean_final_burned_area": 581.0, + "mean_time_to_containment": 150.0, + "mean_resource_efficiency": 20.75, + "variance_across_episodes": 0.0, + "mean_normalized_burn_ratio": 0.9509001636661211 + }, + "holdout": { + "episodes": 1, + "mean_return": -266.00000000000006, + "std_return": 0.0, + "asset_survival_rate": 0.0, + "containment_success_rate": 0.0, + "mean_final_burned_area": 578.0, + "mean_time_to_containment": 150.0, + "mean_resource_efficiency": 20.642857142857142, + "variance_across_episodes": 0.0, + "mean_normalized_burn_ratio": 0.9459901800327333 + } + } +} \ No newline at end of file diff --git a/.ci-smoke/scenario_parameter_records_seeded_holdout.json b/.ci-smoke/scenario_parameter_records_seeded_holdout.json new file mode 100644 index 0000000..071c966 --- /dev/null +++ b/.ci-smoke/scenario_parameter_records_seeded_holdout.json @@ -0,0 +1 @@ +{"schema_version": 3, "split": "holdout", "record_count": 1, "records": [{"record_id": "ci-holdout", "base_spread_prob": 0.14, "severity_bucket": "medium", "wind_direction": "E", "wind_strength": 0.35, "ignition_seed": 101, "layout_seed": 202, "split": "holdout"}]} \ No newline at end of file diff --git a/.ci-smoke/scenario_parameter_records_seeded_train.json b/.ci-smoke/scenario_parameter_records_seeded_train.json new file mode 100644 index 0000000..cfd8ce8 --- /dev/null +++ b/.ci-smoke/scenario_parameter_records_seeded_train.json @@ -0,0 +1 @@ +{"schema_version": 3, "split": "train", "record_count": 1, "records": [{"record_id": "ci-train", "base_spread_prob": 0.14, "severity_bucket": "medium", "wind_direction": "E", "wind_strength": 0.35, "ignition_seed": 101, "layout_seed": 202, "split": "train"}]} \ No newline at end of file diff --git a/.ci-smoke/scenario_parameter_records_seeded_val.json b/.ci-smoke/scenario_parameter_records_seeded_val.json new file mode 100644 index 0000000..1854fea --- /dev/null +++ b/.ci-smoke/scenario_parameter_records_seeded_val.json @@ -0,0 +1 @@ +{"schema_version": 3, "split": "val", "record_count": 1, "records": [{"record_id": "ci-val", "base_spread_prob": 0.14, "severity_bucket": "medium", "wind_direction": "E", "wind_strength": 0.35, "ignition_seed": 101, "layout_seed": 202, "split": "val"}]} \ No newline at end of file diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 918d656..97a9266 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -19,14 +19,45 @@ jobs: - uses: actions/checkout@v4 - uses: astral-sh/setup-uv@v4 - run: uv sync - - name: Verify env imports and runs + - name: Build tiny seeded split datasets run: | uv run python -c " - from src.models.fire_env import WildfireEnv - env = WildfireEnv() - obs, _ = env.reset(seed=42) - assert obs.shape == (636,) - for _ in range(10): - obs, r, done, trunc, info = env.step(env.action_space.sample()) - print('smoke test passed') + import json + from pathlib import Path + + out = Path('.ci-smoke') + out.mkdir(exist_ok=True) + + base = { + 'record_id': 'ci-record', + 'base_spread_prob': 0.14, + 'severity_bucket': 'medium', + 'wind_direction': 'E', + 'wind_strength': 0.35, + 'ignition_seed': 101, + 'layout_seed': 202, + } + + for split in ('train', 'val', 'holdout'): + payload = { + 'schema_version': 3, + 'split': split, + 'record_count': 1, + 'records': [{**base, 'record_id': f'ci-{split}', 'split': split}], + } + (out / f'scenario_parameter_records_seeded_{split}.json').write_text( + json.dumps(payload) + ) + + print('tiny seeded datasets ready') " + - name: Tiny evaluator smoke test + run: | + uv run python -m src.models.evaluate_agents \ + --agents greedy,random \ + --episodes 1 \ + --seeds 42 \ + --train-dataset .ci-smoke/scenario_parameter_records_seeded_train.json \ + --val-dataset .ci-smoke/scenario_parameter_records_seeded_val.json \ + --holdout-dataset .ci-smoke/scenario_parameter_records_seeded_holdout.json \ + --output .ci-smoke/results.json diff --git a/README.md b/README.md index a0e9d16..b80e33a 100644 --- a/README.md +++ b/README.md @@ -44,8 +44,11 @@ We build the static dataset at `src/ingestion/static_dataset.py`. The script: - computes offline environment variables and writes `scenario_parameter_records.json` plus split files in `data/static`. The environment variables written are: - `base_spread_prob` - `severity_bucket` - - `wind_dir_deg` + - `wind_direction` (8-direction string) - `wind_strength` + - `ignition_seed` + - `layout_seed` +- writes seeded benchmark variants (`scenario_parameter_records_seeded.json` and `scenario_parameter_records_seeded_{train|val|holdout}.json`) for reproducible initialization; holdout seeded export is currently a single unique held-out record. - With the following extra fields stored: - `spread_rate_1h_m` - `spread_score` @@ -93,6 +96,8 @@ We run this command run to ingest our dataset (with a large cap to avoid split t uv run python -m src.ingestion.static_dataset --target-count 50000 --raw-alberta-csv data/static/fp-historical-wildfire-data-2006-2025.csv ``` +This command builds data from the CSV and generates initialization seeds for ignition and asset layout for the corresponding environment. CFFDRS was not used to reduce confounding variables and any bias introduced due to incomplete CFFDRS data ingested for some specific fires. + If CFFDRS for the selected year is sparse, the builder still runs and writes records without supplementary CFFDRS enrichment. Optionally, test with a smaller target count: @@ -112,10 +117,10 @@ uv run python -m src.ingestion.static_dataset --fire-records path/to/fire_record After building the dataset, you can train by running: ```bash -uv run python -m src.models.train_rl_agent --scenario-dataset data/static/scenario_parameter_records_train.json --val-dataset data/static/scenario_parameter_records_val.json --holdout-dataset data/static/scenario_parameter_records_holdout.json +uv run python -m src.models.train_rl_agent --scenario-dataset data/static/scenario_parameter_records_seeded_train.json --val-dataset data/static/scenario_parameter_records_seeded_val.json --holdout-dataset data/static/scenario_parameter_records_seeded_holdout.json ``` -The scenario parameter file can then be consumed by `FireEnv` and PPO training. +The seeded scenario parameter files are the canonical benchmark inputs for `FireEnv` and PPO training. The builder also writes year-based split files for the benchmark: @@ -126,13 +131,13 @@ The builder also writes year-based split files for the benchmark: Training command: ```bash -uv run python -m src.models.train_rl_agent --scenario-dataset data/static/scenario_parameter_records_train.json --val-dataset data/static/scenario_parameter_records_val.json --holdout-dataset data/static/scenario_parameter_records_holdout.json +uv run python -m src.models.train_rl_agent --scenario-dataset data/static/scenario_parameter_records_seeded_train.json --val-dataset data/static/scenario_parameter_records_seeded_val.json --holdout-dataset data/static/scenario_parameter_records_seeded_holdout.json ``` General split benchmark evaluation (PPO + baselines): ```bash -uv run python -m src.models.evaluate_agents --agents ppo,greedy,random --train-dataset data/static/scenario_parameter_records_train.json --val-dataset data/static/scenario_parameter_records_val.json --holdout-dataset data/static/scenario_parameter_records_holdout.json --episodes 20 --seeds 42,43,44 +uv run python -m src.models.evaluate_agents --agents ppo,greedy,random --train-dataset data/static/scenario_parameter_records_seeded_train.json --val-dataset data/static/scenario_parameter_records_seeded_val.json --holdout-dataset data/static/scenario_parameter_records_seeded_holdout.json --episodes 20 --seeds 42,43,44 ``` The dataset builder prints cleaning/drop summaries to stdout and uses progress bars when `tqdm` is available. diff --git a/docs/data-pipeline.md b/docs/data-pipeline.md index 17da66a..147e8bf 100644 --- a/docs/data-pipeline.md +++ b/docs/data-pipeline.md @@ -1,6 +1,6 @@ # Data Pipeline -This document describes the current benchmark data pipeline after the move away from live CWFIS-centered ingestion and XGBoost. +This document describes the current benchmark data pipeline for building frozen `FireEnv` datasets. The canonical path now uses the Alberta historical wildfire dataset stored under `data/static/` as the primary source for building `FireEnv` scenario records. @@ -13,13 +13,12 @@ The benchmark pipeline has two stages: 1. normalize historical wildfire incidents into frozen snapshot records 2. compute offline environment-variable records for `FireEnv` -Training and evaluation should then use only the cached parameter dataset plus seeded RNG. +Downstream benchmark consumers should use the seeded split parameter datasets (`scenario_parameter_records_seeded_{train|val|holdout}.json`) as frozen runtime inputs. Primary source hierarchy: - primary: Alberta historical wildfire dataset -- supplementary: CFFDRS fire-danger indices when an annual station file is available and usable -- non-canonical: CWFIS live active fires and FIRMS hotspots +- supplementary: CFFDRS fire-danger indices when an annual station file is available and date-matchable --- @@ -65,30 +64,11 @@ This module downloads annual CWFIS weather-station CSV data and parses: Current role: - supplementary enrichment only -- if `--cffdrs-year` is passed and usable observations exist, the builder joins the nearest station by both distance and snapshot date -- the benchmark no longer depends on CFFDRS being available to build records +- if `--cffdrs-year` is passed and usable observations exist, the builder joins the nearest station by distance, with date alignment to each fire snapshot (`max_date_offset_days=1`) +- the benchmark does not require CFFDRS availability to build records +- practical implication: one run fetches a single annual station file, so date-matched enrichment is usually concentrated in that selected year -### 2.3 `src/ingestion/cwfis.py` - -This module still downloads live active fires from CWFIS. - -Current role: - -- legacy / non-canonical -- useful for live experiments or future non-Alberta extensions -- not part of the canonical Alberta historical benchmark build - -### 2.4 `src/ingestion/firms.py` - -This module still fetches NASA FIRMS hotspots. - -Current role: - -- supplementary / non-canonical -- not used in the canonical Alberta historical benchmark build -- may still be useful for exploratory validation or future data discovery - -### 2.5 `src/ingestion/weather.py` +### 2.3 `src/ingestion/weather.py` This module fetches current-hour weather from Open-Meteo. @@ -107,12 +87,12 @@ Alberta historical wildfire CSV -> optional CFFDRS date-and-distance enrichment -> snapshot_records.json -> offline env-variable builder --> scenario_parameter_records.json +-> scenario_parameter_records.json (unseeded build artifact) +-> scenario_parameter_records_seeded_{split}.json (benchmark runtime artifact) -> FireEnv reset sampling --> RL train/eval from cached records only ``` -This path does not use FIRMS or Open-Meteo in the canonical benchmark build. +This path does not use FIRMS/CWFIS or Open-Meteo in the canonical benchmark build. The builder logs cleaning and drop diagnostics directly to stdout (with progress bars if `tqdm` is available) instead of writing a separate report artifact. @@ -139,7 +119,17 @@ Row-level cleaning behavior: - `ASSESSMENT_HECTARES` - `CURRENT_SIZE` -Normalization-time vetting in `src/ingestion/static_dataset.py` additionally drops rows that fail parsing or mapping, such as invalid datetimes, non-numeric required values, and unresolved wind direction values. +Normalization-time vetting in `src/ingestion/static_dataset.py` additionally drops rows that fail parsing or mapping, such as invalid datetimes, non-numeric required values, unresolved wind direction values, or years outside the frozen split strategy. + +### 3.2 Candidate selection and truncation behavior + +After normalization, candidate fires are selected in this order: + +- deduplicate by `fire_id` (keep the first occurrence encountered for each fire id) +- rank remaining candidates by descending `(observed_spread_rate_m_min, assessment_hectares/area_hectares, year, fire_id)` +- apply `--target-count` as a per-split cap (`train`, `val`, `holdout`) + +This means `--target-count 100` exports up to `100` records per split, not `100` total. Current drop diagnostics printed to stdout include: @@ -168,8 +158,16 @@ Generated split files: - `data/static/scenario_parameter_records_train.json` - `data/static/scenario_parameter_records_val.json` - `data/static/scenario_parameter_records_holdout.json` +- `data/static/scenario_parameter_records_seeded.json` +- `data/static/scenario_parameter_records_seeded_train.json` +- `data/static/scenario_parameter_records_seeded_val.json` +- `data/static/scenario_parameter_records_seeded_holdout.json` + +Seeded parameter files include deterministic `ignition_seed` and `layout_seed` for reproducible environment initialization. +For the current benchmark setup, `scenario_parameter_records_seeded_holdout.json` is intentionally reduced to one unique held-out record. +In benchmark mode, `FireEnv` expects these seed fields to be present on all loaded records. -Each snapshot record represents one Alberta wildfire incident anchored at the initial assessment time. +Each snapshot record represents one selected Alberta wildfire incident row after deduplication/ranking. Core stored fields: @@ -183,25 +181,46 @@ Core stored fields: - response timing metadata: `detection_delay_h`, `report_delay_h`, `dispatch_delay_h`, `ia_travel_delay_h` - optional supplementary enrichment: `fwi`, `isi`, `bui`, `dc`, `dmc`, `ffmc`, station metadata +Additional metadata currently written: + +- split and lifecycle metadata: `split`, `status`, `record_quality_flag`, `snapshot_generated_at` +- CFFDRS alignment metadata: `cffdrs_station_distance_km`, `cffdrs_station_id`, `cffdrs_station_name`, `cffdrs_observation_date`, `cffdrs_date_offset_days`, `temporal_alignment_status` + Important notes: - `snapshot_date` is anchored to `ASSESSMENT_DATETIME` - `area_hectares` prefers `ASSESSMENT_HECTARES`, with `CURRENT_SIZE` as fallback - `precipitation_mm` is estimated from `WEATHER_CONDITIONS_OVER_FIRE` -- CFFDRS fields may be `null` if supplementary enrichment is unavailable +- CFFDRS fields may be `null` if supplementary enrichment is unavailable or no station/date match is found +- `temporal_alignment_status` is one of: `aligned`, `near_aligned`, `not_joined` + +Top-level JSON payload shape for output files: + +- `schema_version` +- `generated_at` +- `record_count` +- `records` --- ## 5) Environment-Variable Builder -The builder computes `data/static/scenario_parameter_records.json` from each snapshot record. +The builder computes `data/static/scenario_parameter_records.json` from each snapshot record, then writes seeded benchmark variants in `data/static/scenario_parameter_records_seeded*.json`. Canonical env-facing fields: - `base_spread_prob` - `severity_bucket` -- `wind_dir_deg` +- `wind_direction` (8-direction string: `N`, `NE`, `E`, `SE`, `S`, `SW`, `W`, `NW`) - `wind_strength` +- `ignition_seed` +- `layout_seed` + +Canonical integration note: + +- ignition family and asset layout remain simulator-side controls +- seeded parameter records do not store explicit ignition/layout labels +- `ignition_seed` and `layout_seed` make those simulator-side initializations reproducible Stored audit fields: @@ -237,11 +256,19 @@ This is not a full Rothermel implementation. It is a benchmark-oriented, physics | Stored env field | Source fields | Builder logic | Used by environment | |---|---|---|---| | `base_spread_prob` | `observed_spread_rate_m_min`, weather, size, optional CFFDRS dryness, `fire_type`, `fuel_type` | derived from blended `spread_score` | primary spread probability in `_spread_fire()` | -| `severity_bucket` | same fields as `base_spread_prob` | derived from `spread_score` thresholds | severity one-hot in observation and family matching | -| `wind_dir_deg` | `wind_direction_deg` | pass-through from Alberta assessment weather | converted to `(wx, wy)` wind bias | +| `severity_bucket` | same fields as `base_spread_prob` | derived from `spread_score` thresholds | severity one-hot in observation | +| `wind_direction` | `wind_direction_deg` | mapped to 8-direction bins (`N`, `NE`, `E`, `SE`, `S`, `SW`, `W`, `NW`) | converted to `(wx, wy)` wind bias | | `wind_strength` | `wind_speed_km_h` | normalized and clipped from assessment wind speed | sets wind-bias magnitude | +| `ignition_seed` | `record_id`, `split` | deterministic stable hash | seeds ignition initialization RNG | +| `layout_seed` | `record_id`, `split` | deterministic stable hash | seeds asset-layout initialization RNG | | `spread_rate_1h_m` | `observed_spread_rate_m_min` | direct conversion to `m/hour` for audit/logging | optional logging only | +Benchmark runtime file contract: + +- canonical train/eval inputs are split-specific seeded files (`scenario_parameter_records_seeded_{split}.json`) +- benchmark loaders enforce split consistency from both filename hints and per-record `split` values +- mixed-split datasets are rejected in benchmark mode + Audit-only intermediates: | Stored audit field | Source fields | Purpose | @@ -278,6 +305,8 @@ Build with optional supplementary CFFDRS enrichment: uv run python -m src.ingestion.static_dataset --target-count 100 --cffdrs-year 2025 ``` +`--cffdrs-year` downloads one annual CFFDRS station file for that year and attempts snapshot-date alignment (within one day) for each candidate fire. + Canonical variant with CFFDRS enrichment: ```bash @@ -296,17 +325,17 @@ Override the raw Alberta CSV path if needed: uv run python -m src.ingestion.static_dataset --raw-alberta-csv path/to/fp-historical-wildfire-data.csv --target-count 100 ``` -Then train from the cached parameter file: +Write outputs to a custom directory: ```bash -uv run python -m src.models.train_rl_agent --scenario-dataset data/static/scenario_parameter_records_train.json --val-dataset data/static/scenario_parameter_records_val.json --holdout-dataset data/static/scenario_parameter_records_holdout.json +uv run python -m src.ingestion.static_dataset --output-dir path/to/output --target-count 100 ``` -Recommended benchmark training/eval uses the split files directly: +Canonical benchmark consumers should point training/evaluation envs at seeded split files, for example: -```bash -uv run python -m src.models.train_rl_agent --scenario-dataset data/static/scenario_parameter_records_train.json --val-dataset data/static/scenario_parameter_records_val.json --holdout-dataset data/static/scenario_parameter_records_holdout.json -``` +- `data/static/scenario_parameter_records_seeded_train.json` +- `data/static/scenario_parameter_records_seeded_val.json` +- `data/static/scenario_parameter_records_seeded_holdout.json` --- @@ -314,7 +343,8 @@ uv run python -m src.models.train_rl_agent --scenario-dataset data/static/scenar - Alberta historical data is Alberta-only, so the canonical benchmark is currently province-scoped rather than Canada-wide. - CFFDRS annual station files may be sparse or unavailable for some years; the builder treats them as optional. -- FIRMS and CWFIS remain available in the repo but are no longer part of the canonical benchmark build path. +- In one run, CFFDRS enrichment is sourced from a single requested annual file; historical records from other years are unlikely to date-align. +- Canonical ingestion does not use FIRMS or CWFIS live-fire modules. - The benchmark still does not use terrain rasters, perimeter replay, or a full operational spread model. -That is acceptable for the current paper because the goal is a reproducible tactical RL benchmark, not an operational wildfire decision-support system. +That is acceptable for the current paper because the goal is a reproducible tactical benchmark dataset, not an operational wildfire decision-support system. diff --git a/docs/envspec.md b/docs/envspec.md index 88257a8..2d2f69f 100644 --- a/docs/envspec.md +++ b/docs/envspec.md @@ -1,327 +1,235 @@ -# Frozen Environment Spec: Wildfire Simulator +# Fire Environment Specification (Current Implementation) -This document is the frozen canonical specification for the wildfire RL benchmark and its static scenario-parameter interface. +This document describes the currently implemented RL environment in `src/models/fire_env.py` and its training/evaluation usage in `src/models/train_rl_agent.py` and `src/models/evaluate_agents.py`. -It is aligned with `docs/planning/impl-plan.md` and is intended to remove ambiguity before coding, benchmarking, and reporting. +It is intentionally code-first and only documents behavior that exists in the current codebase. --- -## 1) What the Agent Represents +## 1) Environment Definition -The agent is a single tactical controller operating on a grid map. +`WildfireEnv` is a single-agent, discrete-action tactical wildfire suppression environment. -- It decides movement and suppression actions each step. -- It has limited suppression resources (helicopter and crew budgets). -- Its mission is not to minimize all fire everywhere; its mission is to protect critical assets under budget. - -Think of it as a simplified incident-response decision unit. - ---- - -## 2) Core Environment Definition - -## 2.1 Grid and Episode Constants (canonical) - -- Grid size: `25 x 25` +- Agent count: `1` +- Grid size: `25 x 25` (`625` cells) - Episode horizon: `150` steps -- Per-episode budgets: - - `heli_left = 8` - - `crew_left = 20` -- Cooldowns: - - helicopter cooldown: `5` steps - - crew cooldown: `2` steps +- Action space: `6` discrete actions (`MOVE_N`, `MOVE_S`, `MOVE_E`, `MOVE_W`, `DEPLOY_HELICOPTER`, `DEPLOY_CREW`) +- Fire process: one evolving fire field per episode (multiple burning cells can exist simultaneously) -## 2.2 Cell encoding +Cell encodings: - `0`: unburned - `1`: burning - `2`: burned -- `3`: suppressed (firebreak or suppressed burn) -- `4`: critical asset (unburned) -- `5`: critical asset damaged/burning (internal bookkeeping can be separate, but report this event explicitly) - ---- - -## 3) Observation (State) +- `3`: suppressed +- `4`: critical asset +- `5`: critical asset burned/lost -At each step, the policy receives: +Per-episode resource limits: -1. Fire grid (`25x25`, encoded cells above) -2. Agent position `(row, col)` -3. Remaining resources: `heli_left`, `crew_left` -4. Cooldowns: `heli_cd`, `crew_cd` -5. Severity one-hot: `[low, medium, high]` -6. Wind bias vector: `(wx, wy)` +- helicopter budget: `8`, cooldown: `5` +- crew budget: `20`, cooldown: `2` -Canonical observation rule: - -- The benchmark observation is the single encoded grid plus scalar features listed above. -- Multi-channel observation variants are allowed only as ablations or future work and must not replace the canonical benchmark interface in the main comparison. - -This state lets the policy reason about: - -- where the fire is, -- what must be protected, -- what resources are still available, -- how spread is likely to move spatially. +About "how many states": this environment has a very large combinatorial state space (not a small enumerable finite-state MDP in practice). The policy input vector shape is fixed at `636`. --- -## 4) Action Semantics (Hard Rules) - -Action categories: +## 2) How the Environment Gets and Uses Data from the Pipeline -- Mobility actions: `MOVE_N`, `MOVE_S`, `MOVE_E`, `MOVE_W` -- Intervention actions: `DEPLOY_HELICOPTER`, `DEPLOY_CREW` +Canonical benchmark mode consumes seeded split scenario records produced by the data pipeline: -Action IDs: +- `data/static/scenario_parameter_records_seeded_train.json` +- `data/static/scenario_parameter_records_seeded_val.json` +- `data/static/scenario_parameter_records_seeded_holdout.json` -- `0`: `MOVE_N` -- `1`: `MOVE_S` -- `2`: `MOVE_E` -- `3`: `MOVE_W` -- `4`: `DEPLOY_HELICOPTER` -- `5`: `DEPLOY_CREW` +Loader and integration flow: -Canonical action rule: +1. `load_scenario_parameter_records(...)` validates each record. +2. `create_benchmark_env(...)` creates `WildfireEnv` with strict benchmark settings. +3. `reset()` samples one cached record and maps it to `ScenarioConfig`. +4. The sampled record stays fixed for that episode. -- The canonical benchmark action set contains exactly these 6 actions. -- `WAIT` is not part of the frozen benchmark and may be introduced only in ablations. +Required benchmark fields: -Rules: +- `record_id`, `split` +- `base_spread_prob`, `severity_bucket` +- `wind_direction` (8-direction string), `wind_strength` +- `ignition_seed`, `layout_seed` -- Movement changes position by one cell if in bounds; otherwise no movement. -- Deployment actions act at the **current agent cell**. -- Helicopter footprint: **3x3** neighborhood centered at agent. -- Crew footprint: **1x1** current cell only. -- Both can target burning or non-burning cells. -- Burning cells in footprint become `suppressed`. -- Unburned cells in footprint become `suppressed` firebreaks. +How fields are used in env runtime: -Budget/cooldown gating: +- `base_spread_prob`: baseline spread probability +- `severity_bucket`: severity one-hot in observation +- `wind_direction` + `wind_strength`: wind-bias vector for spread +- `ignition_seed` / `layout_seed`: reproducible initialization RNGs for ignition and asset placement -- Helicopter requires `heli_left > 0` and `heli_cd == 0`. -- Crew requires `crew_left > 0` and `crew_cd == 0`. -- Successful helicopter use: `heli_left -= 1`, `heli_cd = 5`. -- Successful crew use: `crew_left -= 1`, `crew_cd = 2`. -- Cooldowns decrement by 1 each step down to 0. +Important boundary: -Wasted action definition: - -An action is wasted if either: - -1. Deployment attempted while blocked by cooldown/budget, or -2. Deployment causes zero state change in its footprint. +- ignition family and asset layout labels remain simulator-side controls +- seeded records do not store explicit ignition/layout labels +- seeds make simulator-side initialization reproducible --- -## 5) Fire Dynamics and Transition +## 3) Implementation Details (Variable Updates and Core Functions) -Each step after action execution: +Record loading and benchmark setup: -1. Apply suppression effects. -2. Spread fire stochastically from burning cells to neighbors. -3. Apply burn progression/burnout rules. -4. Update asset-loss counters when fire reaches asset cells. +- `load_scenario_parameter_records`: schema/range/split validation +- `benchmark_env_kwargs`, `create_benchmark_env`: canonical benchmark env factory +- `scenario_from_parameter_record`: maps one record into `ScenarioConfig` -Spread probability is episode-parameterized: +Episode construction and parameter sampling: -- baseline from `base_spread_prob` -- adjusted by wind bias `(wx, wy)` relative to neighbor direction +- `reset`: samples scenario family + parameter record, resets budgets/cooldowns/state, places assets, ignites fire +- `_sample_parameter_record`: seed-stable shuffled sampling over loaded records +- `_configure_initialization_rngs`: configures ignition/layout RNGs from record seeds -Canonical heterogeneity rule: +State transition internals: -- Wind bias is the only mandatory heterogeneity mechanism in canonical runs. -- Additional local modifiers such as flammability maps are ablations or future work and must be reported separately. -- Control-tick versus fire-tick cadence changes are also deferred to ablations or future work. +- `_execute_action`: movement/suppression effects + immediate action reward terms +- `_spread_fire`: wind-biased stochastic spread + asset-loss accounting + burnout +- `_ignite`: ignition pattern initialization (`center`, `edge`, `corner`, `multi_cluster`) +- `_place_assets`: layout initialization (`A`, `B`) -Episode termination: +Observation assembly: -- success if no burning cells remain, or -- horizon reached at step 150. +- `_get_obs`: builds the flat `636`-length policy input vector --- -## 6) Reward Function (Single-Objective) +## 4) Observations (What the Agent Sees) -Per-step reward: +The policy receives a single flat vector with shape `636`: -```text -r_t = - - 75.0 * asset_cells_lost_t - - 0.4 * new_burned_cells_t - + 3.0 * burning_cells_suppressed_t - - 1.5 * heli_used_t - - 0.5 * crew_used_t - - 1.0 * wasted_action_t -``` - -Terminal shaping: +1. flattened grid: `25 * 25 = 625` +2. agent position (normalized row, col): `2` +3. resources/cooldowns (normalized): `4` + - `heli_left`, `crew_left`, `heli_cd`, `crew_cd` +4. severity one-hot: `3` +5. wind bias vector: `2` (`wx`, `wy`) -- `+100` if fire extinguished and no asset loss. -- `+40` if episode ends with all assets intact. +Total: `625 + 2 + 4 + 3 + 2 = 636` -Interpretation: - -- Asset protection is dominant objective. -- Burn suppression and burn growth provide dense learning signal. -- Resource costs prevent degenerate spam strategies. +The environment returns `obs, reward, terminated, truncated, info` at each step. --- -## 7) Static Scenario Parameter Interface - -The benchmark uses cached scenario records with environment variables computed offline before training and evaluation. These variables are not inferred live during benchmark runs. - -## 7.1 Snapshot inputs used during preprocessing - -Required canonical fields available to the preprocessing pipeline: +## 5) Fire Dynamics and Transition -- weather: `wind_speed_km_h`, `wind_direction_deg`, `temperature_c`, `relative_humidity_pct`, `precipitation_mm` -- danger: `fwi`, `isi`, `bui` -- incident: `area_hectares`, `latitude`, `longitude`, `province` +Per-step order: -Optional retained metadata: +1. decrement cooldowns +2. execute action +3. apply resource-use penalties (`-1.5` heli, `-0.5` crew when used) +4. advance spread via `_spread_fire` +5. apply asset-loss and burn-growth penalties +6. evaluate termination/truncation and terminal bonuses -- `frp_mw` -- `cffdrs_station_distance_km` -- `dmc`, `dc`, `ffmc` +Spread rule: -## 7.2 Stored parameter record contract +- base spread: `base_spread_prob` +- wind adjustment: `base + 0.15 * wind_dot` +- clipped spread probability: `[0.01, 0.95]` +- neighborhood: 4-connected (`N/S/E/W`) +- burning cells have `0.05` burnout chance each step -For each cached scenario record, store: +Wind handling: -1. `base_spread_prob` -2. `severity_bucket` -3. `wind_dir_deg` -4. `wind_strength` -5. optional logging fields such as `spread_rate_1h_m` +- `wind_direction` is discrete 8-direction (`N`, `NE`, `E`, `SE`, `S`, `SW`, `W`, `NW`) +- each direction maps to a unit/directional vector, scaled by `wind_strength` -Deterministic env mapping: +Termination: -- severity is encoded one-hot in the observation from `severity_bucket` -- wind vector: - - `wx = wind_strength * cos(wind_dir_deg)` - - `wy = wind_strength * sin(wind_dir_deg)` -- `base_spread_prob` is consumed directly by the environment spread rule - -Episode rule: - -- Sample one cached parameter record at reset. -- Keep it fixed for the full episode in canonical runs. +- `terminated=True` when no burning cells remain +- `truncated=True` when step count reaches `150` --- -## 8) Mandatory Snapshot Pipeline for Reproducibility +## 6) The Reward Function -Training/evaluation must never depend on live API calls. +Reward is assembled from multiple terms in `step()` and `_execute_action()`. -Required workflow: +Main terms: -1. Collect and normalize ingestion data. -2. Write versioned snapshot cache file(s). -3. Compute environment variables offline from snapshot records. -4. Produce cached env-parameter records. -5. Train/evaluate RL only from cached records + seeded RNG. +- `-75.0 * asset_cells_lost` +- `-0.4 * new_burned` where `new_burned = max(0, burning_now - prev_burning)` +- helicopter use cost: `-1.5` +- crew use cost: `-0.5` +- wasted/blocked deployment penalty: `-1.0` +- suppression bonuses: + - helicopter: `+3.0` per suppressed affected cell in `3x3` + - crew on burning cell: `+3.0` + - crew firebreak on unburned cell: `+2.0` -Fail-fast rule: - -- If required fields are missing in benchmark mode, error out. -- Do not silently inject hidden defaults during benchmark runs. +Terminal shaping: ---- +- `+100` if fire is extinguished and assets lost is `0` +- `+40` if episode ends (terminated or truncated) with assets lost still `0` -## 9) Training Process for the Agent +Optimization intent: -## 9.1 Algorithms +- primary objective is minimizing asset damage/loss +- other terms provide dense tactical shaping for learning stability -- DQN -- A2C -- PPO -- Baselines: greedy, random +--- -## 9.2 Fixed protocol +## 7) Training Process (Current Code) -- Train steps: `200,000` per algorithm per seed -- Seeds: `11, 22, 33, 44, 55` -- Eval cadence: every `20,000` steps -- Eval episodes/checkpoint: `20` -- Final eval episodes per seed: `100` +Current training script: -## 9.3 Scenario families +- `src/models/train_rl_agent.py` +- algorithm currently implemented in this script: `PPO` (Stable-Baselines3) -Asset layout definitions: +Canonical training flow: -- Layout `A`: one dense high-value asset cluster placed near moderate exposure to common ignition zones. -- Layout `B`: two smaller separated asset clusters with different distances from common ignition zones. +1. load seeded train split dataset +2. create vectorized benchmark envs (`n_envs`) +3. train PPO for configured timesteps +4. save model to `src/models/tactical_ppo_agent.zip` +5. run quick evaluation on train and optional val/holdout datasets -Train families: +Current benchmark evaluation script: -- ignition in `{center, edge, multi_cluster}` -- severity in `{low, medium, high}` -- asset layout `A` +- `src/models/evaluate_agents.py` +- evaluates agents across splits (`train`, `val`, `holdout`) +- supported evaluated agents: `ppo`, `greedy`, `random` +- can output JSON summary via `--output` -Held-out families: +Transparency outputs from current code: -- ignition `corner` x all severities x layout `A` -- ignition in `{center, edge, multi_cluster}` x severity `medium` x layout `B` +- training console output: timesteps, env count, dataset path/count, quick split metrics +- model artifact: `tactical_ppo_agent.zip` +- evaluation console summary per split/agent +- optional evaluation JSON with aggregate metrics -## 9.4 Training loop (conceptual) +Recommended transparency plots (from saved eval JSON/logs): -```text -for seed in [11,22,33,44,55]: - set global RNG seed - for algorithm in [DQN, A2C, PPO]: - init agent + env factory - for step in training_steps: - collect transitions - update policy/value per algorithm - if step % eval_interval == 0: - evaluate on fixed eval set (no fallback) - log metrics - run final 100-episode evaluation -aggregate results across seeds -compare against greedy and random baselines -``` +- split-wise mean return (`train` vs `val` vs `holdout`) +- split-wise asset survival and containment rates +- final burned area distribution by split/agent +- seed variability/error bars for key metrics --- -## 10) What to Report - -Primary metrics: - -1. mean episodic return -2. asset survival rate -3. containment success rate -4. final burned area -5. variance across seeds +## 8) Reporting Metrics -Secondary metrics: +Primary optimization target/what the agent is trained to do: **Minimize assets damaged/lost** -- time to containment -- resource efficiency -- wasted deployment rate -- held-out performance drop -- normalized burn ratio +Additional reported metrics (already computed or directly derivable from current eval): -Normalized burn ratio definition: +- mean episodic return +- standard deviation / variance across episodes +- asset survival rate +- containment success rate +- mean final burned area +- mean time to containment +- mean resource efficiency +- mean normalized burn ratio (optional in evaluator) -- `normalized_burn_ratio = final_burned_area_with_policy / final_burned_area_no_action_same_scenario` -- For each evaluation scenario, run a no-action baseline with the same initial scenario record and RNG seed. -- Report this as an evaluation-only metric; do not include it in the training reward. - -Interpretability checks: - -- which assets are protected first, -- deployment timing under low vs high severity, -- behavior under held-out corner ignition scenarios. - ---- +Report these diagnostics during training: -## 11) Minimal Sanity Checklist Before Full Runs +- train/val/holdout gap for each metric +- per-seed summary tables +- baseline comparisons (`greedy`, `random`) against PPO -1. Reward sanity pass (20k PPO steps, 1 seed). -2. Confirm non-zero asset-loss and suppression events in logs. -3. Confirm budgets/cooldowns are enforced. -4. Confirm no live API access during train/eval. -5. Confirm held-out scenario IDs are excluded from training. diff --git a/docs/planning/env-checklist.md b/docs/planning/env-checklist.md index 85b5fcf..2f53477 100644 --- a/docs/planning/env-checklist.md +++ b/docs/planning/env-checklist.md @@ -1,86 +1,204 @@ -# Environment Checklist +# Environment Changelog (Pipeline Alignment) -This checklist captures the remaining environment changes needed to align `src/models/fire_env.py` with the current benchmark design in `docs/envspec.md` and `docs/planning/impl-plan.md`. +This document records what was implemented for environment alignment with the frozen static data pipeline. + +Primary implementation files: +- `src/models/fire_env.py` +- `src/models/train_rl_agent.py` +- `src/models/evaluate_agents.py` +- `tests/models/test_fire_env_setup_contract.py` --- -## 1) Replace remaining heuristic-only environment defaults +## 1) Frozen scenario records became the canonical setup path + +- [x] Add an explicit benchmark/frozen mode flag in `WildfireEnv`. + - Implemented in `WildfireEnv.__init__` via `benchmark_mode` (`src/models/fire_env.py`). + - Benchmark mode is now the canonical constructor path for train/eval. +- [x] In benchmark mode, require `scenario_parameter_records` and fail fast if missing or empty. + - Added hard checks in `WildfireEnv.__init__` that raise `ValueError` when records are missing. + - Added additional reset-time hard guard for impossible benchmark state. +- [x] Remove silent fallback to `random_scenario()` in benchmark mode. + - `WildfireEnv.reset` now raises in benchmark mode if no cached records are available. + - Silent fallback to `random_scenario()` remains available only when `benchmark_mode=False`. +- [x] Keep heuristic/random scenario generation as explicit dev/ablation mode only. + - Dev path kept through `random_scenario()` and non-benchmark env setup. + - Error messages explicitly state benchmark mode cannot use dev fallback paths. +- [x] Keep or remove `base_spread_rate_m_per_min` legacy path intentionally. + - Legacy spread-rate path retained but explicitly dev-only. + - Benchmark mode rejects `base_spread_rate_m_per_min`. + - Training script now requires explicit `--allow-legacy-dev-fallback` for this path. + +Changed functions/files: +- `WildfireEnv.__init__`, `WildfireEnv.reset`, `random_scenario` in `src/models/fire_env.py` +- `train` and CLI args in `src/models/train_rl_agent.py` +- benchmark env creation usage in `src/models/evaluate_agents.py` -- [ ] Remove the old assumption that severity alone determines spread through `SEVERITY_SPREAD_PROB` when a cached `base_spread_prob` record is available. -- [ ] Make the canonical path use static scenario parameter records first, with severity acting as reporting/observation metadata rather than the main spread heuristic. -- [ ] Keep severity heuristics only as a fallback dev mode, not as the benchmark-default path. -- [ ] Decide whether canonical benchmark mode should hard-fail if no cached parameter records are provided. +--- -## 2) Make reset-time episode construction more dataset-driven +## 2) Scenario-record schema validation was enforced + +- [x] Validate required fields on load (`record_id`, `split`, `base_spread_prob`, `severity_bucket`, `wind_direction`, `wind_strength`, `ignition_seed`, `layout_seed` in benchmark mode). + - Added required-field checks in `load_scenario_parameter_records`. + - Normalizes validated records (typed numeric fields, lowercase enum/split). +- [x] Validate value domains (severity enum, numeric ranges, finite floats). + - Added severity and split enum checks. + - Added finite/range checks for spread and wind numeric values. + - Added optional seed validation for `ignition_seed` and `layout_seed`. +- [x] Fail in benchmark mode; warn-and-skip in dev mode. + - Benchmark mode raises actionable `ValueError` with sampled invalid rows. + - Dev mode logs warnings and skips invalid records. +- [x] Hard-reject missing/invalid `split` in benchmark mode. + - Decision implemented at loader and env-construction levels. + +Changed functions/files: +- `load_scenario_parameter_records` in `src/models/fire_env.py` +- benchmark load call sites in `src/models/train_rl_agent.py` and `src/models/evaluate_agents.py` -- [ ] Decide what belongs in the cached scenario record versus what remains randomized inside the simulator. -- [ ] If desired, extend cached records to include reset-time metadata such as ignition family, asset layout, and optional size/dryness tags. -- [ ] Stop sampling scenario families and parameter records independently if that can create inconsistent pairings. -- [ ] Replace the current severity-only record matching with a stronger record-selection rule tied to the frozen train/held-out split. +--- -## 3) Freeze the canonical benchmark mode more strictly +## 3) Reset-time selection moved from severity matching to record sampling -- [ ] Add an explicit benchmark mode flag to `FireEnv` so train/eval runs cannot silently fall back to ad hoc random scenario generation. -- [ ] In benchmark mode, fail fast on missing or malformed parameter records. -- [ ] Keep legacy `base_spread_rate_m_per_min` support only for backward compatibility or remove it entirely once the static dataset path is stable. -- [ ] Ensure benchmark mode never depends on runtime live ingestion. +- [x] Stop selecting cached records by `severity_bucket` only. + - Removed severity-only filtering from `WildfireEnv.reset`. + - Reset now samples from full validated records. +- [x] Use deterministic/seed-stable record sampling. + - Added `_sample_parameter_record` with shuffled index order and cursor. + - Sampling uses env RNG and is reshuffled on seed-driven resets. +- [x] Keep `severity_bucket` as metadata, not primary selector. + - `severity_bucket` is consumed from selected record in `scenario_from_parameter_record`. + - It now affects scenario state only through the selected record. +- [x] Track sampled `record_id` per episode. + - Added `_active_record_id` and included `record_id` in reset and step `info`. -## 4) Align environment parameters with the static dataset builder +Changed functions/files: +- `WildfireEnv.reset`, `_sample_parameter_record`, `scenario_from_parameter_record` in `src/models/fire_env.py` -- [ ] Confirm the cached parameter schema used by `FireEnv` matches the output of `src/ingestion/static_dataset.py`. -- [ ] Use `base_spread_prob`, `wind_dir_deg`, and `wind_strength` directly from cached records in the canonical path. -- [ ] Keep extra builder fields such as `spread_rate_1h_m`, `spread_score`, `dryness_score`, and `record_quality_flag` for logging/debugging only unless promoted into the canonical env contract. -- [ ] Decide whether `severity_bucket` should be fully precomputed offline rather than inferred from old hard-coded spread heuristics. +--- -## 5) Tighten the reward and transition accounting +## 4) Environment variable ingestion aligned with pipeline contract -- [ ] Check whether `new_burned` is currently measuring the intended quantity; it now tracks the change in burning cells rather than newly burned cells strictly. -- [ ] Verify the reward matches the frozen coefficients and intended semantics in `docs/envspec.md`. -- [ ] Confirm wasted-action logic matches the benchmark wording for blocked and zero-effect deployments. -- [ ] Confirm asset-loss accounting is correct when assets transition into burning cells. +- [x] Canonical runtime uses cached `base_spread_prob`, `wind_direction` (8-direction string), `wind_strength`, and `severity_bucket` directly. + - `scenario_from_parameter_record` now maps required cached fields directly into `ScenarioConfig`. + - Removed permissive fallback defaults for canonical record mapping. +- [x] Keep audit fields as logging/debug unless promoted. + - Added `PARAMETER_AUDIT_FIELDS` and exposed them via `parameter_audit` in `info`. + - Audit fields are not used in transition dynamics. +- [x] Surface selected metadata in episode info. + - Added `PARAMETER_METADATA_FIELDS` and surfaced `parameter_record_meta`. + - `record_id` and `split` are included directly in reset and step `info`. +- [x] Ensure optional CFFDRS fields do not imply runtime fetches. + - Benchmark runtime only consumes cached records. + - Added explicit module-level note that benchmark env does not fetch FIRMS/CWFIS/Open-Meteo/CFFDRS at runtime. -## 6) Decide what remains randomized inside the simulator +Changed functions/files: +- `scenario_from_parameter_record`, `WildfireEnv.reset`, `WildfireEnv.step`, `_parameter_metadata`, `_parameter_audit` in `src/models/fire_env.py` -- [ ] Keep ignition coordinates randomized within a frozen family if that is the intended benchmark design. -- [ ] Keep asset coordinates randomized within layout `A` and `B` if that is the intended benchmark design. -- [ ] If more reproducibility is needed, precompute reset seeds or exact placements in the cached scenario dataset. -- [ ] Document clearly that the benchmark is a fixed environment family with randomized episode instances, not one single fixed map. +--- -## 7) Improve scenario-family integration +## 5) Train/val/holdout setup semantics were tightened + +- [x] Standardize env construction by split. + - Train path uses expected split `train`. + - Eval path uses expected split `train`/`val`/`holdout` per dataset. +- [x] Add internal guardrails against split mixing. + - Added split checks in loader and `WildfireEnv.__init__`. + - Benchmark mode rejects mixed-split records unless a single split is enforced. +- [x] Decision: trust both dataset filenames and record `split` values. + - Added `_split_hint_from_path` and cross-check logic against `expected_split` and record payload. +- [x] Add split-consistency checks at env creation. + - Added `expected_split` to env constructor. + - Constructor validates record splits and fails fast in benchmark mode. + - Canonical train/eval defaults now use seeded split artifacts (`scenario_parameter_records_seeded_{split}.json`). + +Changed functions/files: +- `_split_hint_from_path`, `load_scenario_parameter_records`, `WildfireEnv.__init__` in `src/models/fire_env.py` +- `_load_split_records`, `_evaluate_agent_on_split` in `src/models/evaluate_agents.py` +- train/eval load paths in `src/models/train_rl_agent.py` -- [ ] Ensure train families and held-out families are sampled exactly as frozen in `docs/planning/impl-plan.md`. -- [ ] Prevent held-out family leakage during training. -- [ ] Consider storing a `split` field or family tag directly in cached scenario records. -- [ ] Confirm layout `A` and `B` generation in code really matches the written definitions. +--- + +## 6) Stale live-ingestion assumptions were removed from env setup -## 8) Clean up legacy or transitional code paths +- [x] Canonical env no longer assumes FIRMS/CWFIS live-fire ingestion. + - Clarified benchmark runtime behavior in `fire_env` module docstring. +- [x] Canonical env no longer depends on Open-Meteo/runtime CFFDRS fetches. + - Benchmark runtime now strictly requires frozen records unless explicit dev fallback is enabled. +- [x] Spread/weather features treated as precomputed offline inputs. + - Training canonical path requires scenario dataset. + - Legacy spread-rate path requires explicit `--allow-legacy-dev-fallback`. -- [ ] Audit whether `random_scenario()` should remain part of the canonical benchmark path or only support smoke tests and ablations. -- [ ] Remove or isolate older spread-rate override code once the static parameter dataset path is fully working. -- [ ] Clarify whether `ScenarioConfig` should remain the main reset object or become a thin wrapper over cached parameter records. -- [ ] Remove comments or naming that still imply the older XGBoost-centered flow. +Changed functions/files: +- module docs and benchmark checks in `src/models/fire_env.py` +- fallback gating in `src/models/train_rl_agent.py` -## 9) Add validation and tests +--- -- [ ] Add tests for loading cached scenario parameter records. -- [ ] Add tests that benchmark-mode reset uses cached parameters and does not fall back silently. -- [ ] Add tests that observation shape remains stable at `636` unless the observation contract intentionally changes. -- [ ] Add tests that severity one-hot and wind bias in the observation match the active cached parameter record. -- [ ] Add tests that train/held-out family filtering works as intended. +## 7) Targeted tests were added for the setup contract -## 10) Nice-to-have improvements after canonical alignment +- [x] Add tests for schema validation. + - Added tests for missing required fields and invalid numeric ranges. +- [x] Add tests for benchmark fail-fast behavior. + - Added test ensuring env creation fails in benchmark mode without records. +- [x] Add tests that active scenario parameters match selected record. + - Added test asserting severity/wind/spread values match cached record. +- [x] Add tests for split isolation. + - Added tests for loader expected split mismatch, filename hint mismatch, and env split mismatch. +- [x] Add tests for `record_id` and split metadata in `info`. + - Added reset/step info assertions for `record_id`, `split`, metadata, and audit payloads. -- [ ] Add richer info logging so each episode returns the active `record_id` and scenario family tags. -- [ ] Add optional per-record diagnostics for spread calibration sanity checks. -- [ ] Consider adding a cached `ignition_seed` or `layout_seed` field for exact episode replay. -- [ ] Add a dedicated benchmark env factory that always builds the environment from frozen cached records. +Changed functions/files: +- New suite in `tests/models/test_fire_env_setup_contract.py` +- Test path bootstrap in `tests/conftest.py` --- -## Suggested order +## 8) Cleanup and naming consistency were completed + +- [x] Remove/rename heuristic-first identifiers. + - Renamed `SEVERITY_SPREAD_PROB` to `LEGACY_SEVERITY_SPREAD_PROB`. + - Updated comments/docstrings to mark dev/ablation semantics. +- [x] Keep dev and benchmark paths clearly separated. + - Benchmark mode stays strict. + - Dev fallback remains explicit and opt-in in training. +- [x] Add a single benchmark env factory/helper. + - Added `benchmark_env_kwargs` and `create_benchmark_env`. + - Updated train/eval code paths to use this centralized helper. + +Changed functions/files: +- constants and helpers in `src/models/fire_env.py` +- helper adoption in `src/models/train_rl_agent.py` and `src/models/evaluate_agents.py` + +--- + +## 9) Final decision: keep ignition/layout simulator-side, add replay seeds + +- [x] Decide on ignition controls in dataset. + - Decision: do not move ignition controls into cached dataset schema now. +- [x] Decide on asset layout controls in dataset. + - Decision: do not move asset layout controls into cached dataset schema now. +- [x] If moved, define dataset fields and reset logic. + - Not applied in this phase due to ROI decision. +- [x] Document simulator-side ignition/layout controls. + - Added reset-time comment clarifying that ignition/layout remain simulator-side. +- [x] Add optional `ignition_seed` / `layout_seed` for replayability. + - Added seeded parameter artifact generation in pipeline (`scenario_parameter_records_seeded*.json`). + - Benchmark mode now requires per-record `ignition_seed` and `layout_seed`. + - Added `_stable_seed` deterministic fallback from `record_id + reset_seed` for non-benchmark/dev compatibility. + - Added `_configure_initialization_rngs` and separate ignition/layout RNGs. + - `_ignite` and `_place_assets` now use those RNGs. + - Seeds are surfaced in reset and step `info`. + - Seeded holdout artifact is intentionally reduced to a single unique held-out record for now. + +Changed functions/files: +- `_stable_seed`, `_configure_initialization_rngs`, `_ignite`, `_place_assets`, loader validation, metadata fields in `src/models/fire_env.py` +- Replayability tests in `tests/models/test_fire_env_setup_contract.py` + +--- + +## Verification status + +- `uv run python -m py_compile src/models/fire_env.py src/models/train_rl_agent.py src/models/evaluate_agents.py` +- `uv run pytest tests/models/test_fire_env_setup_contract.py` -1. Make cached parameter records the canonical reset path. -2. Remove silent fallback behavior in benchmark mode. -3. Align reward/transition accounting with the written spec. -4. Freeze family sampling and held-out split handling. -5. Add tests around cached-record loading and reset behavior. +Current targeted setup-contract test status: passing. diff --git a/docs/planning/impl-plan.md b/docs/planning/impl-plan.md index d48d1a5..00f0f4b 100644 --- a/docs/planning/impl-plan.md +++ b/docs/planning/impl-plan.md @@ -283,9 +283,11 @@ For each scenario record, store: 1. `base_spread_prob` 2. `severity_bucket` in `{low, medium, high}` -3. `wind_dir_deg` +3. `wind_direction` in `{N, NE, E, SE, S, SW, W, NW}` 4. `wind_strength` in `[0, 1]` -5. optional logging fields such as `spread_rate_1h_m` if produced during preprocessing +5. `ignition_seed` +6. `layout_seed` +7. optional logging fields such as `spread_rate_1h_m` if produced during preprocessing Episode sampling rule: diff --git a/src/ingestion/static_dataset.py b/src/ingestion/static_dataset.py index bcccb0e..f49ed1d 100644 --- a/src/ingestion/static_dataset.py +++ b/src/ingestion/static_dataset.py @@ -22,6 +22,7 @@ from collections import Counter from dataclasses import dataclass from datetime import UTC, datetime +from hashlib import blake2b from pathlib import Path from src.ingestion.clean_historical import clean_raw_historical_row_with_reason @@ -64,6 +65,8 @@ def tqdm(iterable, **_kwargs): "crown": 1.18, } +WIND_DIRECTIONS_8 = ("N", "NE", "E", "SE", "S", "SW", "W", "NW") + @dataclass class SnapshotBuildResult: @@ -130,6 +133,39 @@ def _parse_wind_direction(value: object) -> float | None: return WIND_DIR_TO_DEG.get(text.upper()) +def _wind_direction_8_from_deg(value: float) -> str: + idx = int((value % 360.0) / 45.0 + 0.5) % 8 + return WIND_DIRECTIONS_8[idx] + + +def _stable_seed(*parts: object) -> int: + payload = "|".join(str(part) for part in parts).encode("utf-8") + digest = blake2b(payload, digest_size=8).digest() + return int.from_bytes(digest, byteorder="little", signed=False) + + +def _with_initialization_seeds(record: dict) -> dict: + seeded = dict(record) + record_id = str(seeded.get("record_id") or "unknown") + split = str(seeded.get("split") or "unknown") + seeded["ignition_seed"] = _stable_seed(record_id, split, "ignition") + seeded["layout_seed"] = _stable_seed(record_id, split, "layout") + return seeded + + +def _single_unique_record(records: list[dict]) -> list[dict]: + if not records: + return [] + seen: set[str] = set() + for record in records: + record_id = str(record.get("record_id") or "") + if not record_id or record_id in seen: + continue + seen.add(record_id) + return [record] + return [] + + def _estimate_precipitation_mm(condition: str | None) -> float: if condition is None: return 0.0 @@ -225,7 +261,7 @@ def _normalize_alberta_row(row: dict) -> dict | None: spread_rate = _parse_float(cleaned.get("FIRE_SPREAD_RATE")) temp_c = _parse_float(cleaned.get("TEMPERATURE")) rh_pct = _parse_float(cleaned.get("RELATIVE_HUMIDITY")) - wind_dir_deg = _parse_wind_direction(cleaned.get("WIND_DIRECTION")) + wind_direction_deg = _parse_wind_direction(cleaned.get("WIND_DIRECTION")) wind_speed = _parse_float(cleaned.get("WIND_SPEED")) if not all([year, fire_number]) or lat is None or lon is None or assessment_dt is None: @@ -234,7 +270,7 @@ def _normalize_alberta_row(row: dict) -> dict | None: area_hectares = assessment_hectares if assessment_hectares not in (None, 0.0) else current_size if area_hectares is None or spread_rate is None or temp_c is None or rh_pct is None: return None - if wind_dir_deg is None or wind_speed is None: + if wind_direction_deg is None or wind_speed is None: return None started_at = _parse_datetime(cleaned.get("FIRE_START_DATE")) @@ -276,7 +312,7 @@ def _normalize_alberta_row(row: dict) -> dict | None: "observed_spread_rate_m_min": spread_rate, "temperature_c": temp_c, "relative_humidity_pct": rh_pct, - "wind_direction_deg": wind_dir_deg, + "wind_direction_deg": wind_direction_deg, "wind_speed_km_h": wind_speed, "precipitation_mm": _estimate_precipitation_mm(weather_over_fire), "fire_type": fire_type.lower(), @@ -504,7 +540,8 @@ def compute_environment_parameters(snapshot: dict) -> dict: """Map one snapshot record into deterministic FireEnv parameter fields.""" observed_spread = float(snapshot["observed_spread_rate_m_min"]) wind_speed = float(snapshot["wind_speed_km_h"]) - wind_dir_deg = float(snapshot["wind_direction_deg"]) + wind_direction_deg = float(snapshot["wind_direction_deg"]) + wind_direction = _wind_direction_8_from_deg(wind_direction_deg) temp_c = float(snapshot["temperature_c"]) rh_pct = float(snapshot["relative_humidity_pct"]) precip_mm = float(snapshot.get("precipitation_mm") or 0.0) @@ -577,7 +614,7 @@ def compute_environment_parameters(snapshot: dict) -> dict: "split": snapshot.get("split"), "base_spread_prob": base_spread_prob, "severity_bucket": severity_bucket, - "wind_dir_deg": round(wind_dir_deg, 2), + "wind_direction": wind_direction, "wind_strength": wind_strength, "spread_rate_1h_m": spread_rate_1h_m, "spread_score": round(spread_score, 4), @@ -649,21 +686,35 @@ def build_static_datasets( "records": snapshots, } params_payload = { - "schema_version": 2, + "schema_version": 3, "generated_at": datetime.now(UTC).isoformat(), "record_count": len(parameter_records), "records": parameter_records, } + seeded_parameter_records = [_with_initialization_seeds(record) for record in parameter_records] + seeded_params_payload = { + "schema_version": 3, + "generated_at": datetime.now(UTC).isoformat(), + "record_count": len(seeded_parameter_records), + "records": seeded_parameter_records, + } snapshot_path = output_dir / "snapshot_records.json" params_path = output_dir / "scenario_parameter_records.json" + seeded_params_path = output_dir / "scenario_parameter_records_seeded.json" snapshot_path.write_text(json.dumps(snapshot_payload, indent=2)) params_path.write_text(json.dumps(params_payload, indent=2)) + seeded_params_path.write_text(json.dumps(seeded_params_payload, indent=2)) split_names = ("train", "val", "holdout") for split_name in split_names: split_snapshots = [record for record in snapshots if record.get("split") == split_name] split_params = [record for record in parameter_records if record.get("split") == split_name] + split_seeded_params = [ + record for record in seeded_parameter_records if record.get("split") == split_name + ] + if split_name == "holdout": + split_seeded_params = _single_unique_record(split_seeded_params) (output_dir / f"snapshot_records_{split_name}.json").write_text( json.dumps( { @@ -679,7 +730,7 @@ def build_static_datasets( (output_dir / f"scenario_parameter_records_{split_name}.json").write_text( json.dumps( { - "schema_version": 2, + "schema_version": 3, "generated_at": datetime.now(UTC).isoformat(), "split": split_name, "record_count": len(split_params), @@ -688,9 +739,26 @@ def build_static_datasets( indent=2, ) ) + (output_dir / f"scenario_parameter_records_seeded_{split_name}.json").write_text( + json.dumps( + { + "schema_version": 3, + "generated_at": datetime.now(UTC).isoformat(), + "split": split_name, + "record_count": len(split_seeded_params), + "records": split_seeded_params, + }, + indent=2, + ) + ) logger.info("Wrote %s snapshot records to %s", len(snapshots), snapshot_path) logger.info("Wrote %s scenario parameter records to %s", len(parameter_records), params_path) + logger.info( + "Wrote %s seeded scenario parameter records to %s", + len(seeded_parameter_records), + seeded_params_path, + ) for split_name in split_names: logger.info( "Split %s: %s records", diff --git a/src/models/evaluate_agents.py b/src/models/evaluate_agents.py index b55076d..8e8e127 100644 --- a/src/models/evaluate_agents.py +++ b/src/models/evaluate_agents.py @@ -19,6 +19,7 @@ MOVE_S, MOVE_W, WildfireEnv, + create_benchmark_env, load_scenario_parameter_records, ) @@ -30,9 +31,9 @@ def tqdm(iterable, **_kwargs): return iterable -DEFAULT_TRAIN_DATASET = Path("data/static/scenario_parameter_records_train.json") -DEFAULT_VAL_DATASET = Path("data/static/scenario_parameter_records_val.json") -DEFAULT_HOLDOUT_DATASET = Path("data/static/scenario_parameter_records_holdout.json") +DEFAULT_TRAIN_DATASET = Path("data/static/scenario_parameter_records_seeded_train.json") +DEFAULT_VAL_DATASET = Path("data/static/scenario_parameter_records_seeded_val.json") +DEFAULT_HOLDOUT_DATASET = Path("data/static/scenario_parameter_records_seeded_holdout.json") DEFAULT_PPO_MODEL = Path("src/models/tactical_ppo_agent.zip") @@ -135,12 +136,19 @@ def _evaluate_agent_on_split( episodes_per_seed: int, model, compute_normalized_burn_ratio: bool, + split_name: str, ) -> dict: episode_metrics = [] for seed in seeds: - env = WildfireEnv(scenario_parameter_records=records, randomize_scenario=True) - baseline_env = WildfireEnv(scenario_parameter_records=records, randomize_scenario=True) + env = create_benchmark_env( + scenario_parameter_records=records, + expected_split=split_name, + ) + baseline_env = create_benchmark_env( + scenario_parameter_records=records, + expected_split=split_name, + ) iterator = tqdm(range(episodes_per_seed), desc=f"{agent_name} seed={seed}", unit="ep") for ep in iterator: eval_seed = seed * 10_000 + ep @@ -183,10 +191,14 @@ def _evaluate_agent_on_split( return summary -def _load_split_records(path: Path | None) -> list[dict]: +def _load_split_records(path: Path | None, *, split_name: str) -> list[dict]: if path is None or not path.exists(): return [] - return load_scenario_parameter_records(path) + return load_scenario_parameter_records( + path, + benchmark_mode=True, + expected_split=split_name, + ) def main() -> None: @@ -208,9 +220,9 @@ def main() -> None: agents = [a.strip().lower() for a in args.agents.split(",") if a.strip()] split_records = { - "train": _load_split_records(args.train_dataset), - "val": _load_split_records(args.val_dataset), - "holdout": _load_split_records(args.holdout_dataset), + "train": _load_split_records(args.train_dataset, split_name="train"), + "val": _load_split_records(args.val_dataset, split_name="val"), + "holdout": _load_split_records(args.holdout_dataset, split_name="holdout"), } results: dict[str, dict] = {} @@ -231,6 +243,7 @@ def main() -> None: episodes_per_seed=args.episodes, model=model, compute_normalized_burn_ratio=not args.no_normalized_burn, + split_name=split_name, ) results[agent_name][split_name] = summary diff --git a/src/models/fire_env.py b/src/models/fire_env.py index 3c7d0ab..202aa28 100644 --- a/src/models/fire_env.py +++ b/src/models/fire_env.py @@ -12,9 +12,15 @@ The agent must protect critical assets under a finite suppression budget. Scenarios vary by ignition pattern, severity, asset layout, and wind bias. +Canonical benchmark mode consumes precomputed offline scenario parameter records. +It does not fetch FIRMS/CWFIS/Open-Meteo/CFFDRS data at runtime. + Usage: from src.models.fire_env import WildfireEnv, ScenarioConfig - env = WildfireEnv(scenario=ScenarioConfig(ignition="edge", severity="high")) + env = WildfireEnv( + scenario=ScenarioConfig(ignition="edge", severity="high"), + benchmark_mode=False, + ) obs, info = env.reset() obs, reward, done, truncated, info = env.step(action) """ @@ -22,14 +28,18 @@ from __future__ import annotations import json +import logging import math from dataclasses import dataclass +from hashlib import blake2b from pathlib import Path import gymnasium as gym import numpy as np from gymnasium import spaces +logger = logging.getLogger(__name__) + # Cell types UNBURNED = 0 BURNING = 1 @@ -48,14 +58,51 @@ GRID_SIZE = 25 -# Severity -> base spread probability (from impl-plan section 9.2) -SEVERITY_SPREAD_PROB = { +# Legacy fallback only (dev/ablation): severity -> base spread probability. +LEGACY_SEVERITY_SPREAD_PROB = { "low": 0.04 + 0.18 * 0.17, # spread_intensity ~ 0.17 "medium": 0.04 + 0.18 * 0.50, # spread_intensity ~ 0.50 "high": 0.04 + 0.18 * 0.83, # spread_intensity ~ 0.83 } SEVERITY_INDEX = {"low": 0, "medium": 1, "high": 2} +VALID_SPLITS = ("train", "val", "holdout") +WIND_DIRECTIONS_8 = ("N", "NE", "E", "SE", "S", "SW", "W", "NW") +_DIAG = 0.70710678118 +WIND_VECTOR_BY_DIR = { + "N": (0.0, -1.0), + "NE": (_DIAG, -_DIAG), + "E": (1.0, 0.0), + "SE": (_DIAG, _DIAG), + "S": (0.0, 1.0), + "SW": (-_DIAG, _DIAG), + "W": (-1.0, 0.0), + "NW": (-_DIAG, -_DIAG), +} + +PARAMETER_METADATA_FIELDS = ( + "record_id", + "split", + "fire_id", + "year", + "source", + "province", + "record_quality_flag", + "ignition_seed", + "layout_seed", +) + +PARAMETER_AUDIT_FIELDS = ( + "spread_rate_1h_m", + "spread_score", + "weather_score", + "cffdrs_dryness_score", + "size_factor", + "fire_type_factor", + "fuel_factor", + "rain_factor", + "record_quality_flag", +) # ── Scenario families ──────────────────────────────────────────────────────── @@ -83,7 +130,7 @@ class ScenarioConfig: ignition: str = "center" severity: str = "medium" asset_layout: str = "A" - wind_dir_deg: float = 0.0 # 0 = wind blowing north->south + wind_direction: str = "N" wind_strength: float = 0.3 # [0, 1] base_spread_prob: float | None = None record_id: str | None = None @@ -92,21 +139,21 @@ def __post_init__(self): assert self.ignition in IGNITION_TYPES, f"Unknown ignition: {self.ignition}" assert self.severity in SEVERITY_LEVELS, f"Unknown severity: {self.severity}" assert self.asset_layout in ASSET_LAYOUTS, f"Unknown asset layout: {self.asset_layout}" + assert self.wind_direction in WIND_DIRECTIONS_8, ( + f"Unknown wind direction: {self.wind_direction}" + ) @property def spread_prob(self) -> float: if self.base_spread_prob is not None: return float(self.base_spread_prob) - return SEVERITY_SPREAD_PROB[self.severity] + return LEGACY_SEVERITY_SPREAD_PROB[self.severity] @property def wind_bias(self) -> tuple[float, float]: """Wind bias vector (wx, wy) for directional spread.""" - rad = math.radians(self.wind_dir_deg) - return ( - self.wind_strength * math.cos(rad), - self.wind_strength * math.sin(rad), - ) + wx, wy = WIND_VECTOR_BY_DIR[self.wind_direction] + return (self.wind_strength * wx, self.wind_strength * wy) @property def severity_onehot(self) -> list[float]: @@ -119,7 +166,7 @@ def random_scenario( rng: np.random.Generator, families: list[tuple[str, str, str]] | None = None, ) -> ScenarioConfig: - """Sample a random scenario from the given families (default: train).""" + """Sample a random scenario for dev/ablation runs (default: train families).""" if families is None: families = TRAIN_FAMILIES ign, sev, layout = families[rng.integers(len(families))] @@ -127,20 +174,283 @@ def random_scenario( ignition=ign, severity=sev, asset_layout=layout, - wind_dir_deg=float(rng.uniform(0, 360)), + wind_direction=WIND_DIRECTIONS_8[int(rng.integers(len(WIND_DIRECTIONS_8)))], wind_strength=float(rng.uniform(0.1, 0.6)), ) -def load_scenario_parameter_records(path: str | Path) -> list[dict]: - """Load cached scenario parameter records from a JSON file.""" +def _split_hint_from_path(path: Path) -> str | None: + stem = path.stem.lower() + for split in VALID_SPLITS: + token = f"_{split}" + if stem.endswith(token) or token in stem: + return split + return None + + +def _stable_seed(*parts: object) -> int: + payload = "|".join(str(part) for part in parts).encode("utf-8") + digest = blake2b(payload, digest_size=8).digest() + return int.from_bytes(digest, byteorder="little", signed=False) + + +def load_scenario_parameter_records( + path: str | Path, + *, + benchmark_mode: bool = True, + expected_split: str | None = None, +) -> list[dict]: + """Load and validate precomputed scenario parameter records from a JSON file.""" records_path = Path(path) payload = json.loads(records_path.read_text()) records = payload.get("records", []) if isinstance(payload, dict) else payload if not isinstance(records, list): msg = f"Invalid scenario parameter dataset: {records_path}" raise ValueError(msg) - return [record for record in records if isinstance(record, dict)] + + valid_splits = set(VALID_SPLITS) + valid_severities = set(SEVERITY_LEVELS) + validated: list[dict] = [] + errors: list[str] = [] + + if expected_split is not None: + expected_split = expected_split.strip().lower() + if expected_split not in valid_splits: + msg = ( + f"Invalid expected_split '{expected_split}' for dataset {records_path}; " + f"expected one of {sorted(valid_splits)}" + ) + raise ValueError(msg) + + for idx, record in enumerate(records): + if not isinstance(record, dict): + errors.append(f"record[{idx}]: expected object, got {type(record).__name__}") + continue + + missing = [ + field + for field in ( + "record_id", + "split", + "base_spread_prob", + "severity_bucket", + "wind_direction", + "wind_strength", + *(("ignition_seed", "layout_seed") if benchmark_mode else ()), + ) + if record.get(field) is None + ] + if missing: + errors.append(f"record[{idx}]: missing required fields {missing}") + continue + + record_id = str(record.get("record_id", "")).strip() + if not record_id: + errors.append(f"record[{idx}]: record_id must be non-empty") + continue + + split = str(record.get("split", "")).strip().lower() + if split not in valid_splits: + errors.append( + f"record[{idx}]: invalid split '{record.get('split')}' (expected one of {sorted(valid_splits)})" + ) + continue + + severity = str(record.get("severity_bucket", "")).strip().lower() + if severity not in valid_severities: + errors.append( + f"record[{idx}]: invalid severity_bucket " + f"'{record.get('severity_bucket')}' (expected one of {sorted(valid_severities)})" + ) + continue + + try: + base_spread_prob = float(record["base_spread_prob"]) + wind_strength = float(record["wind_strength"]) + except (TypeError, ValueError) as exc: + errors.append(f"record[{idx}]: numeric parse failed ({exc})") + continue + + if not (math.isfinite(base_spread_prob) and math.isfinite(wind_strength)): + errors.append(f"record[{idx}]: numeric fields must be finite floats") + continue + + if not 0.0 <= base_spread_prob <= 1.0: + errors.append( + f"record[{idx}]: base_spread_prob {base_spread_prob} out of range [0.0, 1.0]" + ) + continue + + wind_direction = str(record.get("wind_direction", "")).strip().upper() + if wind_direction not in WIND_DIRECTIONS_8: + errors.append( + f"record[{idx}]: invalid wind_direction '{record.get('wind_direction')}' " + f"(expected one of {list(WIND_DIRECTIONS_8)})" + ) + continue + + if not 0.0 <= wind_strength <= 1.0: + errors.append(f"record[{idx}]: wind_strength {wind_strength} out of range [0.0, 1.0]") + continue + + seed_invalid = False + for seed_key in ("ignition_seed", "layout_seed"): + if record.get(seed_key) is None: + continue + try: + seed_value = int(record[seed_key]) + except (TypeError, ValueError) as exc: + errors.append(f"record[{idx}]: {seed_key} parse failed ({exc})") + seed_invalid = True + continue + if seed_value < 0: + errors.append(f"record[{idx}]: {seed_key} must be >= 0") + seed_invalid = True + continue + + if seed_invalid: + continue + + normalized = dict(record) + normalized["record_id"] = record_id + normalized["split"] = split + normalized["severity_bucket"] = severity + normalized["base_spread_prob"] = base_spread_prob + normalized["wind_direction"] = wind_direction + normalized["wind_strength"] = wind_strength + if normalized.get("ignition_seed") is not None: + normalized["ignition_seed"] = int(normalized["ignition_seed"]) + if normalized.get("layout_seed") is not None: + normalized["layout_seed"] = int(normalized["layout_seed"]) + validated.append(normalized) + + path_split_hint = _split_hint_from_path(records_path) + if ( + path_split_hint is not None + and expected_split is not None + and path_split_hint != expected_split + ): + msg = ( + f"Split mismatch for {records_path}: expected_split='{expected_split}' " + f"but filename suggests '{path_split_hint}'" + ) + if benchmark_mode: + raise ValueError(msg) + logger.warning(msg) + + effective_expected_split = expected_split or path_split_hint + if benchmark_mode and effective_expected_split is None and validated: + record_splits = sorted({str(record["split"]) for record in validated}) + if len(record_splits) == 1: + effective_expected_split = record_splits[0] + else: + msg = ( + f"Could not infer a single split for {records_path}; found mixed splits {record_splits}. " + "Provide expected_split explicitly or use split-specific datasets." + ) + raise ValueError(msg) + + if effective_expected_split is not None and validated: + split_mismatch = [ + f"record[{idx}]: split '{record['split']}' != expected '{effective_expected_split}'" + for idx, record in enumerate(validated) + if str(record["split"]) != effective_expected_split + ] + if split_mismatch and benchmark_mode: + sample = "\n - " + "\n - ".join(split_mismatch[:10]) + if len(split_mismatch) > 10: + sample += f"\n - ... and {len(split_mismatch) - 10} more" + msg = ( + f"Split consistency check failed for {records_path}: " + f"{len(split_mismatch)} record(s) do not match expected split " + f"'{effective_expected_split}'.{sample}" + ) + raise ValueError(msg) + if split_mismatch and not benchmark_mode: + logger.warning( + "Scenario dataset %s has %s split-mismatched record(s); skipping them in dev mode.", + records_path, + len(split_mismatch), + ) + for detail in split_mismatch[:10]: + logger.warning(" %s", detail) + if len(split_mismatch) > 10: + logger.warning(" ... and %s more", len(split_mismatch) - 10) + validated = [ + record for record in validated if str(record["split"]) == effective_expected_split + ] + + if errors and benchmark_mode: + sample = "\n - " + "\n - ".join(errors[:10]) + if len(errors) > 10: + sample += f"\n - ... and {len(errors) - 10} more" + msg = ( + f"Invalid scenario parameter dataset at {records_path}: " + f"{len(errors)} invalid record(s) found.{sample}" + ) + raise ValueError(msg) + + if errors and not benchmark_mode: + logger.warning( + "Scenario dataset %s has %s invalid record(s); skipping them in dev mode.", + records_path, + len(errors), + ) + for detail in errors[:10]: + logger.warning(" %s", detail) + if len(errors) > 10: + logger.warning(" ... and %s more", len(errors) - 10) + + if benchmark_mode and not validated: + msg = ( + f"Scenario dataset {records_path} has no usable records after validation. " + "Benchmark mode requires a non-empty validated dataset." + ) + raise ValueError(msg) + + return validated + + +def benchmark_env_kwargs( + *, + expected_split: str, + scenario_parameter_records: list[dict] | None = None, + dataset_path: str | Path | None = None, +) -> dict: + """Build canonical benchmark env kwargs from validated frozen records.""" + if scenario_parameter_records is None: + if dataset_path is None: + msg = "Provide scenario_parameter_records or dataset_path for benchmark env creation" + raise ValueError(msg) + scenario_parameter_records = load_scenario_parameter_records( + dataset_path, + benchmark_mode=True, + expected_split=expected_split, + ) + + return { + "scenario_parameter_records": scenario_parameter_records, + "benchmark_mode": True, + "expected_split": expected_split, + "randomize_scenario": True, + } + + +def create_benchmark_env( + *, + expected_split: str, + scenario_parameter_records: list[dict] | None = None, + dataset_path: str | Path | None = None, + **env_overrides, +) -> WildfireEnv: + """Create a canonical benchmark env instance from frozen records.""" + kwargs = benchmark_env_kwargs( + expected_split=expected_split, + scenario_parameter_records=scenario_parameter_records, + dataset_path=dataset_path, + ) + kwargs.update(env_overrides) + return WildfireEnv(**kwargs) def scenario_from_parameter_record( @@ -150,17 +460,15 @@ def scenario_from_parameter_record( asset_layout: str, ) -> ScenarioConfig: """Build a ScenarioConfig from a cached parameter record.""" - severity = str(record.get("severity_bucket", "medium")).lower() + severity = str(record["severity_bucket"]).lower() return ScenarioConfig( ignition=ignition, - severity=severity if severity in SEVERITY_LEVELS else "medium", + severity=severity, asset_layout=asset_layout, - wind_dir_deg=float(record.get("wind_dir_deg", 0.0) or 0.0), - wind_strength=float(record.get("wind_strength", 0.3) or 0.3), - base_spread_prob=float(record.get("base_spread_prob")) - if record.get("base_spread_prob") is not None - else None, - record_id=str(record.get("record_id")) if record.get("record_id") is not None else None, + wind_direction=str(record["wind_direction"]), + wind_strength=float(record["wind_strength"]), + base_spread_prob=float(record["base_spread_prob"]), + record_id=str(record["record_id"]), ) @@ -194,6 +502,8 @@ def __init__( randomize_scenario: bool = True, scenario_families: list[tuple[str, str, str]] | None = None, scenario_parameter_records: list[dict] | None = None, + expected_split: str | None = None, + benchmark_mode: bool = True, # Legacy compat -- ignored if scenario is provided base_spread_rate_m_per_min: float | None = None, ): @@ -208,7 +518,86 @@ def __init__( self.randomize_scenario = randomize_scenario self.scenario_families = scenario_families self.scenario_parameter_records = scenario_parameter_records or [] + self.expected_split = expected_split.lower() if expected_split is not None else None + self.benchmark_mode = benchmark_mode self._active_parameter_record: dict | None = None + self._active_record_id: str | None = None + self._record_order: list[int] = [] + self._record_cursor: int = 0 + + if self.benchmark_mode: + if self.expected_split is not None and self.expected_split not in VALID_SPLITS: + msg = ( + f"Invalid expected_split '{self.expected_split}' for WildfireEnv; " + f"expected one of {list(VALID_SPLITS)}" + ) + raise ValueError(msg) + if not self.scenario_parameter_records: + msg = ( + "benchmark_mode=True requires non-empty scenario_parameter_records. " + "Load frozen scenario_parameter_records_*.json and pass them to WildfireEnv." + ) + raise ValueError(msg) + if scenario is not None: + msg = ( + "benchmark_mode=True does not accept fixed ScenarioConfig input. " + "Use scenario_parameter_records for record-driven resets." + ) + raise ValueError(msg) + if not self.randomize_scenario: + msg = ( + "benchmark_mode=True requires randomize_scenario=True for record-driven resets." + ) + raise ValueError(msg) + if base_spread_rate_m_per_min is not None: + msg = ( + "base_spread_rate_m_per_min is a legacy dev-mode path and cannot be used " + "with benchmark_mode=True." + ) + raise ValueError(msg) + record_splits = { + str(record.get("split", "")).strip().lower() + for record in self.scenario_parameter_records + if isinstance(record, dict) + } + if not record_splits: + msg = "benchmark_mode=True requires records with valid split fields" + raise ValueError(msg) + invalid_splits = [s for s in record_splits if s not in VALID_SPLITS] + if invalid_splits: + msg = ( + f"benchmark_mode=True found invalid split values {sorted(invalid_splits)}; " + f"expected one of {list(VALID_SPLITS)}" + ) + raise ValueError(msg) + if self.expected_split is not None and any( + s != self.expected_split for s in record_splits + ): + msg = ( + f"benchmark_mode=True expected split '{self.expected_split}' but got record splits " + f"{sorted(record_splits)}" + ) + raise ValueError(msg) + if self.expected_split is None and len(record_splits) != 1: + msg = ( + "benchmark_mode=True requires a single split dataset when expected_split is not " + f"provided; got splits {sorted(record_splits)}" + ) + raise ValueError(msg) + if self.expected_split is None: + self.expected_split = next(iter(record_splits)) + missing_init = [ + idx + for idx, record in enumerate(self.scenario_parameter_records) + if record.get("ignition_seed") is None or record.get("layout_seed") is None + ] + if missing_init: + msg = ( + "benchmark_mode=True requires initialization seeds on all records " + f"(missing ignition_seed/layout_seed in {len(missing_init)} record(s)). " + "Use scenario_parameter_records_seeded_*.json artifacts." + ) + raise ValueError(msg) # Scenario (may be overridden each reset if randomize_scenario=True) if scenario is not None: @@ -250,6 +639,10 @@ def __init__( self.crew_cd: int = 0 self.assets_lost: int = 0 self.initial_asset_count: int = 0 + self._ignition_seed_used: int | None = None + self._layout_seed_used: int | None = None + self._ignition_rng: np.random.Generator | None = None + self._layout_rng: np.random.Generator | None = None @property def scenario(self) -> ScenarioConfig: @@ -262,29 +655,49 @@ def reset(self, seed: int | None = None, options: dict | None = None): self.grid = np.zeros((self.grid_size, self.grid_size), dtype=np.int32) self.step_count = 0 self.assets_lost = 0 + self._active_parameter_record = None + self._active_record_id = self._scenario.record_id + self._ignition_seed_used = None + self._layout_seed_used = None + self._ignition_rng = None + self._layout_rng = None # Optionally sample a new scenario if self.randomize_scenario: + # Ignition and asset layout remain simulator-side controls. + # Cached records provide spread/weather conditions; optional seeds + # can pin reproducible ignition/layout realizations. families = self.scenario_families or TRAIN_FAMILIES - ign, sev, layout = families[int(self.np_random.integers(len(families)))] + ign, _sev, layout = families[int(self.np_random.integers(len(families)))] if self.scenario_parameter_records: - matching_records = [ - record - for record in self.scenario_parameter_records - if str(record.get("severity_bucket", "")).lower() == sev - ] - source_records = matching_records or self.scenario_parameter_records - record = source_records[int(self.np_random.integers(len(source_records)))] + # Canonical path: consume precomputed offline parameters only. + record = self._sample_parameter_record(reshuffle=seed is not None) self._active_parameter_record = record + self._active_record_id = ( + str(record.get("record_id")) if record.get("record_id") else None + ) self._scenario = scenario_from_parameter_record( record, ignition=ign, asset_layout=layout, ) + self._configure_initialization_rngs(record=record, reset_seed=seed) else: + if self.benchmark_mode: + msg = ( + "benchmark_mode=True cannot reset without scenario_parameter_records. " + "Disable benchmark_mode only for explicit dev/ablation runs." + ) + raise RuntimeError(msg) self._active_parameter_record = None + self._active_record_id = None self._scenario = random_scenario(self.np_random, families) + if self._ignition_rng is None: + self._ignition_rng = self.np_random + if self._layout_rng is None: + self._layout_rng = self.np_random + # Reset budgets and cooldowns self.heli_left = self.heli_budget_init self.crew_left = self.crew_budget_init @@ -303,6 +716,14 @@ def reset(self, seed: int | None = None, options: dict | None = None): return self._get_obs(), { "scenario": self._scenario, + "record_id": self._active_record_id, + "split": self._active_parameter_record.get("split") + if self._active_parameter_record + else None, + "ignition_seed": self._ignition_seed_used, + "layout_seed": self._layout_seed_used, + "parameter_record_meta": self._parameter_metadata(), + "parameter_audit": self._parameter_audit(), "parameter_record": self._active_parameter_record, } @@ -356,6 +777,14 @@ def step(self, action: int): "heli_left": self.heli_left, "crew_left": self.crew_left, "scenario": self._scenario, + "record_id": self._active_record_id, + "split": self._active_parameter_record.get("split") + if self._active_parameter_record + else None, + "ignition_seed": self._ignition_seed_used, + "layout_seed": self._layout_seed_used, + "parameter_record_meta": self._parameter_metadata(), + "parameter_audit": self._parameter_audit(), "parameter_record": self._active_parameter_record, } @@ -366,6 +795,70 @@ def step(self, action: int): def _in_bounds(self, r: int, c: int) -> bool: return 0 <= r < self.grid_size and 0 <= c < self.grid_size + def _configure_initialization_rngs( + self, + *, + record: dict, + reset_seed: int | None, + ) -> None: + record_id = str(record.get("record_id") or "unknown") + + ignition_seed = record.get("ignition_seed") + layout_seed = record.get("layout_seed") + + if self.benchmark_mode and (ignition_seed is None or layout_seed is None): + msg = ( + "benchmark_mode=True requires per-record ignition_seed and layout_seed " + "for reproducible initialization." + ) + raise RuntimeError(msg) + + if ignition_seed is None and reset_seed is not None: + ignition_seed = _stable_seed(record_id, reset_seed, "ignition") + if layout_seed is None and reset_seed is not None: + layout_seed = _stable_seed(record_id, reset_seed, "layout") + + self._ignition_seed_used = int(ignition_seed) if ignition_seed is not None else None + self._layout_seed_used = int(layout_seed) if layout_seed is not None else None + + self._ignition_rng = ( + np.random.default_rng(self._ignition_seed_used) + if self._ignition_seed_used is not None + else self.np_random + ) + self._layout_rng = ( + np.random.default_rng(self._layout_seed_used) + if self._layout_seed_used is not None + else self.np_random + ) + + def _parameter_metadata(self) -> dict: + record = self._active_parameter_record + if not record: + return {} + return {key: record.get(key) for key in PARAMETER_METADATA_FIELDS} + + def _parameter_audit(self) -> dict: + record = self._active_parameter_record + if not record: + return {} + return {key: record.get(key) for key in PARAMETER_AUDIT_FIELDS if key in record} + + def _sample_parameter_record(self, *, reshuffle: bool = False) -> dict: + if not self.scenario_parameter_records: + msg = "No scenario_parameter_records available for sampling" + raise RuntimeError(msg) + + if reshuffle or not self._record_order or self._record_cursor >= len(self._record_order): + self._record_order = [ + int(i) for i in self.np_random.permutation(len(self.scenario_parameter_records)) + ] + self._record_cursor = 0 + + index = self._record_order[self._record_cursor] + self._record_cursor += 1 + return self.scenario_parameter_records[index] + def _get_obs(self) -> np.ndarray: # Normalize grid: 6 cell types -> [0, 1] flat_grid = self.grid.flatten().astype(np.float32) / 5.0 @@ -390,6 +883,7 @@ def _get_obs(self) -> np.ndarray: def _ignite(self): """Set initial fire cells based on scenario ignition pattern.""" + rng = self._ignition_rng or self.np_random gs = self.grid_size cx, cy = gs // 2, gs // 2 pattern = self._scenario.ignition @@ -398,7 +892,7 @@ def _ignite(self): seeds = [(cx, cy), (cx - 1, cy), (cx + 1, cy), (cx, cy - 1), (cx, cy + 1)] elif pattern == "edge": # Fire starts along a random edge - edge = int(self.np_random.integers(4)) + edge = int(rng.integers(4)) if edge == 0: # top seeds = [(0, cy - 1), (0, cy), (0, cy + 1)] elif edge == 1: # bottom @@ -408,7 +902,7 @@ def _ignite(self): else: # right seeds = [(cx - 1, gs - 1), (cx, gs - 1), (cx + 1, gs - 1)] elif pattern == "corner": - corner = int(self.np_random.integers(4)) + corner = int(rng.integers(4)) offsets = [(0, 0), (0, gs - 1), (gs - 1, 0), (gs - 1, gs - 1)] cr, cc = offsets[corner] seeds = [(cr, cc)] @@ -419,11 +913,11 @@ def _ignite(self): seeds.append((nr, nc)) elif pattern == "multi_cluster": # 2-3 small fire clusters scattered across the grid - n_clusters = int(self.np_random.integers(2, 4)) + n_clusters = int(rng.integers(2, 4)) seeds = [] for _ in range(n_clusters): - r = int(self.np_random.integers(2, gs - 2)) - c = int(self.np_random.integers(2, gs - 2)) + r = int(rng.integers(2, gs - 2)) + c = int(rng.integers(2, gs - 2)) seeds.append((r, c)) seeds.append((r + 1, c)) seeds.append((r, c + 1)) @@ -437,6 +931,7 @@ def _ignite(self): def _place_assets(self): """Place critical asset cells based on scenario asset layout.""" + rng = self._layout_rng or self.np_random gs = self.grid_size cx, cy = gs // 2, gs // 2 min_dist = gs // 4 @@ -447,8 +942,8 @@ def _place_assets(self): placed = 0 cluster_r, cluster_c = 0, 0 for _ in range(100): - cluster_r = int(self.np_random.integers(0, gs)) - cluster_c = int(self.np_random.integers(0, gs)) + cluster_r = int(rng.integers(0, gs)) + cluster_c = int(rng.integers(0, gs)) if abs(cluster_r - cx) + abs(cluster_c - cy) >= min_dist: break @@ -459,7 +954,7 @@ def _place_assets(self): if dr == 0 and dc == 0: continue candidates.append((cluster_r + dr, cluster_c + dc)) - self.np_random.shuffle(candidates) + rng.shuffle(candidates) for r, c in candidates: if placed >= self.n_assets: @@ -480,15 +975,15 @@ def _place_assets(self): for _ in range(2): cluster_r, cluster_c = 0, 0 for _ in range(100): - cluster_r = int(self.np_random.integers(0, gs)) - cluster_c = int(self.np_random.integers(0, gs)) + cluster_r = int(rng.integers(0, gs)) + cluster_c = int(rng.integers(0, gs)) if abs(cluster_r - cx) + abs(cluster_c - cy) >= min_dist: break candidates = [ (cluster_r + dr, cluster_c + dc) for dr in range(-1, 2) for dc in range(-1, 2) ] - self.np_random.shuffle(candidates) + rng.shuffle(candidates) cluster_placed = 0 for r, c in candidates: diff --git a/src/models/rl_agent.py b/src/models/rl_agent.py index af51486..b1c8561 100644 --- a/src/models/rl_agent.py +++ b/src/models/rl_agent.py @@ -56,9 +56,9 @@ def _greedy_fallback( directions = [ ("N", -1, 0, "ground_crew", "Establish northern anchor line"), ("NE", -0.7, 0.7, "helicopter", "Pre-position tanker for NE flank"), - ("E", 0, 1, "ground_crew", "Cut containment line on eastern flank"), - ("S", 1, 0, "ground_crew", "Southern backfire opportunity"), - ("W", 0, -1, "helicopter", "Air attack on advancing western head"), + ("E", 0, 1, "ground_crew", "Cut containment line on eastern flank"), + ("S", 1, 0, "ground_crew", "Southern backfire opportunity"), + ("W", 0, -1, "helicopter", "Air attack on advancing western head"), ] metres_per_deg_lat = 111_320 @@ -69,15 +69,17 @@ def _greedy_fallback( for name, dlat_factor, dlon_factor, asset, rationale in directions: lat = fire_lat + (dlat_factor * offset_m / metres_per_deg_lat) lon = fire_lon + (dlon_factor * offset_m / metres_per_deg_lon) - waypoints.append({ - "direction": name, - "latitude": round(lat, 5), - "longitude": round(lon, 5), - "asset_type": asset, - "rationale": rationale, - "score": round(0.9 - len(waypoints) * 0.1, 2), - "source": "greedy_heuristic", - }) + waypoints.append( + { + "direction": name, + "latitude": round(lat, 5), + "longitude": round(lon, 5), + "asset_type": asset, + "rationale": rationale, + "score": round(0.9 - len(waypoints) * 0.1, 2), + "source": "greedy_heuristic", + } + ) return waypoints @@ -94,8 +96,8 @@ def get_tactical_recommendations( If the PPO model is trained and saved, uses it for inference. Falls back to the greedy heuristic if model isn't available. """ - fire_lat = float(fire_data.get("latitude", 49.9071)) if fire_data else 49.9071 - fire_lon = float(fire_data.get("longitude", -119.496)) if fire_data else -119.496 + fire_lat = float(fire_data.get("latitude", 49.9071)) if fire_data else 49.9071 + fire_lon = float(fire_data.get("longitude", -119.496)) if fire_data else -119.496 spread_1h = float(spread_output.get("spread_1h_m", 1200)) if spread_output else 1200 spread_3h = float(spread_output.get("spread_3h_m", 3600)) if spread_output else 3600 @@ -111,6 +113,7 @@ def get_tactical_recommendations( spread_rate_m_per_min = spread_1h / 60.0 env = WildfireEnv( base_spread_rate_m_per_min=spread_rate_m_per_min, + benchmark_mode=False, ) model = PPO.load(str(MODEL_PATH), env=env) @@ -123,18 +126,18 @@ def get_tactical_recommendations( obs, reward, done, truncated, _info = env.step(int(action)) if int(action) in deployment_actions: - lat, lon = _grid_to_latlon( - env.agent_pos, fire_lat, fire_lon, spread_1h - ) + lat, lon = _grid_to_latlon(env.agent_pos, fire_lat, fire_lon, spread_1h) asset = "helicopter" if int(action) == 4 else "ground_crew" - waypoints.append({ - "latitude": lat, - "longitude": lon, - "asset_type": asset, - "rationale": f"PPO recommended {asset} deployment (step {env.step_count})", - "score": round(float(reward), 2), - "source": "ppo_agent", - }) + waypoints.append( + { + "latitude": lat, + "longitude": lon, + "asset_type": asset, + "rationale": f"PPO recommended {asset} deployment (step {env.step_count})", + "score": round(float(reward), 2), + "source": "ppo_agent", + } + ) if done or truncated: break diff --git a/src/models/train_rl_agent.py b/src/models/train_rl_agent.py index d71adec..3c93c33 100644 --- a/src/models/train_rl_agent.py +++ b/src/models/train_rl_agent.py @@ -18,14 +18,15 @@ logger = logging.getLogger(__name__) MODEL_SAVE_PATH = Path(__file__).parent / "tactical_ppo_agent" -DEFAULT_SCENARIO_DATASET = Path("data/static/scenario_parameter_records_train.json") +DEFAULT_SCENARIO_DATASETS = (Path("data/static/scenario_parameter_records_seeded_train.json"),) def _resolve_dataset_path(path: str | None) -> str | None: if path: return path - if DEFAULT_SCENARIO_DATASET.exists(): - return str(DEFAULT_SCENARIO_DATASET) + for candidate in DEFAULT_SCENARIO_DATASETS: + if candidate.exists(): + return str(candidate) return None @@ -35,11 +36,17 @@ def _existing_path(path: str | None) -> str | None: return None -def _evaluate_model(model, dataset_path: str, seed: int, episodes: int = 5) -> tuple[float, float]: - from src.models.fire_env import WildfireEnv, load_scenario_parameter_records +def _evaluate_model( + model, + dataset_path: str, + seed: int, + episodes: int = 5, + expected_split: str | None = None, +) -> tuple[float, float]: + from src.models.fire_env import create_benchmark_env - records = load_scenario_parameter_records(dataset_path) - eval_env = WildfireEnv(scenario_parameter_records=records) + split = expected_split or "train" + eval_env = create_benchmark_env(dataset_path=dataset_path, expected_split=split) returns = [] assets_lost_total = [] for ep in range(episodes): @@ -64,13 +71,14 @@ def train( scenario_dataset_path: str | None = None, val_dataset_path: str | None = None, holdout_dataset_path: str | None = None, + allow_legacy_dev_fallback: bool = False, ) -> None: """ Train the PPO tactical agent. Args: total_timesteps: Total env steps to train for. - spread_rate_m_per_min: Fire spread rate (from XGBoost) for training. + spread_rate_m_per_min: Legacy fixed spread rate used only in dev fallback mode. n_envs: Parallel environments. seed: Random seed for reproducibility. """ @@ -78,7 +86,7 @@ def train( from stable_baselines3 import PPO from stable_baselines3.common.env_util import make_vec_env - from src.models.fire_env import WildfireEnv, load_scenario_parameter_records + from src.models.fire_env import WildfireEnv, benchmark_env_kwargs except ImportError as e: print(f"Missing dependency: {e}") print(" Run: uv sync") @@ -88,7 +96,6 @@ def train( print(" FireGrid PPO Tactical Agent — Training") print("=" * 60) print(f" Timesteps: {total_timesteps:,}") - print(f" Spread rate: {spread_rate_m_per_min} m/min") print(f" Environments: {n_envs} parallel") print(" Grid: 25×25 with critical assets") print(" Budgets: heli=8, crew=20") @@ -98,10 +105,28 @@ def train( env_kwargs: dict = {} if scenario_dataset_path: - records = load_scenario_parameter_records(scenario_dataset_path) - env_kwargs["scenario_parameter_records"] = records + env_kwargs = benchmark_env_kwargs( + dataset_path=scenario_dataset_path, + expected_split="train", + ) + records = env_kwargs["scenario_parameter_records"] + print(" Runtime data: frozen offline scenario records (no live ingestion)") print(f" Scenario records: {len(records)} from {scenario_dataset_path}") else: + if not allow_legacy_dev_fallback: + msg = ( + "No training scenario dataset found. Canonical training requires precomputed " + "scenario_parameter_records_seeded_train.json (or explicit equivalent). " + "To run non-canonical dev mode, pass --allow-legacy-dev-fallback with " + "--spread-rate." + ) + raise ValueError(msg) + print( + " No scenario dataset found; running explicit legacy dev mode " + "with --spread-rate fallback." + ) + print(f" Legacy spread rate: {spread_rate_m_per_min} m/min") + env_kwargs["benchmark_mode"] = False env_kwargs["base_spread_rate_m_per_min"] = spread_rate_m_per_min vec_env = make_vec_env( WildfireEnv, @@ -143,7 +168,13 @@ def train( for split_name, dataset_path in eval_targets: if not dataset_path: continue - mean_return, mean_assets_lost = _evaluate_model(model, dataset_path, seed=seed, episodes=5) + mean_return, mean_assets_lost = _evaluate_model( + model, + dataset_path, + seed=seed, + episodes=5, + expected_split=split_name, + ) print(f" [{split_name}] Mean return: {mean_return:.1f}") print(f" [{split_name}] Mean assets lost: {mean_assets_lost:.1f}") print(f"\nTraining complete. Model ready at {MODEL_SAVE_PATH}.zip") @@ -155,7 +186,15 @@ def train( "--timesteps", type=int, default=200_000, help="Total training timesteps (default: 200000)" ) parser.add_argument( - "--spread-rate", type=float, default=15.0, help="Fire spread rate in m/min (default: 15.0)" + "--spread-rate", + type=float, + default=15.0, + help="Legacy dev-mode fixed spread rate in m/min (default: 15.0)", + ) + parser.add_argument( + "--allow-legacy-dev-fallback", + action="store_true", + help="Allow non-canonical fallback when no scenario dataset is available", ) parser.add_argument( "--envs", type=int, default=4, help="Number of parallel environments (default: 4)" @@ -170,13 +209,13 @@ def train( parser.add_argument( "--val-dataset", type=str, - default="data/static/scenario_parameter_records_val.json", + default="data/static/scenario_parameter_records_seeded_val.json", help="Path to cached validation scenario parameter JSON dataset", ) parser.add_argument( "--holdout-dataset", type=str, - default="data/static/scenario_parameter_records_holdout.json", + default="data/static/scenario_parameter_records_seeded_holdout.json", help="Path to cached holdout scenario parameter JSON dataset", ) args = parser.parse_args() @@ -189,4 +228,5 @@ def train( scenario_dataset_path=args.scenario_dataset, val_dataset_path=args.val_dataset, holdout_dataset_path=args.holdout_dataset, + allow_legacy_dev_fallback=args.allow_legacy_dev_fallback, ) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..f96b4d8 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,8 @@ +from __future__ import annotations + +import sys +from pathlib import Path + +ROOT = Path(__file__).resolve().parents[1] +if str(ROOT) not in sys.path: + sys.path.insert(0, str(ROOT)) diff --git a/tests/models/test_fire_env_setup_contract.py b/tests/models/test_fire_env_setup_contract.py new file mode 100644 index 0000000..7f74846 --- /dev/null +++ b/tests/models/test_fire_env_setup_contract.py @@ -0,0 +1,205 @@ +from __future__ import annotations + +import json + +import pytest + +from src.models.fire_env import WildfireEnv, load_scenario_parameter_records + + +def _record(**overrides): + base = { + "record_id": "AB-2020-001__20200101", + "split": "train", + "base_spread_prob": 0.14, + "severity_bucket": "medium", + "wind_direction": "E", + "wind_strength": 0.35, + "ignition_seed": 101, + "layout_seed": 202, + "fire_id": "AB-2020-001", + "year": 2020, + "source": "AB_HISTORICAL_WILDFIRE", + "province": "AB", + "record_quality_flag": "measured", + "spread_rate_1h_m": 120.0, + "spread_score": 0.51, + "weather_score": 0.42, + "cffdrs_dryness_score": 0.3, + "size_factor": 1.02, + "fire_type_factor": 1.0, + "fuel_factor": 1.08, + "rain_factor": 0.96, + } + base.update(overrides) + return base + + +def _write_records(path, records): + path.write_text(json.dumps({"records": records})) + + +def test_schema_validation_requires_core_fields(tmp_path): + path = tmp_path / "records.json" + _write_records(path, [_record(record_id=None)]) + + with pytest.raises(ValueError, match="missing required fields"): + load_scenario_parameter_records(path, benchmark_mode=True, expected_split="train") + + +def test_schema_validation_checks_domains_and_ranges(tmp_path): + path = tmp_path / "records.json" + _write_records(path, [_record(wind_strength=1.5)]) + + with pytest.raises(ValueError, match="wind_strength"): + load_scenario_parameter_records(path, benchmark_mode=True, expected_split="train") + + +def test_schema_validation_rejects_invalid_wind_direction(tmp_path): + path = tmp_path / "records.json" + _write_records(path, [_record(wind_direction="NORTH")]) + + with pytest.raises(ValueError, match="invalid wind_direction"): + load_scenario_parameter_records(path, benchmark_mode=True, expected_split="train") + + +def test_dev_mode_warns_and_skips_invalid_records(tmp_path, caplog): + path = tmp_path / "records.json" + _write_records(path, [_record(), _record(record_id="bad", split="unexpected")]) + + records = load_scenario_parameter_records(path, benchmark_mode=False, expected_split="train") + + assert len(records) == 1 + assert records[0]["record_id"] == "AB-2020-001__20200101" + assert "invalid record" in caplog.text.lower() or "split-mismatched" in caplog.text.lower() + + +def test_benchmark_mode_requires_records_on_env_creation(): + with pytest.raises(ValueError, match="requires non-empty scenario_parameter_records"): + WildfireEnv(scenario_parameter_records=[], benchmark_mode=True) + + +def test_benchmark_mode_reset_keeps_record_driven_path(): + env = WildfireEnv( + scenario_parameter_records=[_record(record_id="AB-2020-keep")], + benchmark_mode=True, + expected_split="train", + ) + + _obs, info = env.reset(seed=11) + + assert info["parameter_record"] is not None + assert info["record_id"] == "AB-2020-keep" + + +def test_active_scenario_uses_cached_parameter_values(): + record = _record( + severity_bucket="high", + wind_direction="SW", + wind_strength=0.57, + base_spread_prob=0.2, + ) + env = WildfireEnv( + scenario_parameter_records=[record], + scenario_families=[("center", "medium", "A")], + benchmark_mode=True, + expected_split="train", + ) + + env.reset(seed=7) + + assert env.scenario.severity == "high" + assert env.scenario.wind_direction == "SW" + assert env.scenario.wind_strength == pytest.approx(0.57) + assert env.scenario.spread_prob == pytest.approx(0.2) + + +def test_benchmark_mode_requires_initialization_seeds(): + with pytest.raises(ValueError, match="requires initialization seeds"): + WildfireEnv( + scenario_parameter_records=[_record(ignition_seed=None)], + benchmark_mode=True, + expected_split="train", + ) + + +def test_split_isolation_on_loader_expected_split(tmp_path): + path = tmp_path / "records.json" + _write_records(path, [_record(split="val")]) + + with pytest.raises(ValueError, match="Split consistency check failed"): + load_scenario_parameter_records(path, benchmark_mode=True, expected_split="train") + + +def test_split_isolation_from_filename_hint(tmp_path): + path = tmp_path / "scenario_parameter_records_train.json" + _write_records(path, [_record(split="val")]) + + with pytest.raises(ValueError, match="Split consistency check failed"): + load_scenario_parameter_records(path, benchmark_mode=True) + + +def test_split_isolation_on_env_creation_expected_split_mismatch(): + with pytest.raises(ValueError, match="expected split 'train'"): + WildfireEnv( + scenario_parameter_records=[_record(split="val")], + benchmark_mode=True, + expected_split="train", + ) + + +def test_reset_and_step_info_include_record_and_split_metadata(): + record = _record( + record_id="AB-2020-info", + split="train", + fire_id="AB-2020-777", + year=2020, + source="AB_HISTORICAL_WILDFIRE", + province="AB", + record_quality_flag="measured", + ) + env = WildfireEnv( + scenario_parameter_records=[record], + scenario_families=[("center", "medium", "A")], + benchmark_mode=True, + expected_split="train", + ) + + _obs, reset_info = env.reset(seed=21) + _obs, _reward, _done, _trunc, step_info = env.step(0) + + assert reset_info["record_id"] == "AB-2020-info" + assert reset_info["split"] == "train" + assert reset_info["ignition_seed"] is not None + assert reset_info["layout_seed"] is not None + assert reset_info["parameter_record_meta"]["fire_id"] == "AB-2020-777" + assert "spread_score" in reset_info["parameter_audit"] + + assert step_info["record_id"] == "AB-2020-info" + assert step_info["split"] == "train" + assert step_info["ignition_seed"] == reset_info["ignition_seed"] + assert step_info["layout_seed"] == reset_info["layout_seed"] + assert step_info["parameter_record_meta"]["record_quality_flag"] == "measured" + assert "cffdrs_dryness_score" in step_info["parameter_audit"] + + +def test_record_provided_initialization_seeds_make_spatial_setup_replayable(): + record = _record(record_id="AB-2020-seeded", ignition_seed=12345, layout_seed=54321) + env = WildfireEnv( + scenario_parameter_records=[record], + scenario_families=[("center", "medium", "A")], + benchmark_mode=True, + expected_split="train", + ) + + env.reset(seed=1) + grid_a = env.grid.copy() + first_info_seed_pair = (env._ignition_seed_used, env._layout_seed_used) + + env.reset(seed=999) + grid_b = env.grid.copy() + second_info_seed_pair = (env._ignition_seed_used, env._layout_seed_used) + + assert first_info_seed_pair == (12345, 54321) + assert second_info_seed_pair == (12345, 54321) + assert (grid_a == grid_b).all()