diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..d7e0f30 --- /dev/null +++ b/.env.example @@ -0,0 +1,5 @@ +NASA_FIRMS_API_KEY= + +# Optional builder controls +# STATIC_DATA_TARGET_COUNT=100 +# CFFDRS_YEAR= diff --git a/.gitignore b/.gitignore index 366262a..4e951d0 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,9 @@ wheels/ # Virtual environments .venv +.env + +data/ #logs wandb diff --git a/README.md b/README.md index bd12cf3..a0e9d16 100644 --- a/README.md +++ b/README.md @@ -4,13 +4,14 @@ Empirical RL benchmark for wildfire tactical suppression. Compares DQN, A2C, PPO ## Setup -```bash -uv sync -``` +Requirements: [uv](https://docs.astral.sh/uv/getting-started/installation/) -### Pre-commit hooks (optional) +1. clone the repo +2. in the project root, run: `uv venv && source .venv/bin/activate && uv sync` -Install [lefthook](https://github.com/evilmartians/lefthook) for local lint/format checks on commit: +### Pre-commit hooks + +Pre-commit hooks were used for the project for linting and checks. Install [lefthook](https://github.com/evilmartians/lefthook) for local lint/format checks on commit: ```bash # pick one @@ -21,15 +22,117 @@ npm i -g lefthook lefthook install ``` -## Usage +## Usage: Data Pipeline, Training and Eval + +The data pipeline now uses the Alberta historical wildfire dataset in `data/static/` as its primary source. + +Data sources: + +- Alberta historical wildfire dataset: primary incident, weather, spread-rate, and assessment-time source +- CFFDRS: optional supplementary fire-danger enrichment +- CWFIS and FIRMS: retained in the repo for legacy/live experiments, not part of the canonical build path + +Refer to `docs/data-pipeline.md` for exact fields and data we ingest. + +We build the static dataset at `src/ingestion/static_dataset.py`. The script: + +- loads historical incident rows from `data/static/fp-historical-wildfire-data-2006-2025.csv` +- normalizes them into snapshot records anchored at assessment time +- applies lightweight cleaning (strip blank strings and drop unusable rows with missing required assessment-time fields) +- optionally enriches with CFFDRS fields when `--cffdrs-year` is provided and usable +- writes frozen and normalized `snapshot_records.json` and split snapshot files in `data/static` +- computes offline environment variables and writes `scenario_parameter_records.json` plus split files in `data/static`. The environment variables written are: + - `base_spread_prob` + - `severity_bucket` + - `wind_dir_deg` + - `wind_strength` +- With the following extra fields stored: + - `spread_rate_1h_m` + - `spread_score` + - `weather_score` + - `cffdrs_dryness_score` + - `rain_factor` + - `size_factor` + - `fire_type_factor` + - `fuel_factor` + - `observed_spread_rate_m_min` + - `assessment_hectares` + - `fire_type` + - `fuel_type` + - `record_quality_flag` + +> NOTE: the stored extra fields are for checking whether the data pipeline is computing the primary metrics correctly, and checking why a record got a high/low spread setting. Their influence has already been collapsed into `base_spread_prob`, `severity_bucket`, and `wind_strength` for the canonical environment. They are not directly consumed by the current `FireEnv` dynamics to keep the initial benchmark simple and reduce overfitting/confounding risk. + +For future improvements, consider using `cffdrs_dryness_score` to influence burnout probability, `rain_factor` to damp spread for the whole episode, and `size_factor` if we agree incident size should affect spread dynamics. For now, these remain audit fields rather than direct transition inputs. + +Check `docs/data-pipeline.md` for how these variables are computed. + + +### How data is collected + +``` +Alberta historical wildfire CSV -> normalized snapshot records -> scenario parameter records. +``` + +Each record is anchored at `ASSESSMENT_DATETIME`. The builder uses observed spread rate, assessment weather, incident size, fire type, and fuel type to compute benchmark environment variables. If `--cffdrs-year` is passed and a usable station file exists, the builder also joins supplementary CFFDRS danger indices by both distance and date. + +``` +fire record -> snapshot record (`data/static/snapshot_records.json`) -> scenario (environment) parameter record (`data/static/scenario_parameter_records.json`). +``` + +- CFFDRS is supplementary. If the requested year is sparse or unavailable, the builder still works without it. +- The raw Alberta CSV already contains the main weather and spread fields used for the benchmark. + +For more details, check `docs/data-pipeline.md` + +### Usage from project root + +We run this command run to ingest our dataset (with a large cap to avoid split truncation): ```bash -# Train PPO agent (200k steps) -uv run python -m src.models.train_rl_agent +uv run python -m src.ingestion.static_dataset --target-count 50000 --raw-alberta-csv data/static/fp-historical-wildfire-data-2006-2025.csv +``` + +If CFFDRS for the selected year is sparse, the builder still runs and writes records without supplementary CFFDRS enrichment. + +Optionally, test with a smaller target count: + +```bash +uv run python -m src.ingestion.static_dataset --target-count 100 --cffdrs-year 2025 --raw-alberta-csv data/static/fp-historical-wildfire-data-2006-2025.csv +``` + +If you have your own normalized historical fire records JSON: + +```bash +uv run python -m src.ingestion.static_dataset --fire-records path/to/fire_records.json --target-count 100 +``` + +### Training + +After building the dataset, you can train by running: -# Quick test (10k steps) -uv run python -m src.models.train_rl_agent --timesteps 10000 +```bash +uv run python -m src.models.train_rl_agent --scenario-dataset data/static/scenario_parameter_records_train.json --val-dataset data/static/scenario_parameter_records_val.json --holdout-dataset data/static/scenario_parameter_records_holdout.json +``` + +The scenario parameter file can then be consumed by `FireEnv` and PPO training. + +The builder also writes year-based split files for the benchmark: + +- `train`: `2006-2022` +- `val`: `2023` +- `holdout`: `2024-2025` + +Training command: -# Train XGBoost spread model -uv run python -m src.models.spread_model +```bash +uv run python -m src.models.train_rl_agent --scenario-dataset data/static/scenario_parameter_records_train.json --val-dataset data/static/scenario_parameter_records_val.json --holdout-dataset data/static/scenario_parameter_records_holdout.json +``` + +General split benchmark evaluation (PPO + baselines): + +```bash +uv run python -m src.models.evaluate_agents --agents ppo,greedy,random --train-dataset data/static/scenario_parameter_records_train.json --val-dataset data/static/scenario_parameter_records_val.json --holdout-dataset data/static/scenario_parameter_records_holdout.json --episodes 20 --seeds 42,43,44 ``` + +The dataset builder prints cleaning/drop summaries to stdout and uses progress bars when `tqdm` is available. diff --git a/docs/data-pipeline.md b/docs/data-pipeline.md new file mode 100644 index 0000000..17da66a --- /dev/null +++ b/docs/data-pipeline.md @@ -0,0 +1,320 @@ +# Data Pipeline + +This document describes the current benchmark data pipeline after the move away from live CWFIS-centered ingestion and XGBoost. + +The canonical path now uses the Alberta historical wildfire dataset stored under `data/static/` as the primary source for building `FireEnv` scenario records. + +--- + +## 1) Overview + +The benchmark pipeline has two stages: + +1. normalize historical wildfire incidents into frozen snapshot records +2. compute offline environment-variable records for `FireEnv` + +Training and evaluation should then use only the cached parameter dataset plus seeded RNG. + +Primary source hierarchy: + +- primary: Alberta historical wildfire dataset +- supplementary: CFFDRS fire-danger indices when an annual station file is available and usable +- non-canonical: CWFIS live active fires and FIRMS hotspots + +--- + +## 2) Input Data Sources + +### 2.1 Alberta historical wildfire dataset + +Raw file path: + +- `data/static/fp-historical-wildfire-data-2006-2025.csv` + +This dataset is now the default primary input to `src/ingestion/static_dataset.py`. + +Important fields used directly by the pipeline: + +- incident identity: `YEAR`, `FIRE_NUMBER`, `FIRE_NAME` +- location: `LATITUDE`, `LONGITUDE` +- timing: `FIRE_START_DATE`, `ASSESSMENT_DATETIME`, `DISCOVERED_DATE`, `REPORTED_DATE`, `DISPATCH_DATE`, `IA_ARRIVAL_AT_FIRE_DATE`, `FIRE_FIGHTING_START_DATE` +- fire state: `ASSESSMENT_HECTARES`, `CURRENT_SIZE`, `SIZE_CLASS` +- spread/weather: `FIRE_SPREAD_RATE`, `TEMPERATURE`, `RELATIVE_HUMIDITY`, `WIND_DIRECTION`, `WIND_SPEED`, `WEATHER_CONDITIONS_OVER_FIRE` +- fire context: `FIRE_TYPE`, `FUEL_TYPE`, `FIRE_POSITION_ON_SLOPE`, `FIRE_ORIGIN` +- optional response context: `INITIAL_ACTION_BY`, `IA_ACCESS`, `BUCKETING_ON_FIRE`, `DISTANCE_FROM_WATER_SOURCE` + +Why this is now primary: + +- historical instead of live-only +- provides assessment-time weather directly +- provides observed spread rate directly +- provides assessment-time size directly +- removes dependence on ad hoc live weather reconstruction for canonical builds + +### 2.2 `src/ingestion/cffdrs.py` + +This module downloads annual CWFIS weather-station CSV data and parses: + +- `fwi` +- `isi` +- `bui` +- `dc` +- `dmc` +- `ffmc` + +Current role: + +- supplementary enrichment only +- if `--cffdrs-year` is passed and usable observations exist, the builder joins the nearest station by both distance and snapshot date +- the benchmark no longer depends on CFFDRS being available to build records + +### 2.3 `src/ingestion/cwfis.py` + +This module still downloads live active fires from CWFIS. + +Current role: + +- legacy / non-canonical +- useful for live experiments or future non-Alberta extensions +- not part of the canonical Alberta historical benchmark build + +### 2.4 `src/ingestion/firms.py` + +This module still fetches NASA FIRMS hotspots. + +Current role: + +- supplementary / non-canonical +- not used in the canonical Alberta historical benchmark build +- may still be useful for exploratory validation or future data discovery + +### 2.5 `src/ingestion/weather.py` + +This module fetches current-hour weather from Open-Meteo. + +Current role: + +- legacy / non-canonical for Alberta historical builds +- assessment-time weather now comes directly from the Alberta dataset + +--- + +## 3) Canonical Build Flow + +```text +Alberta historical wildfire CSV +-> normalized historical fire records +-> optional CFFDRS date-and-distance enrichment +-> snapshot_records.json +-> offline env-variable builder +-> scenario_parameter_records.json +-> FireEnv reset sampling +-> RL train/eval from cached records only +``` + +This path does not use FIRMS or Open-Meteo in the canonical benchmark build. + +The builder logs cleaning and drop diagnostics directly to stdout (with progress bars if `tqdm` is available) instead of writing a separate report artifact. + +### 3.1 Cleaning and vetting specification + +Cleaning is intentionally lightweight and is implemented in `src/ingestion/clean_historical.py`. + +Row-level cleaning behavior: + +- strip leading/trailing whitespace from all string fields +- convert blank strings to `null` +- drop rows missing any required core field: + - `YEAR` + - `FIRE_NUMBER` + - `LATITUDE` + - `LONGITUDE` + - `ASSESSMENT_DATETIME` + - `FIRE_SPREAD_RATE` + - `TEMPERATURE` + - `RELATIVE_HUMIDITY` + - `WIND_DIRECTION` + - `WIND_SPEED` +- drop rows where both size fields are missing: + - `ASSESSMENT_HECTARES` + - `CURRENT_SIZE` + +Normalization-time vetting in `src/ingestion/static_dataset.py` additionally drops rows that fail parsing or mapping, such as invalid datetimes, non-numeric required values, and unresolved wind direction values. + +Current drop diagnostics printed to stdout include: + +- total rows, kept rows, dropped rows +- top drop reasons (for example `missing_fire_spread_rate`, `normalization_failed`) +- per-year kept/total counts +- per-split built record counts + +--- + +## 4) Snapshot Schema + +The builder writes `data/static/snapshot_records.json`. + +It also writes per-split files using the frozen year strategy: + +- `train`: `2006-2022` +- `val`: `2023` +- `holdout`: `2024-2025` + +Generated split files: + +- `data/static/snapshot_records_train.json` +- `data/static/snapshot_records_val.json` +- `data/static/snapshot_records_holdout.json` +- `data/static/scenario_parameter_records_train.json` +- `data/static/scenario_parameter_records_val.json` +- `data/static/scenario_parameter_records_holdout.json` + +Each snapshot record represents one Alberta wildfire incident anchored at the initial assessment time. + +Core stored fields: + +- identity: `record_id`, `fire_id`, `year`, `name`, `province`, `source` +- timing: `snapshot_date`, `snapshot_datetime`, `started_at`, `updated_at` +- location: `latitude`, `longitude` +- size: `area_hectares`, `assessment_hectares`, `current_size`, `size_class` +- observed spread/weather: `observed_spread_rate_m_min`, `temperature_c`, `relative_humidity_pct`, `wind_direction_deg`, `wind_speed_km_h`, `precipitation_mm` +- fire context: `fire_type`, `fuel_type`, `weather_conditions_over_fire`, `fire_position_on_slope`, `fire_origin` +- cause/admin metadata: `general_cause`, `activity_class`, `true_cause` +- response timing metadata: `detection_delay_h`, `report_delay_h`, `dispatch_delay_h`, `ia_travel_delay_h` +- optional supplementary enrichment: `fwi`, `isi`, `bui`, `dc`, `dmc`, `ffmc`, station metadata + +Important notes: + +- `snapshot_date` is anchored to `ASSESSMENT_DATETIME` +- `area_hectares` prefers `ASSESSMENT_HECTARES`, with `CURRENT_SIZE` as fallback +- `precipitation_mm` is estimated from `WEATHER_CONDITIONS_OVER_FIRE` +- CFFDRS fields may be `null` if supplementary enrichment is unavailable + +--- + +## 5) Environment-Variable Builder + +The builder computes `data/static/scenario_parameter_records.json` from each snapshot record. + +Canonical env-facing fields: + +- `base_spread_prob` +- `severity_bucket` +- `wind_dir_deg` +- `wind_strength` + +Stored audit fields: + +- `spread_rate_1h_m` +- `spread_score` +- `weather_score` +- `cffdrs_dryness_score` +- `size_factor` +- `fire_type_factor` +- `fuel_factor` +- `rain_factor` +- `observed_spread_rate_m_min` +- `assessment_hectares` +- `fire_type` +- `fuel_type` +- `record_quality_flag` + +### Builder logic + +The current builder uses a blended physics-informed rule: + +- dominant term: observed `fire_spread_rate` +- supporting terms: wind, temperature, relative humidity, estimated precipitation, assessment size +- optional supplementary term: CFFDRS dryness score from `ISI/FWI/BUI/FFMC` +- modifiers: `fire_type` and `fuel_type` + +This is not a full Rothermel implementation. It is a benchmark-oriented, physics-informed calibration rule that keeps the simulator simple while grounding episode conditions in historical assessment data. + +--- + +## 6) Mapping From Data to Environment Variables + +| Stored env field | Source fields | Builder logic | Used by environment | +|---|---|---|---| +| `base_spread_prob` | `observed_spread_rate_m_min`, weather, size, optional CFFDRS dryness, `fire_type`, `fuel_type` | derived from blended `spread_score` | primary spread probability in `_spread_fire()` | +| `severity_bucket` | same fields as `base_spread_prob` | derived from `spread_score` thresholds | severity one-hot in observation and family matching | +| `wind_dir_deg` | `wind_direction_deg` | pass-through from Alberta assessment weather | converted to `(wx, wy)` wind bias | +| `wind_strength` | `wind_speed_km_h` | normalized and clipped from assessment wind speed | sets wind-bias magnitude | +| `spread_rate_1h_m` | `observed_spread_rate_m_min` | direct conversion to `m/hour` for audit/logging | optional logging only | + +Audit-only intermediates: + +| Stored audit field | Source fields | Purpose | +|---|---|---| +| `spread_score` | spread + weather + size + optional CFFDRS + type/fuel modifiers | blended benchmark calibration score | +| `weather_score` | wind, temperature, RH | weather contribution summary | +| `cffdrs_dryness_score` | `ISI`, `FWI`, `BUI`, `FFMC` | supplementary dryness context | +| `size_factor` | `assessment_hectares` | weak size modifier | +| `fire_type_factor` | `fire_type` | fire-behavior modifier | +| `fuel_factor` | `fuel_type` | fuel-based modifier | +| `rain_factor` | `WEATHER_CONDITIONS_OVER_FIRE` -> `precipitation_mm` | precipitation damping | + +--- + +## 7) Usage + +Build the canonical dataset from the Alberta historical CSV: + +```bash +uv run python -m src.ingestion.static_dataset --target-count 100 +``` + +Here, `--target-count 100` means up to `100` records per split, not `100` total records overall. + +For canonical benchmark builds, use a high cap to avoid truncating available records: + +```bash +uv run python -m src.ingestion.static_dataset --target-count 50000 --raw-alberta-csv data/static/fp-historical-wildfire-data-2006-2025.csv +``` + +Build with optional supplementary CFFDRS enrichment: + +```bash +uv run python -m src.ingestion.static_dataset --target-count 100 --cffdrs-year 2025 +``` + +Canonical variant with CFFDRS enrichment: + +```bash +uv run python -m src.ingestion.static_dataset --target-count 50000 --cffdrs-year 2025 --raw-alberta-csv data/static/fp-historical-wildfire-data-2006-2025.csv +``` + +Build from a pre-normalized historical JSON instead of the raw Alberta CSV: + +```bash +uv run python -m src.ingestion.static_dataset --fire-records path/to/fire_records.json --target-count 100 +``` + +Override the raw Alberta CSV path if needed: + +```bash +uv run python -m src.ingestion.static_dataset --raw-alberta-csv path/to/fp-historical-wildfire-data.csv --target-count 100 +``` + +Then train from the cached parameter file: + +```bash +uv run python -m src.models.train_rl_agent --scenario-dataset data/static/scenario_parameter_records_train.json --val-dataset data/static/scenario_parameter_records_val.json --holdout-dataset data/static/scenario_parameter_records_holdout.json +``` + +Recommended benchmark training/eval uses the split files directly: + +```bash +uv run python -m src.models.train_rl_agent --scenario-dataset data/static/scenario_parameter_records_train.json --val-dataset data/static/scenario_parameter_records_val.json --holdout-dataset data/static/scenario_parameter_records_holdout.json +``` + +--- + +## 8) Practical Constraints + +- Alberta historical data is Alberta-only, so the canonical benchmark is currently province-scoped rather than Canada-wide. +- CFFDRS annual station files may be sparse or unavailable for some years; the builder treats them as optional. +- FIRMS and CWFIS remain available in the repo but are no longer part of the canonical benchmark build path. +- The benchmark still does not use terrain rasters, perimeter replay, or a full operational spread model. + +That is acceptable for the current paper because the goal is a reproducible tactical RL benchmark, not an operational wildfire decision-support system. diff --git a/docs/envspec.md b/docs/envspec.md index a02a1f2..88257a8 100644 --- a/docs/envspec.md +++ b/docs/envspec.md @@ -1,8 +1,8 @@ -# Environment Spec: Wildfire Simulator + XGBoost Calibration +# Frozen Environment Spec: Wildfire Simulator -This document is a concrete implementation guide for how the wildfire RL environment should work and how XGBoost interfaces with it. +This document is the frozen canonical specification for the wildfire RL benchmark and its static scenario-parameter interface. -It is aligned with `impl-plan.md` and intended to remove ambiguity before coding and experiments. +It is aligned with `docs/planning/impl-plan.md` and is intended to remove ambiguity before coding, benchmarking, and reporting. --- @@ -53,6 +53,11 @@ At each step, the policy receives: 5. Severity one-hot: `[low, medium, high]` 6. Wind bias vector: `(wx, wy)` +Canonical observation rule: + +- The benchmark observation is the single encoded grid plus scalar features listed above. +- Multi-channel observation variants are allowed only as ablations or future work and must not replace the canonical benchmark interface in the main comparison. + This state lets the policy reason about: - where the fire is, @@ -64,6 +69,11 @@ This state lets the policy reason about: ## 4) Action Semantics (Hard Rules) +Action categories: + +- Mobility actions: `MOVE_N`, `MOVE_S`, `MOVE_E`, `MOVE_W` +- Intervention actions: `DEPLOY_HELICOPTER`, `DEPLOY_CREW` + Action IDs: - `0`: `MOVE_N` @@ -73,6 +83,11 @@ Action IDs: - `4`: `DEPLOY_HELICOPTER` - `5`: `DEPLOY_CREW` +Canonical action rule: + +- The canonical benchmark action set contains exactly these 6 actions. +- `WAIT` is not part of the frozen benchmark and may be introduced only in ablations. + Rules: - Movement changes position by one cell if in bounds; otherwise no movement. @@ -113,7 +128,12 @@ Spread probability is episode-parameterized: - baseline from `base_spread_prob` - adjusted by wind bias `(wx, wy)` relative to neighbor direction -- optional local modifiers if additional heterogeneity is enabled + +Canonical heterogeneity rule: + +- Wind bias is the only mandatory heterogeneity mechanism in canonical runs. +- Additional local modifiers such as flammability maps are ablations or future work and must be reported separately. +- Control-tick versus fire-tick cadence changes are also deferred to ablations or future work. Episode termination: @@ -149,55 +169,45 @@ Interpretation: --- -## 7) XGBoost Interface: What It Does and Does Not Do - -XGBoost is used to calibrate episode conditions from cached real-data snapshots. +## 7) Static Scenario Parameter Interface -- It does **not** choose actions. -- It does **not** replace the RL policy. -- It does **not** make real-time tactical decisions. +The benchmark uses cached scenario records with environment variables computed offline before training and evaluation. These variables are not inferred live during benchmark runs. -It outputs simulator parameters at episode reset. +## 7.1 Snapshot inputs used during preprocessing -## 7.1 Snapshot input features (for XGBoost) - -Required canonical features: +Required canonical fields available to the preprocessing pipeline: - weather: `wind_speed_km_h`, `wind_direction_deg`, `temperature_c`, `relative_humidity_pct`, `precipitation_mm` - danger: `fwi`, `isi`, `bui` - incident: `area_hectares`, `latitude`, `longitude`, `province` -Optional useful features (if ingestion/training pipeline is extended): +Optional retained metadata: -- `frp_mw` (FIRMS) +- `frp_mw` - `cffdrs_station_distance_km` - `dmc`, `dc`, `ffmc` -- temporal deltas from snapshot history -## 7.2 XGBoost output contract +## 7.2 Stored parameter record contract -For each snapshot record: +For each cached scenario record, store: -1. `spread_intensity` in `[0,1]` -2. `spread_rate_1h_m` (logging + interpretability) -3. `wind_dir_deg` (pass-through from snapshot) -4. `wind_strength` in `[0,1]` (normalized from wind speed) -5. `severity_bucket` from `spread_intensity` +1. `base_spread_prob` +2. `severity_bucket` +3. `wind_dir_deg` +4. `wind_strength` +5. optional logging fields such as `spread_rate_1h_m` Deterministic env mapping: -- `base_spread_prob = 0.04 + 0.18 * spread_intensity` -- severity: - - low: `<0.33` - - medium: `0.33-0.66` - - high: `>0.66` +- severity is encoded one-hot in the observation from `severity_bucket` - wind vector: - `wx = wind_strength * cos(wind_dir_deg)` - `wy = wind_strength * sin(wind_dir_deg)` +- `base_spread_prob` is consumed directly by the environment spread rule Episode rule: -- Sample one parameter record at reset. +- Sample one cached parameter record at reset. - Keep it fixed for the full episode in canonical runs. --- @@ -210,8 +220,8 @@ Required workflow: 1. Collect and normalize ingestion data. 2. Write versioned snapshot cache file(s). -3. Build XGBoost features from snapshots. -4. Produce env-parameter records. +3. Compute environment variables offline from snapshot records. +4. Produce cached env-parameter records. 5. Train/evaluate RL only from cached records + seeded RNG. Fail-fast rule: @@ -240,6 +250,11 @@ Fail-fast rule: ## 9.3 Scenario families +Asset layout definitions: + +- Layout `A`: one dense high-value asset cluster placed near moderate exposure to common ignition zones. +- Layout `B`: two smaller separated asset clusters with different distances from common ignition zones. + Train families: - ignition in `{center, edge, multi_cluster}` @@ -287,6 +302,13 @@ Secondary metrics: - resource efficiency - wasted deployment rate - held-out performance drop +- normalized burn ratio + +Normalized burn ratio definition: + +- `normalized_burn_ratio = final_burned_area_with_policy / final_burned_area_no_action_same_scenario` +- For each evaluation scenario, run a no-action baseline with the same initial scenario record and RNG seed. +- Report this as an evaluation-only metric; do not include it in the training reward. Interpretability checks: diff --git a/docs/planning/env-checklist.md b/docs/planning/env-checklist.md new file mode 100644 index 0000000..85b5fcf --- /dev/null +++ b/docs/planning/env-checklist.md @@ -0,0 +1,86 @@ +# Environment Checklist + +This checklist captures the remaining environment changes needed to align `src/models/fire_env.py` with the current benchmark design in `docs/envspec.md` and `docs/planning/impl-plan.md`. + +--- + +## 1) Replace remaining heuristic-only environment defaults + +- [ ] Remove the old assumption that severity alone determines spread through `SEVERITY_SPREAD_PROB` when a cached `base_spread_prob` record is available. +- [ ] Make the canonical path use static scenario parameter records first, with severity acting as reporting/observation metadata rather than the main spread heuristic. +- [ ] Keep severity heuristics only as a fallback dev mode, not as the benchmark-default path. +- [ ] Decide whether canonical benchmark mode should hard-fail if no cached parameter records are provided. + +## 2) Make reset-time episode construction more dataset-driven + +- [ ] Decide what belongs in the cached scenario record versus what remains randomized inside the simulator. +- [ ] If desired, extend cached records to include reset-time metadata such as ignition family, asset layout, and optional size/dryness tags. +- [ ] Stop sampling scenario families and parameter records independently if that can create inconsistent pairings. +- [ ] Replace the current severity-only record matching with a stronger record-selection rule tied to the frozen train/held-out split. + +## 3) Freeze the canonical benchmark mode more strictly + +- [ ] Add an explicit benchmark mode flag to `FireEnv` so train/eval runs cannot silently fall back to ad hoc random scenario generation. +- [ ] In benchmark mode, fail fast on missing or malformed parameter records. +- [ ] Keep legacy `base_spread_rate_m_per_min` support only for backward compatibility or remove it entirely once the static dataset path is stable. +- [ ] Ensure benchmark mode never depends on runtime live ingestion. + +## 4) Align environment parameters with the static dataset builder + +- [ ] Confirm the cached parameter schema used by `FireEnv` matches the output of `src/ingestion/static_dataset.py`. +- [ ] Use `base_spread_prob`, `wind_dir_deg`, and `wind_strength` directly from cached records in the canonical path. +- [ ] Keep extra builder fields such as `spread_rate_1h_m`, `spread_score`, `dryness_score`, and `record_quality_flag` for logging/debugging only unless promoted into the canonical env contract. +- [ ] Decide whether `severity_bucket` should be fully precomputed offline rather than inferred from old hard-coded spread heuristics. + +## 5) Tighten the reward and transition accounting + +- [ ] Check whether `new_burned` is currently measuring the intended quantity; it now tracks the change in burning cells rather than newly burned cells strictly. +- [ ] Verify the reward matches the frozen coefficients and intended semantics in `docs/envspec.md`. +- [ ] Confirm wasted-action logic matches the benchmark wording for blocked and zero-effect deployments. +- [ ] Confirm asset-loss accounting is correct when assets transition into burning cells. + +## 6) Decide what remains randomized inside the simulator + +- [ ] Keep ignition coordinates randomized within a frozen family if that is the intended benchmark design. +- [ ] Keep asset coordinates randomized within layout `A` and `B` if that is the intended benchmark design. +- [ ] If more reproducibility is needed, precompute reset seeds or exact placements in the cached scenario dataset. +- [ ] Document clearly that the benchmark is a fixed environment family with randomized episode instances, not one single fixed map. + +## 7) Improve scenario-family integration + +- [ ] Ensure train families and held-out families are sampled exactly as frozen in `docs/planning/impl-plan.md`. +- [ ] Prevent held-out family leakage during training. +- [ ] Consider storing a `split` field or family tag directly in cached scenario records. +- [ ] Confirm layout `A` and `B` generation in code really matches the written definitions. + +## 8) Clean up legacy or transitional code paths + +- [ ] Audit whether `random_scenario()` should remain part of the canonical benchmark path or only support smoke tests and ablations. +- [ ] Remove or isolate older spread-rate override code once the static parameter dataset path is fully working. +- [ ] Clarify whether `ScenarioConfig` should remain the main reset object or become a thin wrapper over cached parameter records. +- [ ] Remove comments or naming that still imply the older XGBoost-centered flow. + +## 9) Add validation and tests + +- [ ] Add tests for loading cached scenario parameter records. +- [ ] Add tests that benchmark-mode reset uses cached parameters and does not fall back silently. +- [ ] Add tests that observation shape remains stable at `636` unless the observation contract intentionally changes. +- [ ] Add tests that severity one-hot and wind bias in the observation match the active cached parameter record. +- [ ] Add tests that train/held-out family filtering works as intended. + +## 10) Nice-to-have improvements after canonical alignment + +- [ ] Add richer info logging so each episode returns the active `record_id` and scenario family tags. +- [ ] Add optional per-record diagnostics for spread calibration sanity checks. +- [ ] Consider adding a cached `ignition_seed` or `layout_seed` field for exact episode replay. +- [ ] Add a dedicated benchmark env factory that always builds the environment from frozen cached records. + +--- + +## Suggested order + +1. Make cached parameter records the canonical reset path. +2. Remove silent fallback behavior in benchmark mode. +3. Align reward/transition accounting with the written spec. +4. Freeze family sampling and held-out split handling. +5. Add tests around cached-record loading and reset behavior. diff --git a/docs/planning/impl-plan.md b/docs/planning/impl-plan.md index a518281..d48d1a5 100644 --- a/docs/planning/impl-plan.md +++ b/docs/planning/impl-plan.md @@ -22,7 +22,8 @@ Given a spreading wildfire on a grid and limited suppression resources, what tac - Do not claim operational readiness. - Do not claim superiority over real emergency protocols. -- Treat real data ingestion and XGBoost as simulator calibration support only. +- Treat real data ingestion as scenario-construction support only. +- Do not claim empirical wildfire spread prediction. ### Canonical claim text @@ -50,8 +51,8 @@ Any deviation must be explicitly labeled as an ablation and reported separately. flowchart TD A[Feasible Data Sources] --> B[Ingestion and Normalization] B --> C[Mandatory Snapshot Cache] - C --> D[XGBoost Spread Calibration] - D --> E[Scenario Parameters] + C --> D[Static Parameter Preprocessing] + D --> E[Scenario Parameter Records] E --> F[Enhanced RL Environment] F --> G[DQN A2C PPO] F --> H[Greedy Random] @@ -86,8 +87,18 @@ Observation at step `t` contains: 7. severity bucket (`low`, `medium`, `high`) encoded one-hot 8. wind bias vector `(wx, wy)` if wind-bias mode enabled +Frozen observation rule: + +- The canonical benchmark uses the encoded `fire_grid` plus scalar features listed above. +- Multi-channel observation variants are allowed only as ablations or future work and must be reported separately. + ## 4.2 Action set and exact semantics +Action categories: + +- mobility: `MOVE_N`, `MOVE_S`, `MOVE_E`, `MOVE_W` +- intervention: `DEPLOY_HELICOPTER`, `DEPLOY_CREW` + Actions: - `0`: `MOVE_N` @@ -97,6 +108,11 @@ Actions: - `4`: `DEPLOY_HELICOPTER` - `5`: `DEPLOY_CREW` +Frozen action rule: + +- The canonical action space contains exactly these 6 actions. +- `WAIT` is out of scope for the frozen benchmark and may appear only in ablations. + Hard definitions: - Movement actions move the agent by one cell if in bounds; otherwise no movement. @@ -119,9 +135,10 @@ Per-episode budgets: ## 4.3 Fire dynamics - Fire spreads stochastically from burning cells to neighbors. -- Baseline spread probability is scenario-dependent from XGBoost calibration. +- Baseline spread probability is scenario-dependent from precomputed episode parameters. - Heterogeneity mode for canonical runs: **wind bias enabled**. - Wind bias increases ignition probability downwind and decreases upwind. +- Local flammability maps and control-tick versus fire-tick cadence are deferred to ablations or future work. --- @@ -162,7 +179,7 @@ If unstable, adjust only `asset` and `burn` coefficients once, then freeze. - Ignition layout: `center`, `edge`, `corner`, `multi_cluster` - Severity: `low`, `medium`, `high` -- Asset layout type: `A` (single critical cluster), `B` (two smaller critical clusters) +- Asset layout type: `A` (one dense high-value cluster near moderate exposure), `B` (two smaller separated clusters with different exposure distances) ## 6.2 Training families (frozen) @@ -189,14 +206,12 @@ Required methods: - **Greedy heuristic** (non-RL baseline) - **Random** (sanity floor) -Do not include recurrent baseline unless hidden regime shifts are explicitly added and tested. +Recurrent baselines are not included because we will not add and test hidden regime shifts. --- ## 8) Benchmark Harness and Logging (Required Infrastructure) -The harness is mandatory and must be built before full training runs. - Requirements: 1. Unified runner for all algorithms. @@ -219,81 +234,69 @@ Secondary metrics: - resource efficiency - wasted deployment rate - held-out performance drop +- normalized burn ratio + +Normalized burn ratio definition: + +- `final_burned_area_with_policy / final_burned_area_no_action_same_scenario` +- The denominator comes from a no-action baseline rollout using the same scenario record and RNG seed. +- This is an evaluation-only metric and does not modify the training reward. --- -## 9) XGBoost Interface to Environment (Refined) +## 9) Static Scenario Parameter Interface -XGBoost is a calibration layer between ingested wildfire/weather snapshots and simulator episode parameters. It is not a control policy. +The benchmark uses a static scenario-parameter dataset built offline from ingested wildfire, weather, and fire-danger records. These parameters are not predicted at runtime. -## 9.1 Input feature contract with availability status +## 9.1 Snapshot inputs used during preprocessing Canonical feature groups: -1. **Weather (supported now)** +1. **Weather** - `wind_speed_km_h` - `wind_direction_deg` - `temperature_c` - `relative_humidity_pct` - `precipitation_mm` -2. **Fire danger indices (supported now)** +2. **Fire danger indices** - `fwi`, `isi`, `bui` -> TODO: check caveats to ensure data quality for pipeline! - -1. **Incident context (supported now, with caveats)** - - `area_hectares` (often missing for FIRMS hotspots) +3. **Incident context** + - `area_hectares` - `latitude`, `longitude` - - `province` (categorical coarse location) - -4. **Useful optional features (partially supported or easy to add)** - - `frp_mw` from FIRMS hotspot intensity (partially available) - - `cffdrs_station_distance_km` (derive from nearest-station lookup) - - `dmc`, `dc`, `ffmc` (already available from CFFDRS response) - - simple temporal deltas from snapshots (e.g., 6h wind or RH change) - -5. **Do not include as canonical unless truly ingested** - - synthetic `slope_pct` - - synthetic `rh_trend_24h` - -### Required ingestion/training code updates for optional features + - `province` -- Extend snapshot schema to persist `frp_mw`, station distance, and additional CFFDRS indices. -- Update XGBoost feature builder to include encoded `province` and missingness flags (e.g., `has_area`, `has_frp`). -- Remove hidden defaults during benchmark feature generation; fail fast if required canonical features are missing. +4. **Optional retained metadata** + - `frp_mw` + - `cffdrs_station_distance_km` + - `dmc`, `dc`, `ffmc` -## 9.2 Output contract for simulator (refined) +Preprocessing rule: -For each snapshot record, produce: +- The pipeline computes environment variables offline before writing the static scenario dataset. +- Any variable used in canonical benchmarking must be present in the stored record; benchmark mode must fail fast on missing required fields. -1. `spread_intensity` in `[0, 1]` (primary scalar for fire-growth pressure) -2. `spread_rate_1h_m` (interpretable spread scale for logging and sanity checks) -3. `wind_dir_deg` (pass-through from weather input, not predicted) -4. `wind_strength` in `[0, 1]` (normalized from observed wind speed) -5. `severity_bucket` (`low`, `medium`, `high`) derived deterministically from `spread_intensity` +## 9.2 Stored parameter record for the simulator -Deterministic mapping to env parameters: +For each scenario record, store: -- `base_spread_prob = 0.04 + 0.18 * spread_intensity` -- severity bucket thresholds: - - low: `< 0.33` - - medium: `0.33-0.66` - - high: `> 0.66` -- wind bias vector: - - `wx = wind_strength * cos(wind_dir_deg)` - - `wy = wind_strength * sin(wind_dir_deg)` +1. `base_spread_prob` +2. `severity_bucket` in `{low, medium, high}` +3. `wind_dir_deg` +4. `wind_strength` in `[0, 1]` +5. optional logging fields such as `spread_rate_1h_m` if produced during preprocessing Episode sampling rule: -- At reset, sample one snapshot-derived parameter record for the episode. +- At reset, sample one cached parameter record for the episode. - Parameters remain fixed for the full episode in canonical runs. ## 9.3 Why this interface is chosen -- Keeps RL benchmark focused on tactical decision-making rather than end-to-end forecasting claims. -- Uses real ingested signals where available while preserving deterministic simulator reproducibility. -- Avoids modeling wind direction with XGBoost when it is already directly observed. +- Keeps the RL benchmark focused on tactical decision-making rather than learned spread prediction. +- Uses ingested data to define realistic variation in episode conditions while preserving deterministic reproducibility. +- Avoids runtime API dependence and avoids overclaiming forecasting capability. --- @@ -315,8 +318,7 @@ flowchart TD C[Open-Meteo] --> N D[CFFDRS] --> N N --> S[Versioned Snapshot Cache] - S --> X[XGBoost Feature Builder] - X --> P[Env Parameter Records] + S --> P[Offline Env Parameter Builder] P --> E[RL Scenario Generator] ``` @@ -352,12 +354,14 @@ flowchart TD 1. Freeze objective, protocol numbers, held-out split. 2. Implement assets, budgets, cooldown semantics. 3. Implement wind-bias heterogeneity. -4. Implement mandatory benchmark harness/log schema/eval mode. -5. Implement scenario generator with frozen train/test families. -6. Implement snapshot cache loader and XGBoost-to-env parameter mapping. -7. Run reward sanity pass and freeze coefficients. -8. Run full multi-seed benchmarks for DQN/A2C/PPO + greedy/random. -9. Aggregate plots/tables and write limitations. +4. Define asset layouts `A` and `B` explicitly in the generator and docs. +5. Implement mandatory benchmark harness/log schema/eval mode. +6. Implement scenario generator with frozen train/test families. +7. Implement snapshot cache loader and offline parameter-to-env mapping. +8. Add evaluation-only normalized burn ratio reporting. +9. Run reward sanity pass and freeze coefficients. +10. Run full multi-seed benchmarks for DQN/A2C/PPO + greedy/random. +11. Aggregate plots/tables and write limitations. --- diff --git a/docs/planning/proposal.md b/docs/planning/proposal.md index fe36ed0..d2b5512 100644 --- a/docs/planning/proposal.md +++ b/docs/planning/proposal.md @@ -33,7 +33,7 @@ The developed technique is not a new RL algorithm. It is an enhanced benchmark e 1. **Prioritization under risk**: critical assets can be lost if not protected. 2. **Planning under scarcity**: helicopter/crew actions are limited and costly. -3. **Spatial reasoning**: non-uniform spread field (flammability map or wind bias). +3. **Spatial reasoning**: non-uniform spread conditions from fixed episode parameters such as spread severity and wind bias. 4. **Robustness testing**: multiple scenario families and held-out test families. Core intuition: better benchmark structure and rigorous evaluation produce more defensible RL evidence than adding algorithmic novelty under time pressure. @@ -64,8 +64,8 @@ Optional only if hidden regime shifts are added and time permits: - limited helicopter drops and crew deployments - resource cost and cooldown penalties -3. **Heterogeneous spread field (choose one)** - - flammability map, or +3. **Heterogeneous spread conditions** + - precomputed spread severity from the static dataset - directional wind bias 4. **Clean benchmark harness** @@ -151,11 +151,12 @@ Data pipeline remains **supporting context**, not the central empirical claim. Based on audit findings, claims must stay realistic: - implemented ingestion: FIRMS, CWFIS active fires, Open-Meteo, CFFDRS +- benchmark use: one-time ingestion and preprocessing into static scenario records with precomputed environment variables - not fully implemented as production ETL: CIFFC, BC/AB ArcGIS full pipeline, ECCC Datamart orchestration, broad historical validated spread labels Paper wording will avoid operational overclaim and state: -"We benchmark RL methods in an enhanced custom wildfire simulator inspired by wildfire decision-support structure." +"We benchmark RL methods in an enhanced custom wildfire simulator using static snapshot-derived scenario records and fixed environment parameterization." --- @@ -166,7 +167,7 @@ Day 1: - Add critical assets + resource budgets to environment. Day 2: -- Add scenario generator and heterogeneous spread field. +- Add scenario generator and static parameter preprocessing for spread severity and wind bias. - Add eval mode without fallback contamination. Day 3: @@ -175,7 +176,7 @@ Day 3: Day 4: - Pilot runs and reward sanity checks. -- Fix instability and calibration issues. +- Fix instability and environment calibration issues. Day 5: - Full multi-seed train/eval runs. diff --git a/fp-historical-wildfire-data-dictionary-2006-2025.pdf b/fp-historical-wildfire-data-dictionary-2006-2025.pdf new file mode 100644 index 0000000..97f6e1c Binary files /dev/null and b/fp-historical-wildfire-data-dictionary-2006-2025.pdf differ diff --git a/lefthook.yml b/lefthook.yml index 36155ba..55111a5 100644 --- a/lefthook.yml +++ b/lefthook.yml @@ -2,7 +2,9 @@ pre-commit: commands: lint: glob: "*.py" - run: uv run ruff check {staged_files} - format-check: + run: uv run ruff check --fix --unsafe-fixes {staged_files} + stage_fixed: true + format: glob: "*.py" - run: uv run ruff format --check {staged_files} + run: uv run ruff format {staged_files} + stage_fixed: true diff --git a/pyproject.toml b/pyproject.toml index 9a10321..1251b7a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,9 @@ description = "Add your description here" readme = "README.md" requires-python = ">=3.14" dependencies = [ + "python-dotenv>=1.1.1", "gymnasium>=1.2.3", + "httpx>=0.28.1", "matplotlib>=3.10.8", "numpy>=2.4.2", "torch>=2.10.0", diff --git a/src/ingestion/cffdrs.py b/src/ingestion/cffdrs.py index 662554b..d9c69f8 100644 --- a/src/ingestion/cffdrs.py +++ b/src/ingestion/cffdrs.py @@ -23,7 +23,7 @@ import io import logging import math -from datetime import UTC, datetime +from datetime import UTC, date, datetime import httpx @@ -43,8 +43,10 @@ def _haversine_km(lat1: float, lon1: float, lat2: float, lon2: float) -> float: R = 6371.0 dlat = math.radians(lat2 - lat1) dlon = math.radians(lon2 - lon1) - a = (math.sin(dlat / 2) ** 2 - + math.cos(math.radians(lat1)) * math.cos(math.radians(lat2)) * math.sin(dlon / 2) ** 2) + a = ( + math.sin(dlat / 2) ** 2 + + math.cos(math.radians(lat1)) * math.cos(math.radians(lat2)) * math.sin(dlon / 2) ** 2 + ) return R * 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a)) @@ -53,10 +55,23 @@ def _parse_float(val: str) -> float | None: try: f = float(val) return None if f < -900 else f # CWFIS uses -999 as missing value sentinel - except (ValueError, TypeError): + except ValueError, TypeError: return None +def _parse_station_date(val: str) -> date | None: + """Parse a station observation date from common CFFDRS formats.""" + if not val: + return None + text = str(val).strip() + for fmt in ("%Y-%m-%d", "%Y/%m/%d", "%Y%m%d", "%d-%b-%Y"): + try: + return datetime.strptime(text, fmt).date() + except ValueError: + continue + return None + + def fetch_cffdrs_stations(year: int | None = None) -> list[dict]: """ Download the full CWFIS annual FWI observation CSV and parse it. @@ -110,22 +125,29 @@ def fetch_cffdrs_stations(year: int | None = None) -> list[dict]: "latitude": lat, "longitude": lon, "date": row.get("date", row.get("DATE", "")).strip(), + "observation_date": _parse_station_date(row.get("date", row.get("DATE", "")).strip()), # Core CFFDRS indices - "fwi": _parse_float(row.get("fwi", row.get("FWI", ""))), - "isi": _parse_float(row.get("isi", row.get("ISI", ""))), - "bui": _parse_float(row.get("bui", row.get("BUI", ""))), - "dc": _parse_float(row.get("dc", row.get("DC", ""))), - "dmc": _parse_float(row.get("dmc", row.get("DMC", ""))), + "fwi": _parse_float(row.get("fwi", row.get("FWI", ""))), + "isi": _parse_float(row.get("isi", row.get("ISI", ""))), + "bui": _parse_float(row.get("bui", row.get("BUI", ""))), + "dc": _parse_float(row.get("dc", row.get("DC", ""))), + "dmc": _parse_float(row.get("dmc", row.get("DMC", ""))), "ffmc": _parse_float(row.get("ffmc", row.get("FFMC", ""))), # Observed weather at station - "temp_c": _parse_float(row.get("temp", row.get("TEMP", ""))), - "rh_pct": _parse_float(row.get("rh", row.get("RH", ""))), - "ws_km_h": _parse_float(row.get("ws", row.get("WS", ""))), - "precip_mm": _parse_float(row.get("prec", row.get("PREC", ""))), + "temp_c": _parse_float(row.get("temp", row.get("TEMP", ""))), + "rh_pct": _parse_float(row.get("rh", row.get("RH", ""))), + "ws_km_h": _parse_float(row.get("ws", row.get("WS", ""))), + "precip_mm": _parse_float(row.get("prec", row.get("PREC", ""))), } stations.append(station) logger.info(f"CFFDRS: loaded {len(stations)} BC/AB stations") + valid_fwi = sum(1 for stn in stations if stn.get("fwi") is not None) + if valid_fwi == 0 and stations: + logger.warning( + "CFFDRS station file loaded but contains no usable FWI values. " + "This often happens outside fire season or when the requested year has sparse observations." + ) return stations @@ -134,6 +156,8 @@ def get_cffdrs_for_location( longitude: float, stations: list[dict] | None = None, max_radius_km: float = 200.0, + target_date: date | None = None, + max_date_offset_days: int = 1, ) -> dict | None: """ Find the nearest CWFIS weather station and return its CFFDRS indices. @@ -154,15 +178,25 @@ def get_cffdrs_for_location( if not stations: return None - # Find nearest station with valid FWI data + # Find nearest station with valid FWI data, preferring date-aligned observations. best = None best_dist = float("inf") + best_date_offset = float("inf") for stn in stations: if stn.get("fwi") is None: continue # skip stations with missing data + obs_date = stn.get("observation_date") + date_offset = 0 + if target_date is not None: + if obs_date is None: + continue + date_offset = abs((obs_date - target_date).days) + if date_offset > max_date_offset_days: + continue dist = _haversine_km(latitude, longitude, stn["latitude"], stn["longitude"]) - if dist < best_dist: + if date_offset < best_date_offset or (date_offset == best_date_offset and dist < best_dist): + best_date_offset = date_offset best_dist = dist best = stn @@ -178,11 +212,15 @@ def get_cffdrs_for_location( "source_station_id": best["station_id"], "distance_km": round(best_dist, 1), "date": best["date"], - "fwi": best["fwi"], - "isi": best["isi"], - "bui": best["bui"], - "dc": best["dc"], - "dmc": best["dmc"], + "observation_date": best["observation_date"].isoformat() + if best.get("observation_date") + else None, + "date_offset_days": int(best_date_offset) if target_date is not None else 0, + "fwi": best["fwi"], + "isi": best["isi"], + "bui": best["bui"], + "dc": best["dc"], + "dmc": best["dmc"], "ffmc": best["ffmc"], } @@ -222,10 +260,30 @@ def get_cffdrs_for_fires(fires: list[dict]) -> dict[str, dict]: logging.basicConfig(level=logging.INFO) test_fires = [ - {"fire_id": "BC-2026-001", "name": "Okanagan Ridge Fire", "latitude": 49.9071, "longitude": -119.496}, - {"fire_id": "BC-2026-002", "name": "Kamloops Plateau Fire", "latitude": 50.6745, "longitude": -120.3273}, - {"fire_id": "BC-2026-003", "name": "Fraser Valley Approach","latitude": 49.3845, "longitude": -121.4483}, - {"fire_id": "AB-2026-001", "name": "Peace River Complex", "latitude": 56.2370, "longitude": -117.2900}, + { + "fire_id": "BC-2026-001", + "name": "Okanagan Ridge Fire", + "latitude": 49.9071, + "longitude": -119.496, + }, + { + "fire_id": "BC-2026-002", + "name": "Kamloops Plateau Fire", + "latitude": 50.6745, + "longitude": -120.3273, + }, + { + "fire_id": "BC-2026-003", + "name": "Fraser Valley Approach", + "latitude": 49.3845, + "longitude": -121.4483, + }, + { + "fire_id": "AB-2026-001", + "name": "Peace River Complex", + "latitude": 56.2370, + "longitude": -117.2900, + }, ] print("Fetching CFFDRS fire danger indices from CWFIS/NRCan...\n") diff --git a/src/ingestion/clean_historical.py b/src/ingestion/clean_historical.py new file mode 100644 index 0000000..bfd9ea0 --- /dev/null +++ b/src/ingestion/clean_historical.py @@ -0,0 +1,48 @@ +"""Utilities for lightweight cleaning of Alberta historical wildfire rows.""" + +from __future__ import annotations + +REQUIRED_RAW_FIELDS = ( + "YEAR", + "FIRE_NUMBER", + "LATITUDE", + "LONGITUDE", + "ASSESSMENT_DATETIME", + "FIRE_SPREAD_RATE", + "TEMPERATURE", + "RELATIVE_HUMIDITY", + "WIND_DIRECTION", + "WIND_SPEED", +) + + +def clean_raw_historical_row_with_reason(row: dict) -> tuple[dict | None, str | None]: + """Trim strings and drop rows missing required canonical fields. + + This stays intentionally lightweight: strip blanks, normalize empty strings, + and reject rows that lack core assessment-time fields needed by the builder. + """ + cleaned: dict[str, object] = {} + for key, value in row.items(): + if isinstance(value, str): + stripped = value.strip() + cleaned[key] = stripped if stripped != "" else None + else: + cleaned[key] = value + + for field in REQUIRED_RAW_FIELDS: + if cleaned.get(field) in (None, ""): + return None, f"missing_{field.lower()}" + + area_fields_present = cleaned.get("ASSESSMENT_HECTARES") not in (None, "") or cleaned.get( + "CURRENT_SIZE" + ) not in (None, "") + if not area_fields_present: + return None, "missing_area_fields" + + return cleaned, None + + +def clean_raw_historical_row(row: dict) -> dict | None: + cleaned, _reason = clean_raw_historical_row_with_reason(row) + return cleaned diff --git a/src/ingestion/cwfis.py b/src/ingestion/cwfis.py deleted file mode 100644 index ab42513..0000000 --- a/src/ingestion/cwfis.py +++ /dev/null @@ -1,160 +0,0 @@ -""" -cwfis.py — Canadian Wildland Fire Information System (CWFIS) ingestion. - -Downloads the official active fire list from the NRCan open data portal. -No API key required — this is a public government dataset. - -Data source: https://cwfis.cfs.nrcan.gc.ca/downloads/activefires/ -Updated daily around 13:00 UTC by Natural Resources Canada. - -Run from backend/ to test: - uv run python -m src.ingestion.cwfis -""" - -import csv -import io -import logging -from datetime import UTC, datetime - -import httpx - -logger = logging.getLogger(__name__) - -# ── CWFIS Open Data URLs ────────────────────────────────────────────────────── -# This CSV is updated daily by NRCan. No auth required. -CWFIS_ACTIVEFIRES_URL = "https://cwfis.cfs.nrcan.gc.ca/downloads/activefires/activefires.csv" - -# Only ingest BC and AB for our scope -TARGET_PROVINCES = {"BC", "AB"} - - -def _severity_from_status(status: str) -> str: - """Map CWFIS fire status codes to FireGrid severity labels.""" - s = status.upper().strip() - if "OUT OF CONTROL" in s or s == "OC": - return "extreme" - elif "BEING HELD" in s or s == "BH": - return "high" - elif "UNDER CONTROL" in s or s == "UC": - return "moderate" - return "low" - - -def _normalize_cwfis_row(row: dict) -> dict | None: - """ - Normalize a single CWFIS CSV row into a FireGrid FireEvent dict. - - CWFIS CSV columns (as of 2024): - agency, firename, lat, lon, startdate, hectares, status, stage_of_control - Returns None if critical fields are missing. - """ - try: - # Province comes from the agency code (e.g. "BC", "AB", "ON") - agency = row.get("agency", "").strip().upper() - province = agency[:2] if len(agency) >= 2 else "OTHER" - - # Filter to just BC + AB - if province not in TARGET_PROVINCES: - return None - - lat = float(row.get("lat") or row.get("latitude", 0)) - lon = float(row.get("lon") or row.get("longitude", 0)) - - if lat == 0 and lon == 0: - return None - - fire_name = row.get("firename", "").strip() or f"CWFIS Fire ({province})" - fire_number = row.get("firenumber", "").strip() or row.get("firename", "UNK") - status = row.get("stage_of_control", row.get("status", "Unknown")).strip() - hectares_raw = row.get("hectares", row.get("area", "0")) or "0" - hectares = float(hectares_raw) if hectares_raw else 0.0 - start_date = row.get("startdate", row.get("discovered", "")).strip() - - # Build ISO timestamp from start date - try: - started_at = datetime.strptime(start_date, "%Y-%m-%d").replace(tzinfo=UTC).isoformat() - except (ValueError, TypeError): - started_at = datetime.now(UTC).isoformat() - - # Build a stable fire_id using province + fire number - safe_num = fire_number.replace(" ", "_").replace("/", "-") - fire_id = f"CWFIS-{province}-{safe_num}" - - return { - "fire_id": fire_id, - "province": province, - "name": fire_name, - "status": status.lower().replace(" ", "_"), - "severity": _severity_from_status(status), - "latitude": lat, - "longitude": lon, - "area_hectares": hectares, - "started_at": started_at, - "updated_at": datetime.now(UTC).isoformat(), - "source": "CWFIS_NRCAN", - } - except (ValueError, KeyError, TypeError) as e: - logger.warning(f"Skipping malformed CWFIS row: {e} | row={row}") - return None - - -def fetch_cwfis_activefires() -> list[dict]: - """ - Download and parse the CWFIS active fires CSV. - Filters to BC + AB only. - - Returns: - List of normalized FireEvent dicts. - - No rate limit — this is a static daily file. One HTTP GET. - """ - logger.info(f"Fetching CWFIS active fires from {CWFIS_ACTIVEFIRES_URL}") - - try: - with httpx.Client(timeout=20) as client: - resp = client.get(CWFIS_ACTIVEFIRES_URL) - resp.raise_for_status() - except httpx.TimeoutException: - logger.error("CWFIS download timed out.") - return [] - except httpx.HTTPStatusError as e: - logger.error(f"CWFIS HTTP error: {e.response.status_code}") - return [] - except httpx.RequestError as e: - logger.error(f"CWFIS request failed: {e}") - return [] - - # CWFIS CSV sometimes has a BOM or extra header lines — strip them - text = resp.text.lstrip("\ufeff") # strip BOM if present - reader = csv.DictReader(io.StringIO(text)) - - fires = [] - for row in reader: - normalized = _normalize_cwfis_row(row) - if normalized: - fires.append(normalized) - - logger.info(f"CWFIS: found {len(fires)} active fires in BC + AB") - return fires - - -def get_cwfis_fires() -> list[dict]: - """ - Public interface: fetch and return active fires from CWFIS over BC + AB. - Called by the API endpoint when real data is requested. - """ - return fetch_cwfis_activefires() - - -# ── Manual test ────────────────────────────────────────────────────────────── -if __name__ == "__main__": - import json - logging.basicConfig(level=logging.INFO) - print("Fetching CWFIS active fires (NRCan)...") - fires = get_cwfis_fires() - print(f"\nGot {len(fires)} fires in BC + AB.\n") - if fires: - print("Sample fire:") - print(json.dumps(fires[0], indent=2)) - else: - print("No active fires right now (or off-season). CSV was empty for BC/AB.") diff --git a/src/ingestion/dummy.py b/src/ingestion/dummy.py deleted file mode 100644 index e2a528f..0000000 --- a/src/ingestion/dummy.py +++ /dev/null @@ -1,234 +0,0 @@ -""" -dummy.py — Dummy data generators for all FireGrid data types. - -These helpers keep the API usable when live data sources are unavailable. -""" - -import random -from datetime import UTC, datetime, timedelta - -# ── Seed for reproducible dummy data ─────────────────────────────────────────── -random.seed(42) - - -# ── Helpers ───────────────────────────────────────────────────────────────────── - -def _rand_bc_coord() -> tuple[float, float]: - """Random coordinate inside British Columbia.""" - lat = random.uniform(49.0, 59.0) - lon = random.uniform(-139.0, -114.0) - return round(lat, 5), round(lon, 5) - - -def _rand_ab_coord() -> tuple[float, float]: - """Random coordinate inside Alberta.""" - lat = random.uniform(49.0, 60.0) - lon = random.uniform(-120.0, -110.0) - return round(lat, 5), round(lon, 5) - - -def _rand_coord() -> tuple[float, float]: - return random.choice([_rand_bc_coord, _rand_ab_coord])() - - -# ── Fire Incidents ─────────────────────────────────────────────────────────────── - -DUMMY_FIRE_INCIDENTS = [ - { - "fire_id": "BC-2026-001", - "province": "BC", - "name": "Okanagan Ridge Fire", - "status": "out_of_control", - "severity": "extreme", - "latitude": 49.9071, - "longitude": -119.4960, - "area_hectares": 4200.0, - "started_at": (datetime.now(UTC) - timedelta(hours=18)).isoformat(), - "updated_at": datetime.now(UTC).isoformat(), - "source": "dummy", - }, - { - "fire_id": "BC-2026-002", - "province": "BC", - "name": "Kamloops Plateau Fire", - "status": "being_held", - "severity": "high", - "latitude": 50.6745, - "longitude": -120.3273, - "area_hectares": 800.0, - "started_at": (datetime.now(UTC) - timedelta(hours=6)).isoformat(), - "updated_at": datetime.now(UTC).isoformat(), - "source": "dummy", - }, - { - "fire_id": "BC-2026-003", - "province": "BC", - "name": "Fraser Valley Approach", - "status": "being_held", - "severity": "high", - "latitude": 49.3845, - "longitude": -121.4483, - "area_hectares": 250.0, - "started_at": (datetime.now(UTC) - timedelta(hours=8)).isoformat(), - "updated_at": datetime.now(UTC).isoformat(), - "source": "dummy", - }, - { - "fire_id": "AB-2026-001", - "province": "AB", - "name": "Peace River Complex", - "status": "out_of_control", - "severity": "extreme", - "latitude": 56.2370, - "longitude": -117.2900, - "area_hectares": 12500.0, - "started_at": (datetime.now(UTC) - timedelta(hours=36)).isoformat(), - "updated_at": datetime.now(UTC).isoformat(), - "source": "dummy", - }, -] - - -def get_dummy_fires() -> list[dict]: - return DUMMY_FIRE_INCIDENTS - - -def get_dummy_fire_by_id(fire_id: str) -> dict | None: - return next((f for f in DUMMY_FIRE_INCIDENTS if f["fire_id"] == fire_id), None) - - -# ── Burn Probability Grid ──────────────────────────────────────────────────────── - -def get_dummy_burn_probability(fire_id: str) -> dict: - """ - Returns a 5x5 grid of burn probability cells around the fire origin. - In production this will be replaced by XGBoost model inference. - """ - fire = get_dummy_fire_by_id(fire_id) - if not fire: - return {} - - origin_lat = fire["latitude"] - origin_lon = fire["longitude"] - grid_step = 0.05 # ~5km cells - - cells = [] - for i in range(-2, 3): - for j in range(-2, 3): - # Probability peaks at origin and decays outward (simulates wind pushing east) - dist = abs(i) + abs(j - 1) # offset east to simulate wind direction - probability = max(0.0, round(0.95 - (dist * 0.18) + random.uniform(-0.05, 0.05), 3)) - cells.append({ - "latitude": round(origin_lat + i * grid_step, 5), - "longitude": round(origin_lon + j * grid_step, 5), - "burn_probability": probability, - "cell_size_km": 5.0, - }) - - return { - "fire_id": fire_id, - "model": "dummy_v0", - "horizon_hours": 24, - "generated_at": datetime.now(UTC).isoformat(), - "wind_speed_kmh": random.uniform(20, 60), - "wind_direction_deg": random.uniform(220, 280), # SW winds, pushing NE - "cells": cells, - } - - -# ── Asset Inventory ────────────────────────────────────────────────────────────── - -DUMMY_ASSETS = [ - # Ground crews - {"asset_id": "CREW-001", "type": "hotshot_crew", "size": 20, "status": "available", "latitude": 49.8880, "longitude": -119.4960, "province": "BC"}, - {"asset_id": "CREW-002", "type": "hotshot_crew", "size": 20, "status": "available", "latitude": 50.6745, "longitude": -120.1010, "province": "BC"}, - {"asset_id": "CREW-003", "type": "hotshot_crew", "size": 20, "status": "deployed", "latitude": 56.1200, "longitude": -117.3500, "province": "AB"}, - {"asset_id": "CREW-004", "type": "initial_attack_crew", "size": 6, "status": "available", "latitude": 49.9500, "longitude": -119.5000, "province": "BC"}, - # Heavy equipment - {"asset_id": "DOZER-001", "type": "bulldozer", "size": 1, "status": "available", "latitude": 49.9100, "longitude": -119.5200, "province": "BC"}, - {"asset_id": "DOZER-002", "type": "bulldozer", "size": 1, "status": "available", "latitude": 56.2000, "longitude": -117.2500, "province": "AB"}, - # Aircraft - {"asset_id": "AIR-001", "type": "water_bomber", "size": 1, "status": "available", "latitude": 49.4627, "longitude": -119.5720, "province": "BC"}, - {"asset_id": "AIR-002", "type": "water_bomber", "size": 1, "status": "deployed", "latitude": 56.2370, "longitude": -117.2800, "province": "AB"}, - {"asset_id": "AIR-003", "type": "helicopter", "size": 1, "status": "available", "latitude": 50.7000, "longitude": -120.3500, "province": "BC"}, -] - - -def get_dummy_assets(province: str | None = None) -> list[dict]: - if province: - return [a for a in DUMMY_ASSETS if a["province"] == province] - return DUMMY_ASSETS - - -def get_dummy_assets_summary() -> dict: - available = [a for a in DUMMY_ASSETS if a["status"] == "available"] - deployed = [a for a in DUMMY_ASSETS if a["status"] == "deployed"] - return { - "total": len(DUMMY_ASSETS), - "available": len(available), - "deployed": len(deployed), - "by_type": { - "hotshot_crew": len([a for a in DUMMY_ASSETS if a["type"] == "hotshot_crew"]), - "initial_attack_crew": len([a for a in DUMMY_ASSETS if a["type"] == "initial_attack_crew"]), - "bulldozer": len([a for a in DUMMY_ASSETS if a["type"] == "bulldozer"]), - "water_bomber": len([a for a in DUMMY_ASSETS if a["type"] == "water_bomber"]), - "helicopter": len([a for a in DUMMY_ASSETS if a["type"] == "helicopter"]), - }, - } - - -# ── Choke Point Recommendations (Greedy MVP) ──────────────────────────────────── - -def get_dummy_choke_points(fire_id: str) -> dict: - """ - MVP greedy heuristic: returns ranked deployment zones for the given fire. - Each node is scored by simulated burn_probability × accessibility. - In production this is replaced by the RL agent inference. - """ - fire = get_dummy_fire_by_id(fire_id) - if not fire: - return {} - - lat = fire["latitude"] - lon = fire["longitude"] - - recommendations = [ - { - "choke_point_id": f"{fire_id}-CP-001", - "latitude": round(lat + 0.12, 5), - "longitude": round(lon + 0.08, 5), - "priority_score": 0.94, - "recommended_action": "selective_backburn", - "recommended_assets": ["hotshot_crew", "bulldozer"], - "estimated_crew_size": 20, - "rationale": "Highest predicted burn probability in 24h window. Ridgeline break creates natural containment anchor.", - }, - { - "choke_point_id": f"{fire_id}-CP-002", - "latitude": round(lat + 0.07, 5), - "longitude": round(lon + 0.14, 5), - "priority_score": 0.78, - "recommended_action": "firebreak_construction", - "recommended_assets": ["bulldozer"], - "estimated_crew_size": 0, - "rationale": "Secondary threat corridor. Dozer line along logging road will cut off eastern flank.", - }, - { - "choke_point_id": f"{fire_id}-CP-003", - "latitude": round(lat - 0.05, 5), - "longitude": round(lon + 0.10, 5), - "priority_score": 0.61, - "recommended_action": "aerial_retardant_drop", - "recommended_assets": ["water_bomber"], - "estimated_crew_size": 0, - "rationale": "Dense fuel load in valley. Retardant drop will slow spread before ground crews can reach.", - }, - ] - - return { - "fire_id": fire_id, - "model": "greedy_heuristic_v0", - "generated_at": datetime.now(UTC).isoformat(), - "total_choke_points": len(recommendations), - "recommendations": recommendations, - } diff --git a/src/ingestion/firms.py b/src/ingestion/firms.py deleted file mode 100644 index cf1704f..0000000 --- a/src/ingestion/firms.py +++ /dev/null @@ -1,193 +0,0 @@ -""" -firms.py — NASA FIRMS (Fire Information for Resource Management System) ingestion. - -Pulls real VIIRS/NOAA satellite fire hotspot data over Western Canada (BC + AB) -and normalizes it into FireGrid FireEvent format. - -API docs: https://firms.modaps.eosdis.nasa.gov/api/area/ -Rate limit: 5000 transactions / 10 minutes — we use 1 call per request (safe). - -Run this from backend/ to test: - uv run python -m src.ingestion.firms -""" - -import csv -import io -import logging -from datetime import UTC, datetime - -import httpx - -from src.core.config import settings - -logger = logging.getLogger(__name__) - -# ── Canada Bounding Box (BC + AB focus) ───────────────────────────────────────── -# Format: W,S,E,N (longitude_min, latitude_min, longitude_max, latitude_max) -CANADA_WEST_BBOX = "-140,48,-110,62" - -# FIRMS endpoint for CSV area data -FIRMS_BASE_URL = "https://firms.modaps.eosdis.nasa.gov/api/area/csv" - -# Use VIIRS NOAA-20 (most recent, best resolution ~375m) -FIRMS_SOURCE = "VIIRS_NOAA20_NRT" - -# How many days back to fetch (1 = last 24h — minimizes data volume + rate hits) -DEFAULT_DAY_RANGE = 1 - - -def _assign_province(lat: float, lon: float) -> str: - """ - Rough bounding-box province assignment for BC and AB. - Anything else gets tagged as 'OTHER'. - """ - if -139.0 <= lon <= -114.0 and 48.3 <= lat <= 60.0: - return "BC" - elif -120.0 <= lon <= -110.0 and 49.0 <= lat <= 60.0: - return "AB" - return "OTHER" - - -def _frp_to_severity(frp: float) -> str: - """ - Convert Fire Radiative Power (MW) to a FireGrid severity label. - Thresholds based on FIRMS documentation and wildfire research. - """ - if frp >= 500: - return "extreme" - elif frp >= 100: - return "high" - elif frp >= 20: - return "moderate" - return "low" - - -def _normalize_hotspot(row: dict, idx: int) -> dict | None: - """ - Normalize a single FIRMS CSV row into a FireGrid FireEvent dict. - Returns None if the row is missing critical fields. - """ - try: - lat = float(row["latitude"]) - lon = float(row["longitude"]) - frp = float(row.get("frp", 0) or 0) - acq_date = row.get("acq_date", "") # e.g. "2026-03-21" - acq_time = row.get("acq_time", "0000") # e.g. "2315" - - # Build ISO timestamp from acquisition date + time - try: - dt_str = f"{acq_date} {acq_time.zfill(4)}" - detected_at = datetime.strptime(dt_str, "%Y-%m-%d %H%M").replace(tzinfo=UTC).isoformat() - except ValueError: - detected_at = datetime.now(UTC).isoformat() - - province = _assign_province(lat, lon) - severity = _frp_to_severity(frp) - - # Build a deterministic fire_id from date + grid position - # Round to 2 decimal places (~1km) to group nearby hotspots - grid_lat = round(lat, 2) - grid_lon = round(lon, 2) - fire_id = f"FIRMS-{acq_date.replace('-', '')}-{abs(int(grid_lat * 100))}-{abs(int(grid_lon * 100))}" - - return { - "fire_id": fire_id, - "province": province, - "name": f"Satellite Hotspot ({province}) #{idx + 1}", - "status": "out_of_control", # FIRMS detects active burning - "severity": severity, - "latitude": lat, - "longitude": lon, - "area_hectares": None, # FIRMS doesn't provide area - "frp_mw": frp, # Fire Radiative Power in megawatts - "confidence": row.get("confidence", "n"), - "satellite": row.get("satellite", "N20"), - "started_at": detected_at, - "updated_at": datetime.now(UTC).isoformat(), - "source": "NASA_FIRMS_VIIRS", - } - except (ValueError, KeyError) as e: - logger.warning(f"Skipping malformed FIRMS row: {e}") - return None - - -def fetch_firms_hotspots( - day_range: int = DEFAULT_DAY_RANGE, - bbox: str = CANADA_WEST_BBOX, - min_confidence: str = "n", # "l"=low, "n"=nominal, "h"=high — filter low quality -) -> list[dict]: - """ - Fetch active fire hotspots from NASA FIRMS VIIRS over Western Canada. - - Args: - day_range: Number of days to look back (1-10). Default 1 = last 24h. - bbox: Bounding box "W,S,E,N". Default covers BC + AB. - min_confidence: Minimum confidence level to include ('n' = nominal or higher). - - Returns: - List of normalized FireEvent dicts. - - Rate limit: 1 API call. FIRMS allows 5000 calls/10min — this is very safe. - """ - api_key = settings.NASA_FIRMS_API_KEY - if not api_key or api_key in ("dummy_key", ""): - logger.error("NASA_FIRMS_API_KEY is not set. Cannot fetch real fire data.") - return [] - - url = f"{FIRMS_BASE_URL}/{api_key}/{FIRMS_SOURCE}/{bbox}/{day_range}" - logger.info(f"Fetching FIRMS data: {url}") - - try: - # Single request — well within rate limits - with httpx.Client(timeout=15) as client: - resp = client.get(url) - resp.raise_for_status() - except httpx.TimeoutException: - logger.error("FIRMS API timed out.") - return [] - except httpx.HTTPStatusError as e: - logger.error(f"FIRMS API HTTP error: {e.response.status_code} — {e.response.text[:200]}") - return [] - except httpx.RequestError as e: - logger.error(f"FIRMS API request failed: {e}") - return [] - - # Parse CSV response - reader = csv.DictReader(io.StringIO(resp.text)) - hotspots = [] - for idx, row in enumerate(reader): - # Filter to nominal/high confidence only to reduce noise - confidence = row.get("confidence", "n").lower() - if min_confidence == "n" and confidence == "l": - continue - if min_confidence == "h" and confidence != "h": - continue - - normalized = _normalize_hotspot(row, idx) - if normalized: - hotspots.append(normalized) - - logger.info(f"FIRMS: fetched {len(hotspots)} hotspots over bbox {bbox}") - return hotspots - - -def get_firms_fires() -> list[dict]: - """ - Public interface: fetch and return all VIIRS hotspots over Western Canada. - Used by the fires API endpoint when USE_DUMMY_DATA=False. - """ - return fetch_firms_hotspots(day_range=DEFAULT_DAY_RANGE) - - -# ── Manual test ────────────────────────────────────────────────────────────────── -if __name__ == "__main__": - import json - logging.basicConfig(level=logging.INFO) - print("Fetching NASA FIRMS hotspots over Western Canada...") - fires = get_firms_fires() - print(f"\nGot {len(fires)} hotspots.\n") - if fires: - print("Sample hotspot:") - print(json.dumps(fires[0], indent=2)) - else: - print("No hotspots found — might be no active fires right now, or check your API key.") diff --git a/src/ingestion/static_dataset.py b/src/ingestion/static_dataset.py new file mode 100644 index 0000000..bcccb0e --- /dev/null +++ b/src/ingestion/static_dataset.py @@ -0,0 +1,754 @@ +""" +static_dataset.py - Build frozen benchmark datasets from historical wildfire data. + +Canonical path: +1. load Alberta historical wildfire incidents from `data/static/` +2. normalize them into snapshot records +3. optionally enrich with CFFDRS fire-danger fields +4. compute environment-variable records for FireEnv + +Run once, store the outputs, and train/evaluate only from the cached files. + +Example: + uv run python -m src.ingestion.static_dataset --target-count 100 +""" + +from __future__ import annotations + +import argparse +import csv +import json +import logging +from collections import Counter +from dataclasses import dataclass +from datetime import UTC, datetime +from pathlib import Path + +from src.ingestion.clean_historical import clean_raw_historical_row_with_reason + +try: + from tqdm import tqdm +except Exception: # pragma: no cover - optional dependency + + def tqdm(iterable, **_kwargs): + return iterable + + +logger = logging.getLogger(__name__) + +DEFAULT_OUTPUT_DIR = Path("data/static") +DEFAULT_ALBERTA_CSV = DEFAULT_OUTPUT_DIR / "fp-historical-wildfire-data-2006-2025.csv" + +WIND_DIR_TO_DEG = { + "N": 0.0, + "NNE": 22.5, + "NE": 45.0, + "ENE": 67.5, + "E": 90.0, + "ESE": 112.5, + "SE": 135.0, + "SSE": 157.5, + "S": 180.0, + "SSW": 202.5, + "SW": 225.0, + "WSW": 247.5, + "W": 270.0, + "WNW": 292.5, + "NW": 315.0, + "NNW": 337.5, +} + +FIRE_TYPE_FACTOR = { + "ground": 0.8, + "surface": 1.0, + "crown": 1.18, +} + + +@dataclass +class SnapshotBuildResult: + snapshots: list[dict] + parameter_records: list[dict] + output_dir: Path + + +def _clamp(value: float, low: float, high: float) -> float: + return max(low, min(high, value)) + + +def _norm(value: float, low: float, high: float) -> float: + if high <= low: + return 0.0 + return _clamp((value - low) / (high - low), 0.0, 1.0) + + +def _clean_str(value: object) -> str | None: + if value is None: + return None + text = str(value).strip() + return text or None + + +def _parse_float(value: object) -> float | None: + text = _clean_str(value) + if text is None: + return None + try: + return float(text) + except ValueError: + return None + + +def _parse_datetime(value: object) -> datetime | None: + text = _clean_str(value) + if text is None: + return None + try: + if "T" in text: + return datetime.fromisoformat(text.replace("Z", "+00:00")) + except ValueError: + pass + for fmt in ("%Y-%m-%d %H:%M", "%Y-%m-%d %H:%M:%S", "%Y-%m-%d"): + try: + return datetime.strptime(text, fmt).replace(tzinfo=UTC) + except ValueError: + continue + return None + + +def _to_iso(value: datetime | None) -> str | None: + return value.isoformat() if value is not None else None + + +def _parse_wind_direction(value: object) -> float | None: + text = _clean_str(value) + if text is None: + return None + try: + return float(text) + except ValueError: + return WIND_DIR_TO_DEG.get(text.upper()) + + +def _estimate_precipitation_mm(condition: str | None) -> float: + if condition is None: + return 0.0 + normalized = condition.strip().lower() + if normalized == "rain showers": + return 2.0 + if normalized == "cb wet": + return 1.0 + return 0.0 + + +def _fuel_type_factor(fuel_type: str | None) -> float: + if not fuel_type: + return 1.0 + fuel = fuel_type.strip().upper() + if ( + fuel.startswith("C-") + or fuel.startswith("C") + or fuel.startswith("S-") + or fuel.startswith("S") + ): + return 1.12 + if fuel.startswith("M-") or fuel.startswith("M"): + return 1.06 + if fuel.startswith("O-1B"): + return 1.08 + if fuel.startswith("O-"): + return 1.03 + if fuel.startswith("D-"): + return 0.92 + return 1.0 + + +def _canonical_record_id(fire: dict) -> str: + fire_id = str(fire.get("fire_id", "unknown")) + anchor = str( + fire.get("snapshot_date") or fire.get("updated_at") or fire.get("started_at") or "unknown" + ) + safe_time = anchor.replace(":", "").replace("-", "").replace("+", "_") + return f"{fire_id}__{safe_time}" + + +def split_for_year(year: int | None) -> str | None: + if year is None: + return None + if 2006 <= year <= 2022: + return "train" + if year == 2023: + return "val" + if 2024 <= year <= 2025: + return "holdout" + return None + + +def _dedupe_fires(fires: list[dict]) -> list[dict]: + seen_ids: set[str] = set() + unique: list[dict] = [] + for fire in fires: + fire_id = str(fire.get("fire_id", "")) + if not fire_id or fire_id in seen_ids: + continue + seen_ids.add(fire_id) + unique.append(fire) + return unique + + +def _fire_priority(fire: dict) -> tuple[float, float, float, str]: + spread = float(fire.get("observed_spread_rate_m_min") or 0.0) + size = float(fire.get("assessment_hectares") or fire.get("area_hectares") or 0.0) + year = float(fire.get("year") or 0.0) + fire_id = str(fire.get("fire_id", "")) + return (spread, size, year, fire_id) + + +def _load_fire_records(path: Path) -> list[dict]: + payload = json.loads(path.read_text()) + records = payload.get("records", []) if isinstance(payload, dict) else payload + return [record for record in records if isinstance(record, dict)] + + +def _normalize_alberta_row(row: dict) -> dict | None: + cleaned, _reason = clean_raw_historical_row_with_reason(row) + if cleaned is None: + return None + + year = _clean_str(cleaned.get("YEAR")) + fire_number = _clean_str(cleaned.get("FIRE_NUMBER")) + lat = _parse_float(cleaned.get("LATITUDE")) + lon = _parse_float(cleaned.get("LONGITUDE")) + assessment_dt = _parse_datetime(cleaned.get("ASSESSMENT_DATETIME")) + assessment_hectares = _parse_float(cleaned.get("ASSESSMENT_HECTARES")) + current_size = _parse_float(cleaned.get("CURRENT_SIZE")) + spread_rate = _parse_float(cleaned.get("FIRE_SPREAD_RATE")) + temp_c = _parse_float(cleaned.get("TEMPERATURE")) + rh_pct = _parse_float(cleaned.get("RELATIVE_HUMIDITY")) + wind_dir_deg = _parse_wind_direction(cleaned.get("WIND_DIRECTION")) + wind_speed = _parse_float(cleaned.get("WIND_SPEED")) + + if not all([year, fire_number]) or lat is None or lon is None or assessment_dt is None: + return None + + area_hectares = assessment_hectares if assessment_hectares not in (None, 0.0) else current_size + if area_hectares is None or spread_rate is None or temp_c is None or rh_pct is None: + return None + if wind_dir_deg is None or wind_speed is None: + return None + + started_at = _parse_datetime(cleaned.get("FIRE_START_DATE")) + discovered_at = _parse_datetime(cleaned.get("DISCOVERED_DATE")) + reported_at = _parse_datetime(cleaned.get("REPORTED_DATE")) + dispatch_at = _parse_datetime(cleaned.get("DISPATCH_DATE")) + arrival_at = _parse_datetime(cleaned.get("IA_ARRIVAL_AT_FIRE_DATE")) + firefighting_start = _parse_datetime(cleaned.get("FIRE_FIGHTING_START_DATE")) + + fire_id = f"AB-{year}-{fire_number}" + fire_name = _clean_str(cleaned.get("FIRE_NAME")) or fire_id + fire_type = (_clean_str(cleaned.get("FIRE_TYPE")) or "Surface").strip() + fuel_type = _clean_str(cleaned.get("FUEL_TYPE")) + weather_over_fire = _clean_str(cleaned.get("WEATHER_CONDITIONS_OVER_FIRE")) + year_int = int(year) + split = split_for_year(year_int) + if split is None: + return None + + return { + "record_id": fire_id, + "fire_id": fire_id, + "year": year_int, + "split": split, + "province": "AB", + "name": fire_name, + "source": "AB_HISTORICAL_WILDFIRE", + "status": "historical", + "snapshot_date": assessment_dt.date().isoformat(), + "snapshot_datetime": _to_iso(assessment_dt), + "started_at": _to_iso(started_at), + "updated_at": _to_iso(assessment_dt), + "latitude": lat, + "longitude": lon, + "area_hectares": float(area_hectares), + "assessment_hectares": assessment_hectares, + "current_size": current_size, + "size_class": _clean_str(cleaned.get("SIZE_CLASS")), + "observed_spread_rate_m_min": spread_rate, + "temperature_c": temp_c, + "relative_humidity_pct": rh_pct, + "wind_direction_deg": wind_dir_deg, + "wind_speed_km_h": wind_speed, + "precipitation_mm": _estimate_precipitation_mm(weather_over_fire), + "fire_type": fire_type.lower(), + "fuel_type": fuel_type, + "weather_conditions_over_fire": weather_over_fire, + "fire_position_on_slope": _clean_str(cleaned.get("FIRE_POSITION_ON_SLOPE")), + "fire_origin": _clean_str(cleaned.get("FIRE_ORIGIN")), + "general_cause": _clean_str(cleaned.get("GENERAL_CAUSE")), + "activity_class": _clean_str(cleaned.get("ACTIVITY_CLASS")), + "true_cause": _clean_str(cleaned.get("TRUE_CAUSE")), + "discovered_date": _to_iso(discovered_at), + "reported_date": _to_iso(reported_at), + "dispatch_date": _to_iso(dispatch_at), + "ia_arrival_at_fire_date": _to_iso(arrival_at), + "fire_fighting_start_date": _to_iso(firefighting_start), + "discovered_size": _parse_float(cleaned.get("DISCOVERED_SIZE")), + "fire_fighting_start_size": _parse_float(cleaned.get("FIRE_FIGHTING_START_SIZE")), + "initial_action_by": _clean_str(cleaned.get("INITIAL_ACTION_BY")), + "ia_access": _clean_str(cleaned.get("IA_ACCESS")), + "bucketing_on_fire": _clean_str(cleaned.get("BUCKETING_ON_FIRE")), + "distance_from_water_source": _parse_float(cleaned.get("DISTANCE_FROM_WATER_SOURCE")), + } + + +def load_alberta_historical_fires(csv_path: Path) -> list[dict]: + if not csv_path.exists(): + msg = f"Alberta historical wildfire CSV not found: {csv_path}" + raise FileNotFoundError(msg) + + fires: list[dict] = [] + drop_reasons: Counter[str] = Counter() + yearly_total: Counter[int] = Counter() + yearly_kept: Counter[int] = Counter() + raw_rows = 0 + with csv_path.open(newline="", encoding="utf-8-sig") as handle: + reader = csv.DictReader(handle) + for row in tqdm(reader, desc="Cleaning historical rows", unit="row"): + raw_rows += 1 + year_raw = _clean_str(row.get("YEAR")) + if year_raw and year_raw.isdigit(): + yearly_total[int(year_raw)] += 1 + + cleaned, reason = clean_raw_historical_row_with_reason(row) + if cleaned is None: + drop_reasons[reason or "cleaning_failed"] += 1 + continue + + normalized = _normalize_alberta_row(cleaned) + if normalized is not None: + fires.append(normalized) + yearly_kept[int(normalized["year"])] += 1 + else: + drop_reasons["normalization_failed"] += 1 + logger.info("Loaded %s Alberta historical wildfire incidents", len(fires)) + logger.info( + "Historical input rows: %s | kept: %s | dropped: %s", + raw_rows, + len(fires), + raw_rows - len(fires), + ) + if drop_reasons: + for reason, count in drop_reasons.most_common(10): + logger.info("Dropped %s rows due to %s", count, reason) + for year in sorted(yearly_total): + logger.info( + "Year %s: kept %s / %s", + year, + yearly_kept.get(year, 0), + yearly_total[year], + ) + return fires + + +def collect_candidate_fires( + fire_records_path: Path | None = None, + raw_alberta_csv: Path | None = None, +) -> list[dict]: + """Collect and prioritize historical fire records for snapshot export.""" + if fire_records_path is not None: + fires = _load_fire_records(fire_records_path) + else: + fires = load_alberta_historical_fires(raw_alberta_csv or DEFAULT_ALBERTA_CSV) + + unique = _dedupe_fires(fires) + unique.sort(key=_fire_priority, reverse=True) + return unique + + +def _hours_between(start: str | None, end: str | None) -> float | None: + start_dt = _parse_datetime(start) + end_dt = _parse_datetime(end) + if start_dt is None or end_dt is None: + return None + return round((end_dt - start_dt).total_seconds() / 3600.0, 2) + + +def build_snapshot_record(fire: dict, *, stations: list[dict] | None = None) -> dict | None: + """Convert one historical fire record into a snapshot record.""" + from src.ingestion.cffdrs import get_cffdrs_for_location + + lat = fire.get("latitude") + lon = fire.get("longitude") + if lat is None or lon is None: + return None + + snapshot_date = fire.get("snapshot_date") + snapshot_dt = _parse_datetime(fire.get("snapshot_datetime")) + snapshot_day = snapshot_dt.date() if snapshot_dt is not None else None + + cffdrs = None + if stations: + cffdrs = get_cffdrs_for_location( + float(lat), + float(lon), + stations=stations, + target_date=snapshot_day, + max_date_offset_days=1, + ) + + record = { + "record_id": _canonical_record_id(fire), + "fire_id": fire.get("fire_id"), + "source": fire.get("source"), + "province": fire.get("province", "AB"), + "year": fire.get("year"), + "split": fire.get("split"), + "name": fire.get("name"), + "status": fire.get("status"), + "snapshot_date": snapshot_date, + "snapshot_datetime": fire.get("snapshot_datetime"), + "latitude": float(lat), + "longitude": float(lon), + "area_hectares": float(fire["area_hectares"]), + "assessment_hectares": fire.get("assessment_hectares"), + "current_size": fire.get("current_size"), + "size_class": fire.get("size_class"), + "started_at": fire.get("started_at"), + "updated_at": fire.get("updated_at"), + "wind_speed_km_h": float(fire["wind_speed_km_h"]), + "wind_direction_deg": float(fire["wind_direction_deg"]), + "temperature_c": float(fire["temperature_c"]), + "relative_humidity_pct": float(fire["relative_humidity_pct"]), + "precipitation_mm": float(fire.get("precipitation_mm") or 0.0), + "observed_spread_rate_m_min": float(fire["observed_spread_rate_m_min"]), + "fire_type": fire.get("fire_type"), + "fuel_type": fire.get("fuel_type"), + "weather_conditions_over_fire": fire.get("weather_conditions_over_fire"), + "fire_position_on_slope": fire.get("fire_position_on_slope"), + "fire_origin": fire.get("fire_origin"), + "general_cause": fire.get("general_cause"), + "activity_class": fire.get("activity_class"), + "true_cause": fire.get("true_cause"), + "discovered_date": fire.get("discovered_date"), + "reported_date": fire.get("reported_date"), + "dispatch_date": fire.get("dispatch_date"), + "ia_arrival_at_fire_date": fire.get("ia_arrival_at_fire_date"), + "fire_fighting_start_date": fire.get("fire_fighting_start_date"), + "discovered_size": fire.get("discovered_size"), + "fire_fighting_start_size": fire.get("fire_fighting_start_size"), + "initial_action_by": fire.get("initial_action_by"), + "ia_access": fire.get("ia_access"), + "bucketing_on_fire": fire.get("bucketing_on_fire"), + "distance_from_water_source": fire.get("distance_from_water_source"), + "detection_delay_h": _hours_between(fire.get("started_at"), fire.get("discovered_date")), + "report_delay_h": _hours_between(fire.get("discovered_date"), fire.get("reported_date")), + "dispatch_delay_h": _hours_between(fire.get("reported_date"), fire.get("dispatch_date")), + "ia_travel_delay_h": _hours_between( + fire.get("dispatch_date"), fire.get("ia_arrival_at_fire_date") + ), + "record_quality_flag": "measured", + "snapshot_generated_at": datetime.now(UTC).isoformat(), + } + + if cffdrs is not None: + record.update( + { + "fwi": cffdrs.get("fwi"), + "isi": cffdrs.get("isi"), + "bui": cffdrs.get("bui"), + "dc": cffdrs.get("dc"), + "dmc": cffdrs.get("dmc"), + "ffmc": cffdrs.get("ffmc"), + "cffdrs_station_distance_km": cffdrs.get("distance_km"), + "cffdrs_station_id": cffdrs.get("source_station_id"), + "cffdrs_station_name": cffdrs.get("source_station"), + "cffdrs_observation_date": cffdrs.get("observation_date"), + "cffdrs_date_offset_days": cffdrs.get("date_offset_days"), + "temporal_alignment_status": "aligned" + if cffdrs.get("date_offset_days", 0) == 0 + else "near_aligned", + } + ) + else: + record.update( + { + "fwi": None, + "isi": None, + "bui": None, + "dc": None, + "dmc": None, + "ffmc": None, + "cffdrs_station_distance_km": None, + "cffdrs_station_id": None, + "cffdrs_station_name": None, + "cffdrs_observation_date": None, + "cffdrs_date_offset_days": None, + "temporal_alignment_status": "not_joined", + } + ) + + required_fields = ( + "wind_speed_km_h", + "wind_direction_deg", + "temperature_c", + "relative_humidity_pct", + "area_hectares", + "observed_spread_rate_m_min", + ) + if any(record.get(field) is None for field in required_fields): + return None + return record + + +def compute_environment_parameters(snapshot: dict) -> dict: + """Map one snapshot record into deterministic FireEnv parameter fields.""" + observed_spread = float(snapshot["observed_spread_rate_m_min"]) + wind_speed = float(snapshot["wind_speed_km_h"]) + wind_dir_deg = float(snapshot["wind_direction_deg"]) + temp_c = float(snapshot["temperature_c"]) + rh_pct = float(snapshot["relative_humidity_pct"]) + precip_mm = float(snapshot.get("precipitation_mm") or 0.0) + area_hectares = float(snapshot["area_hectares"]) + fire_type = str(snapshot.get("fire_type") or "surface").lower() + fuel_type = snapshot.get("fuel_type") + + spread_norm = _norm(observed_spread, 0.0, 25.0) + wind_norm = _norm(wind_speed, 0.0, 40.0) + temp_norm = _norm(temp_c, 0.0, 35.0) + rh_norm = _norm(rh_pct, 10.0, 95.0) + rain_norm = _norm(precip_mm, 0.0, 5.0) + size_norm = _norm(area_hectares, 0.0, 2000.0) + + cffdrs_terms = [ + snapshot.get("isi"), + snapshot.get("fwi"), + snapshot.get("bui"), + snapshot.get("ffmc"), + ] + cffdrs_present = any(value is not None for value in cffdrs_terms) + cffdrs_dryness = 0.0 + if cffdrs_present: + isi_norm = _norm(float(snapshot.get("isi") or 0.0), 0.0, 25.0) + fwi_norm = _norm(float(snapshot.get("fwi") or 0.0), 0.0, 40.0) + bui_norm = _norm(float(snapshot.get("bui") or 0.0), 0.0, 120.0) + ffmc_norm = _norm(float(snapshot.get("ffmc") or 85.0), 70.0, 96.0) + cffdrs_dryness = _clamp( + 0.4 * isi_norm + 0.25 * fwi_norm + 0.15 * bui_norm + 0.2 * ffmc_norm, + 0.0, + 1.0, + ) + + weather_score = _clamp( + 0.45 * wind_norm + 0.2 * temp_norm + 0.35 * (1.0 - rh_norm), + 0.0, + 1.0, + ) + rain_factor = 1.0 - 0.5 * rain_norm + size_factor = 0.95 + 0.15 * size_norm + fire_type_factor = FIRE_TYPE_FACTOR.get(fire_type, 1.0) + fuel_factor = _fuel_type_factor(fuel_type) + + spread_score = _clamp( + (0.6 * spread_norm + 0.2 * weather_score + 0.1 * size_norm + 0.1 * cffdrs_dryness) + * rain_factor + * fire_type_factor + * fuel_factor, + 0.0, + 1.0, + ) + + base_spread_prob = round(_clamp(0.04 + 0.18 * spread_score, 0.04, 0.22), 4) + wind_strength = round(_clamp(0.1 + 0.5 * wind_norm, 0.1, 0.6), 4) + spread_rate_1h_m = round(observed_spread * 60.0, 1) + + if spread_score < 0.33: + severity_bucket = "low" + elif spread_score < 0.66: + severity_bucket = "medium" + else: + severity_bucket = "high" + + return { + "record_id": snapshot["record_id"], + "fire_id": snapshot.get("fire_id"), + "source": snapshot.get("source"), + "province": snapshot.get("province"), + "year": snapshot.get("year"), + "split": snapshot.get("split"), + "base_spread_prob": base_spread_prob, + "severity_bucket": severity_bucket, + "wind_dir_deg": round(wind_dir_deg, 2), + "wind_strength": wind_strength, + "spread_rate_1h_m": spread_rate_1h_m, + "spread_score": round(spread_score, 4), + "weather_score": round(weather_score, 4), + "cffdrs_dryness_score": round(cffdrs_dryness, 4), + "size_factor": round(size_factor, 4), + "fire_type_factor": round(fire_type_factor, 4), + "fuel_factor": round(fuel_factor, 4), + "rain_factor": round(rain_factor, 4), + "observed_spread_rate_m_min": observed_spread, + "assessment_hectares": snapshot.get("assessment_hectares"), + "fire_type": snapshot.get("fire_type"), + "fuel_type": snapshot.get("fuel_type"), + "record_quality_flag": snapshot.get("record_quality_flag", "measured"), + } + + +def build_static_datasets( + *, + target_count: int = 100, + output_dir: Path | None = None, + cffdrs_year: int | None = None, + fire_records_path: Path | None = None, + raw_alberta_csv: Path | None = None, +) -> SnapshotBuildResult: + """Run the one-time pipeline and write frozen benchmark artifacts.""" + output_dir = output_dir or DEFAULT_OUTPUT_DIR + output_dir.mkdir(parents=True, exist_ok=True) + + stations: list[dict] | None = None + if cffdrs_year is not None: + from src.ingestion.cffdrs import fetch_cffdrs_stations + + stations = fetch_cffdrs_stations(year=cffdrs_year) + if not stations: + logger.warning( + "CFFDRS station download failed for year %s; continuing without supplementary CFFDRS enrichment.", + cffdrs_year, + ) + + candidates = collect_candidate_fires( + fire_records_path=fire_records_path, + raw_alberta_csv=raw_alberta_csv, + ) + snapshots: list[dict] = [] + parameter_records: list[dict] = [] + split_counts = {"train": 0, "val": 0, "holdout": 0} + + for fire in tqdm(candidates, desc="Building snapshots", unit="record"): + split_name = fire.get("split") + if split_name not in split_counts: + continue + if split_counts[split_name] >= target_count: + continue + snapshot = build_snapshot_record(fire, stations=stations) + if snapshot is None: + continue + params = compute_environment_parameters(snapshot) + snapshots.append(snapshot) + parameter_records.append(params) + split_counts[split_name] += 1 + if all(count >= target_count for count in split_counts.values()): + break + + snapshot_payload = { + "schema_version": 2, + "generated_at": datetime.now(UTC).isoformat(), + "record_count": len(snapshots), + "records": snapshots, + } + params_payload = { + "schema_version": 2, + "generated_at": datetime.now(UTC).isoformat(), + "record_count": len(parameter_records), + "records": parameter_records, + } + + snapshot_path = output_dir / "snapshot_records.json" + params_path = output_dir / "scenario_parameter_records.json" + snapshot_path.write_text(json.dumps(snapshot_payload, indent=2)) + params_path.write_text(json.dumps(params_payload, indent=2)) + + split_names = ("train", "val", "holdout") + for split_name in split_names: + split_snapshots = [record for record in snapshots if record.get("split") == split_name] + split_params = [record for record in parameter_records if record.get("split") == split_name] + (output_dir / f"snapshot_records_{split_name}.json").write_text( + json.dumps( + { + "schema_version": 2, + "generated_at": datetime.now(UTC).isoformat(), + "split": split_name, + "record_count": len(split_snapshots), + "records": split_snapshots, + }, + indent=2, + ) + ) + (output_dir / f"scenario_parameter_records_{split_name}.json").write_text( + json.dumps( + { + "schema_version": 2, + "generated_at": datetime.now(UTC).isoformat(), + "split": split_name, + "record_count": len(split_params), + "records": split_params, + }, + indent=2, + ) + ) + + logger.info("Wrote %s snapshot records to %s", len(snapshots), snapshot_path) + logger.info("Wrote %s scenario parameter records to %s", len(parameter_records), params_path) + for split_name in split_names: + logger.info( + "Split %s: %s records", + split_name, + sum(1 for record in parameter_records if record.get("split") == split_name), + ) + if not parameter_records: + logger.warning( + "No scenario parameter records were built. Check whether the Alberta historical file has valid assessment fields or whether your optional CFFDRS join is too sparse." + ) + return SnapshotBuildResult( + snapshots=snapshots, parameter_records=parameter_records, output_dir=output_dir + ) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Build frozen wildfire benchmark datasets") + parser.add_argument( + "--target-count", type=int, default=100, help="Target number of records to export per split" + ) + parser.add_argument( + "--output-dir", + type=Path, + default=DEFAULT_OUTPUT_DIR, + help="Directory for snapshot and parameter JSON files", + ) + parser.add_argument( + "--cffdrs-year", + type=int, + default=None, + help="Optional CFFDRS observation year for supplementary danger-index enrichment", + ) + parser.add_argument( + "--fire-records", + type=Path, + default=None, + help="Optional JSON file of normalized fire records to use instead of the Alberta historical CSV", + ) + parser.add_argument( + "--raw-alberta-csv", + type=Path, + default=DEFAULT_ALBERTA_CSV, + help="Path to the raw Alberta historical wildfire CSV", + ) + args = parser.parse_args() + + logging.basicConfig(level=logging.INFO) + result = build_static_datasets( + target_count=args.target_count, + output_dir=args.output_dir, + cffdrs_year=args.cffdrs_year, + fire_records_path=args.fire_records, + raw_alberta_csv=args.raw_alberta_csv, + ) + print( + f"Built {len(result.parameter_records)} scenario parameter records in {result.output_dir}" + ) + + +if __name__ == "__main__": + main() diff --git a/src/ingestion/weather.py b/src/ingestion/weather.py index f9474b9..aaaaa73 100644 --- a/src/ingestion/weather.py +++ b/src/ingestion/weather.py @@ -4,7 +4,7 @@ No API key required. Open-Meteo is a free, open-source weather API. Given a fire's (latitude, longitude), this module returns the weather -variables needed as features for the XGBoost spread model: +variables used to build frozen snapshot records and offline environment variables: - wind_speed_km_h - wind_direction_deg - temperature_c @@ -29,7 +29,7 @@ # Open-Meteo current-conditions endpoint (no key needed) OPEN_METEO_URL = "https://api.open-meteo.com/v1/forecast" -# Variables we need for the ML feature vector +# Variables needed for the snapshot builder WEATHER_VARIABLES = [ "temperature_2m", "relative_humidity_2m", @@ -149,10 +149,30 @@ def get_weather_for_fires(fires: list[dict]) -> dict[str, dict]: # Test fires test_fires = [ - {"fire_id": "BC-2026-001", "name": "Okanagan Ridge Fire", "latitude": 49.9071, "longitude": -119.496}, - {"fire_id": "BC-2026-002", "name": "Kamloops Plateau Fire", "latitude": 50.6745, "longitude": -120.3273}, - {"fire_id": "BC-2026-003", "name": "Fraser Valley Approach","latitude": 49.3845, "longitude": -121.4483}, - {"fire_id": "AB-2026-001", "name": "Peace River Complex", "latitude": 56.2370, "longitude": -117.2900}, + { + "fire_id": "BC-2026-001", + "name": "Okanagan Ridge Fire", + "latitude": 49.9071, + "longitude": -119.496, + }, + { + "fire_id": "BC-2026-002", + "name": "Kamloops Plateau Fire", + "latitude": 50.6745, + "longitude": -120.3273, + }, + { + "fire_id": "BC-2026-003", + "name": "Fraser Valley Approach", + "latitude": 49.3845, + "longitude": -121.4483, + }, + { + "fire_id": "AB-2026-001", + "name": "Peace River Complex", + "latitude": 56.2370, + "longitude": -117.2900, + }, ] print("Fetching fire weather from Open-Meteo...\n") diff --git a/src/models/evaluate_agents.py b/src/models/evaluate_agents.py new file mode 100644 index 0000000..b55076d --- /dev/null +++ b/src/models/evaluate_agents.py @@ -0,0 +1,255 @@ +"""General benchmark evaluation interface for RL agents on split datasets.""" + +from __future__ import annotations + +import argparse +import json +from pathlib import Path + +import numpy as np + +from src.models.fire_env import ( + ASSET_BURNED, + BURNED, + BURNING, + DEPLOY_CREW, + DEPLOY_HELICOPTER, + MOVE_E, + MOVE_N, + MOVE_S, + MOVE_W, + WildfireEnv, + load_scenario_parameter_records, +) + +try: + from tqdm import tqdm +except Exception: # pragma: no cover - optional dependency + + def tqdm(iterable, **_kwargs): + return iterable + + +DEFAULT_TRAIN_DATASET = Path("data/static/scenario_parameter_records_train.json") +DEFAULT_VAL_DATASET = Path("data/static/scenario_parameter_records_val.json") +DEFAULT_HOLDOUT_DATASET = Path("data/static/scenario_parameter_records_holdout.json") +DEFAULT_PPO_MODEL = Path("src/models/tactical_ppo_agent.zip") + + +def _load_ppo_model(path: Path): + from stable_baselines3 import PPO + + if not path.exists(): + raise FileNotFoundError(f"PPO model not found at {path}") + return PPO.load(str(path)) + + +def _nearest_burning_cell(env: WildfireEnv) -> tuple[int, int] | None: + burning_positions = np.argwhere(env.grid == BURNING) + if burning_positions.size == 0: + return None + ar, ac = env.agent_pos + dists = np.abs(burning_positions[:, 0] - ar) + np.abs(burning_positions[:, 1] - ac) + idx = int(np.argmin(dists)) + return int(burning_positions[idx, 0]), int(burning_positions[idx, 1]) + + +def _greedy_action(env: WildfireEnv) -> int: + ar, ac = env.agent_pos + + if env.heli_left > 0 and env.heli_cd == 0: + for dr in range(-1, 2): + for dc in range(-1, 2): + rr, cc = ar + dr, ac + dc + if ( + 0 <= rr < env.grid_size + and 0 <= cc < env.grid_size + and env.grid[rr, cc] == BURNING + ): + return DEPLOY_HELICOPTER + + if env.crew_left > 0 and env.crew_cd == 0 and env.grid[ar, ac] == BURNING: + return DEPLOY_CREW + + target = _nearest_burning_cell(env) + if target is None: + return MOVE_N + + tr, tc = target + if tr < ar: + return MOVE_N + if tr > ar: + return MOVE_S + if tc > ac: + return MOVE_E + if tc < ac: + return MOVE_W + return DEPLOY_CREW if env.crew_left > 0 and env.crew_cd == 0 else MOVE_N + + +def _run_episode(env: WildfireEnv, agent_name: str, model, seed: int) -> dict: + obs, _info = env.reset(seed=seed) + episode_return = 0.0 + terminated = False + truncated = False + info = {} + + for _ in range(env.max_steps): + if agent_name == "random": + action = int(env.action_space.sample()) + elif agent_name == "greedy": + action = _greedy_action(env) + else: + action, _ = model.predict(obs, deterministic=True) + action = int(action) + + obs, reward, terminated, truncated, info = env.step(action) + episode_return += float(reward) + if terminated or truncated: + break + + final_burned_area = int( + np.sum((env.grid == BURNED) | (env.grid == BURNING) | (env.grid == ASSET_BURNED)) + ) + containment_success = 1 if terminated and not truncated else 0 + heli_used = env.heli_budget_init - info.get("heli_left", env.heli_left) + crew_used = env.crew_budget_init - info.get("crew_left", env.crew_left) + + return { + "return": episode_return, + "assets_lost": int(info.get("assets_lost", env.assets_lost)), + "containment_success": containment_success, + "final_burned_area": final_burned_area, + "time_to_containment": int(info.get("step", env.step_count)), + "heli_used": int(heli_used), + "crew_used": int(crew_used), + "resource_efficiency": float(final_burned_area / max(1, heli_used + crew_used)), + } + + +def _evaluate_agent_on_split( + *, + agent_name: str, + records: list[dict], + seeds: list[int], + episodes_per_seed: int, + model, + compute_normalized_burn_ratio: bool, +) -> dict: + episode_metrics = [] + + for seed in seeds: + env = WildfireEnv(scenario_parameter_records=records, randomize_scenario=True) + baseline_env = WildfireEnv(scenario_parameter_records=records, randomize_scenario=True) + iterator = tqdm(range(episodes_per_seed), desc=f"{agent_name} seed={seed}", unit="ep") + for ep in iterator: + eval_seed = seed * 10_000 + ep + metrics = _run_episode(env, agent_name, model, seed=eval_seed) + if compute_normalized_burn_ratio: + # Use MOVE_N-only as deterministic no-action surrogate baseline. + _obs, _ = baseline_env.reset(seed=eval_seed) + for _ in range(baseline_env.max_steps): + _obs, _reward, done, trunc, _base_info = baseline_env.step(MOVE_N) + if done or trunc: + break + baseline_burned = int( + np.sum( + (baseline_env.grid == BURNED) + | (baseline_env.grid == BURNING) + | (baseline_env.grid == ASSET_BURNED) + ) + ) + metrics["normalized_burn_ratio"] = float( + metrics["final_burned_area"] / max(1, baseline_burned) + ) + episode_metrics.append(metrics) + + arr = { + key: np.array([m[key] for m in episode_metrics], dtype=float) for key in episode_metrics[0] + } + summary = { + "episodes": len(episode_metrics), + "mean_return": float(arr["return"].mean()), + "std_return": float(arr["return"].std()), + "asset_survival_rate": float((arr["assets_lost"] == 0).mean()), + "containment_success_rate": float(arr["containment_success"].mean()), + "mean_final_burned_area": float(arr["final_burned_area"].mean()), + "mean_time_to_containment": float(arr["time_to_containment"].mean()), + "mean_resource_efficiency": float(arr["resource_efficiency"].mean()), + "variance_across_episodes": float(arr["return"].var()), + } + if "normalized_burn_ratio" in arr: + summary["mean_normalized_burn_ratio"] = float(arr["normalized_burn_ratio"].mean()) + return summary + + +def _load_split_records(path: Path | None) -> list[dict]: + if path is None or not path.exists(): + return [] + return load_scenario_parameter_records(path) + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Evaluate benchmark agents on train/val/holdout splits" + ) + parser.add_argument("--agents", type=str, default="ppo,greedy,random") + parser.add_argument("--train-dataset", type=Path, default=DEFAULT_TRAIN_DATASET) + parser.add_argument("--val-dataset", type=Path, default=DEFAULT_VAL_DATASET) + parser.add_argument("--holdout-dataset", type=Path, default=DEFAULT_HOLDOUT_DATASET) + parser.add_argument("--ppo-model", type=Path, default=DEFAULT_PPO_MODEL) + parser.add_argument("--episodes", type=int, default=20, help="Episodes per seed per split") + parser.add_argument("--seeds", type=str, default="42,43,44") + parser.add_argument("--no-normalized-burn", action="store_true") + parser.add_argument("--output", type=Path, default=None) + args = parser.parse_args() + + seeds = [int(s.strip()) for s in args.seeds.split(",") if s.strip()] + agents = [a.strip().lower() for a in args.agents.split(",") if a.strip()] + + split_records = { + "train": _load_split_records(args.train_dataset), + "val": _load_split_records(args.val_dataset), + "holdout": _load_split_records(args.holdout_dataset), + } + + results: dict[str, dict] = {} + ppo_model = None + if "ppo" in agents: + ppo_model = _load_ppo_model(args.ppo_model) + + for agent_name in agents: + results[agent_name] = {} + for split_name, records in split_records.items(): + if not records: + continue + model = ppo_model if agent_name == "ppo" else None + summary = _evaluate_agent_on_split( + agent_name=agent_name, + records=records, + seeds=seeds, + episodes_per_seed=args.episodes, + model=model, + compute_normalized_burn_ratio=not args.no_normalized_burn, + ) + results[agent_name][split_name] = summary + + print("\nBenchmark Summary") + print("=" * 72) + for agent_name, split_summaries in results.items(): + for split_name, summary in split_summaries.items(): + print( + f"{agent_name:>8} | {split_name:<7} | episodes={summary['episodes']:>4} " + f"| return={summary['mean_return']:.1f} " + f"| assets_survival={summary['asset_survival_rate']:.3f} " + f"| containment={summary['containment_success_rate']:.3f} " + f"| burned={summary['mean_final_burned_area']:.1f}" + ) + + if args.output: + args.output.write_text(json.dumps(results, indent=2)) + print(f"\nSaved results to {args.output}") + + +if __name__ == "__main__": + main() diff --git a/src/models/fire_env.py b/src/models/fire_env.py index c93a27e..3c7d0ab 100644 --- a/src/models/fire_env.py +++ b/src/models/fire_env.py @@ -21,8 +21,10 @@ from __future__ import annotations +import json import math from dataclasses import dataclass +from pathlib import Path import gymnasium as gym import numpy as np @@ -83,6 +85,8 @@ class ScenarioConfig: asset_layout: str = "A" wind_dir_deg: float = 0.0 # 0 = wind blowing north->south wind_strength: float = 0.3 # [0, 1] + base_spread_prob: float | None = None + record_id: str | None = None def __post_init__(self): assert self.ignition in IGNITION_TYPES, f"Unknown ignition: {self.ignition}" @@ -91,6 +95,8 @@ def __post_init__(self): @property def spread_prob(self) -> float: + if self.base_spread_prob is not None: + return float(self.base_spread_prob) return SEVERITY_SPREAD_PROB[self.severity] @property @@ -126,6 +132,38 @@ def random_scenario( ) +def load_scenario_parameter_records(path: str | Path) -> list[dict]: + """Load cached scenario parameter records from a JSON file.""" + records_path = Path(path) + payload = json.loads(records_path.read_text()) + records = payload.get("records", []) if isinstance(payload, dict) else payload + if not isinstance(records, list): + msg = f"Invalid scenario parameter dataset: {records_path}" + raise ValueError(msg) + return [record for record in records if isinstance(record, dict)] + + +def scenario_from_parameter_record( + record: dict, + *, + ignition: str, + asset_layout: str, +) -> ScenarioConfig: + """Build a ScenarioConfig from a cached parameter record.""" + severity = str(record.get("severity_bucket", "medium")).lower() + return ScenarioConfig( + ignition=ignition, + severity=severity if severity in SEVERITY_LEVELS else "medium", + asset_layout=asset_layout, + wind_dir_deg=float(record.get("wind_dir_deg", 0.0) or 0.0), + wind_strength=float(record.get("wind_strength", 0.3) or 0.3), + base_spread_prob=float(record.get("base_spread_prob")) + if record.get("base_spread_prob") is not None + else None, + record_id=str(record.get("record_id")) if record.get("record_id") is not None else None, + ) + + # ── Environment ────────────────────────────────────────────────────────────── @@ -155,6 +193,7 @@ def __init__( crew_cooldown: int = 2, randomize_scenario: bool = True, scenario_families: list[tuple[str, str, str]] | None = None, + scenario_parameter_records: list[dict] | None = None, # Legacy compat -- ignored if scenario is provided base_spread_rate_m_per_min: float | None = None, ): @@ -168,6 +207,8 @@ def __init__( self.crew_cooldown_duration = crew_cooldown self.randomize_scenario = randomize_scenario self.scenario_families = scenario_families + self.scenario_parameter_records = scenario_parameter_records or [] + self._active_parameter_record: dict | None = None # Scenario (may be overridden each reset if randomize_scenario=True) if scenario is not None: @@ -224,7 +265,25 @@ def reset(self, seed: int | None = None, options: dict | None = None): # Optionally sample a new scenario if self.randomize_scenario: - self._scenario = random_scenario(self.np_random, self.scenario_families) + families = self.scenario_families or TRAIN_FAMILIES + ign, sev, layout = families[int(self.np_random.integers(len(families)))] + if self.scenario_parameter_records: + matching_records = [ + record + for record in self.scenario_parameter_records + if str(record.get("severity_bucket", "")).lower() == sev + ] + source_records = matching_records or self.scenario_parameter_records + record = source_records[int(self.np_random.integers(len(source_records)))] + self._active_parameter_record = record + self._scenario = scenario_from_parameter_record( + record, + ignition=ign, + asset_layout=layout, + ) + else: + self._active_parameter_record = None + self._scenario = random_scenario(self.np_random, families) # Reset budgets and cooldowns self.heli_left = self.heli_budget_init @@ -242,7 +301,10 @@ def reset(self, seed: int | None = None, options: dict | None = None): self.agent_pos = [0, 0] self._prev_burning = int(np.sum(self.grid == BURNING)) - return self._get_obs(), {"scenario": self._scenario} + return self._get_obs(), { + "scenario": self._scenario, + "parameter_record": self._active_parameter_record, + } def step(self, action: int): self.step_count += 1 @@ -294,6 +356,7 @@ def step(self, action: int): "heli_left": self.heli_left, "crew_left": self.crew_left, "scenario": self._scenario, + "parameter_record": self._active_parameter_record, } return self._get_obs(), reward, terminated, truncated, info diff --git a/src/models/spread_model.py b/src/models/spread_model.py deleted file mode 100644 index f593c9c..0000000 --- a/src/models/spread_model.py +++ /dev/null @@ -1,380 +0,0 @@ -""" -spread_model.py — XGBoost wildfire spread prediction model. - -Predicts fire spread radius (metres) at +1h and +3h horizons. - -TRAINING STRATEGY (Hackathon POC): - Physics-informed synthetic data using a Rothermel-inspired formula. - No historical fire spread archives needed. - -FEATURES (11 inputs) — research-informed: - wind_speed_km_h — wind magnitude - wind_u — eastward wind vector (cos decomposition — fixes cyclical issue) - wind_v — northward wind vector (sin decomposition) - temperature_c — air temperature - relative_humidity_pct - fwi — Fire Weather Index (CFFDRS) - isi — Initial Spread Index - bui — Buildup Index - area_hectares — current fire size - slope_pct — terrain slope (negative=downhill, positive=uphill) - rh_trend_24h — change in RH over last 24h (temporal context) - -NOTE on wind_u/wind_v: Tree models cannot handle cyclical features (359° and 1° -are physically adjacent but numerically 358 apart). Projecting onto U/V Cartesian -vectors eliminates this problem — the research paper confirmed this is the correct fix. - -Run to train + test: - uv run python -m src.models.spread_model -""" - -from __future__ import annotations - -import logging -import math -from pathlib import Path - -import joblib -import numpy as np -import pandas as pd -from sklearn.metrics import mean_absolute_error, r2_score -from sklearn.model_selection import train_test_split -from xgboost import XGBRegressor - -logger = logging.getLogger(__name__) - -MODEL_DIR = Path(__file__).parent -MODEL_1H_PATH = MODEL_DIR / "spread_1h_model.joblib" -MODEL_3H_PATH = MODEL_DIR / "spread_3h_model.joblib" - -# Updated feature set — 11 features with wind U/V, slope, and RH trend -FEATURE_COLS = [ - "wind_speed_km_h", - "wind_u", # eastward component: speed × cos(dir_rad) - "wind_v", # northward component: speed × sin(dir_rad) - "temperature_c", - "relative_humidity_pct", - "fwi", - "isi", - "bui", - "area_hectares", - "slope_pct", # terrain slope (–20 to +45 %) - "rh_trend_24h", # RH change over last 24h (negative = drying out) -] - - -# ── Physics Formula ─────────────────────────────────────────────────────────── - -def _rothermel_spread_m_per_min( - wind_speed: float, - rh: float, - isi: float, - ffmc: float, - slope: float = 0.0, -) -> float: - """ - Rothermel-inspired fire spread rate (metres per minute). - - Improvements over original version: - - wind_speed is the magnitude (|U|, |V| already resolved) - - slope_factor: uphill fires accelerate due to convective preheating - """ - ffmc_factor = max(0.1, (101 - ffmc) / 100) - wind_factor = math.exp(0.05039 * wind_speed) - rh_damping = max(0.01, 1 - (rh / 120)) - - # Uphill acceleration (Rothermel): positive slope → faster spread - # Downhill: small dampening. Flat: neutral (1.0) - slope_factor = 1.0 + (max(0.0, slope) / 20.0) - - base = isi * ffmc_factor * wind_factor * rh_damping * slope_factor * 2.5 - return max(0.5, base) - - -# ── Synthetic Data Generator ────────────────────────────────────────────────── - -def generate_synthetic_dataset(n_samples: int = 6000, seed: int = 42) -> pd.DataFrame: - """ - Build a physics-informed synthetic training dataset. - Ranges: BC/AB wildfire season (May–September) observations. - - Key improvements applied: - 1. wind_direction → (wind_u, wind_v) via trigonometric decomposition - 2. slope_pct feature added (terrain topography) - 3. rh_trend_24h feature added (temporal drying context) - """ - rng = np.random.default_rng(seed) - - wind_speed = rng.uniform(0, 60, n_samples) # km/h - wind_dir_deg = rng.uniform(0, 360, n_samples) # degrees - temperature = rng.uniform(5, 42, n_samples) # °C - humidity = rng.uniform(8, 85, n_samples) # % - fwi = rng.uniform(0, 100, n_samples) - isi = rng.uniform(0, 40, n_samples) - bui = rng.uniform(0, 200, n_samples) - area_ha = rng.uniform(1, 25000, n_samples) - - # [FIX 1] Wind U/V decomposition — eliminates cyclic discontinuity - wind_dir_rad = np.radians(wind_dir_deg) - wind_u = wind_speed * np.cos(wind_dir_rad) # eastward - wind_v = wind_speed * np.sin(wind_dir_rad) # northward - - # [FIX 2] Slope topography — uphill fires spread much faster - slope_pct = rng.uniform(-20, 45, n_samples) # –20 (downhill) to +45% (steep uphill) - - # [FIX 3] Temporal RH trend — drying conditions amplify danger - # Biased negative (more often drying than wetting during fire season) - rh_trend_24h = rng.normal(-5, 15, n_samples) # % RH change per 24h - - # FFMC derived from humidity + temp - ffmc = np.clip(101 - humidity * 0.7 + temperature * 0.4, 0, 101) - - # RH trend amplification: fast-drying conditions increase effective spread - # (a fire in drop-40%-RH conditions is much more dangerous) - rh_drying_factor = np.clip(1 + (-rh_trend_24h / 80), 0.8, 1.5) - - # Labels from Rothermel formula × drying amplification + noise - spread_1h_m = np.array([ - _rothermel_spread_m_per_min( - wind_speed[i], humidity[i], isi[i], ffmc[i], slope_pct[i] - ) * 60 * rh_drying_factor[i] - + rng.normal(0, 50) - for i in range(n_samples) - ]) - spread_3h_m = np.array([ - _rothermel_spread_m_per_min( - wind_speed[i], humidity[i], isi[i], ffmc[i], slope_pct[i] - ) * 180 * rh_drying_factor[i] - + rng.normal(0, 200) - for i in range(n_samples) - ]) - - spread_1h_m = np.clip(spread_1h_m, 50, 15000) - spread_3h_m = np.clip(spread_3h_m, 100, 50000) - - return pd.DataFrame({ - "wind_speed_km_h": wind_speed, - "wind_u": wind_u, - "wind_v": wind_v, - "temperature_c": temperature, - "relative_humidity_pct": humidity, - "fwi": fwi, - "isi": isi, - "bui": bui, - "area_hectares": area_ha, - "slope_pct": slope_pct, - "rh_trend_24h": rh_trend_24h, - "spread_1h_m": spread_1h_m, - "spread_3h_m": spread_3h_m, - }) - - -# ── Training ────────────────────────────────────────────────────────────────── - -def train_spread_model(n_samples: int = 6000) -> tuple[XGBRegressor, XGBRegressor, dict]: - print(f"Generating {n_samples} synthetic fire spread samples...") - df = generate_synthetic_dataset(n_samples=n_samples) - - X = df[FEATURE_COLS] - y_1h = df["spread_1h_m"] - y_3h = df["spread_3h_m"] - - X_train, X_test, y1_train, y1_test, y3_train, y3_test = train_test_split( - X, y_1h, y_3h, test_size=0.2, random_state=42 - ) - - xgb_params = dict( - n_estimators=300, - max_depth=6, - learning_rate=0.08, - subsample=0.8, - colsample_bytree=0.8, - random_state=42, - n_jobs=-1, - ) - - print("Training 1-hour spread model...") - model_1h = XGBRegressor(**xgb_params) - model_1h.fit(X_train, y1_train) - - print("Training 3-hour spread model...") - model_3h = XGBRegressor(**xgb_params) - model_3h.fit(X_train, y3_train) - - pred_1h = model_1h.predict(X_test) - pred_3h = model_3h.predict(X_test) - - metrics = { - "1h_mae_m": round(mean_absolute_error(y1_test, pred_1h), 1), - "1h_r2": round(r2_score(y1_test, pred_1h), 3), - "3h_mae_m": round(mean_absolute_error(y3_test, pred_3h), 1), - "3h_r2": round(r2_score(y3_test, pred_3h), 3), - } - - joblib.dump(model_1h, MODEL_1H_PATH) - joblib.dump(model_3h, MODEL_3H_PATH) - print(f"Models saved -> {MODEL_DIR}") - - return model_1h, model_3h, metrics - - -# ── Lazy Model Loading ──────────────────────────────────────────────────────── - -_model_1h: XGBRegressor | None = None -_model_3h: XGBRegressor | None = None - - -def _load_models() -> tuple[XGBRegressor, XGBRegressor]: - global _model_1h, _model_3h - if _model_1h is None or _model_3h is None: - if MODEL_1H_PATH.exists() and MODEL_3H_PATH.exists(): - logger.info("Loading pre-trained spread models from disk...") - _model_1h = joblib.load(MODEL_1H_PATH) - _model_3h = joblib.load(MODEL_3H_PATH) - else: - logger.info("No saved models — training now...") - _model_1h, _model_3h, _ = train_spread_model() - return _model_1h, _model_3h - - -# ── Prediction API ──────────────────────────────────────────────────────────── - -def predict_spread_from_features(features: dict) -> dict: - """ - Build feature row, run both XGBoost models, return predictions. - Any missing feature is filled with a safe fire-season default. - """ - model_1h, model_3h = _load_models() - - # Compute U/V from raw wind direction if provided - wind_speed = features.get("wind_speed_km_h", 20.0) - wind_dir = features.get("wind_direction_deg", 180.0) - wind_dir_rad = math.radians(wind_dir) - - defaults = { - "wind_speed_km_h": wind_speed, - "wind_u": wind_speed * math.cos(wind_dir_rad), - "wind_v": wind_speed * math.sin(wind_dir_rad), - "temperature_c": 25.0, - "relative_humidity_pct": 35.0, - "fwi": 25.0, - "isi": 10.0, - "bui": 60.0, - "area_hectares": 500.0, - "slope_pct": 5.0, # mild uphill default - "rh_trend_24h": -8.0, # slight drying — typical fire season - } - - row = {col: features.get(col, defaults[col]) for col in FEATURE_COLS} - df_row = pd.DataFrame([row]) - - spread_1h = float(model_1h.predict(df_row)[0]) - spread_3h = float(model_3h.predict(df_row)[0]) - - return { - "spread_1h_m": round(max(50, spread_1h)), - "spread_3h_m": round(max(100, spread_3h)), - "features_used": row, - "model": "XGBoost-Rothermel-v2-WindUV-Slope-Trend", - } - - -def predict_spread(fire_id: str, fire_data: dict | None = None) -> dict: - """ - High-level call: fetch live weather → build features → predict. - Called by GET /api/v1/predictions/{fire_id}. - """ - from src.ingestion.weather import get_fire_weather - - lat = fire_data.get("latitude", 49.9071) if fire_data else 49.9071 - lon = fire_data.get("longitude", -119.496) if fire_data else -119.496 - area = fire_data.get("area_hectares", 500) if fire_data else 500 - - weather = get_fire_weather(lat, lon) - - if weather: - wind_speed = weather.get("wind_speed_km_h", 20.0) or 20.0 - wind_dir = weather.get("wind_direction_deg", 180.0) or 180.0 - wind_dir_rad = math.radians(wind_dir) - features = { - "wind_speed_km_h": wind_speed, - "wind_direction_deg": wind_dir, # kept for U/V calc in predict_spread_from_features - "wind_u": wind_speed * math.cos(wind_dir_rad), - "wind_v": wind_speed * math.sin(wind_dir_rad), - "temperature_c": weather.get("temperature_c", 25.0), - "relative_humidity_pct": weather.get("relative_humidity_pct", 35.0), - "fwi": 25.0, # CFFDRS fallback (overwritten below if available) - "isi": 10.0, - "bui": 60.0, - "area_hectares": float(area or 500), - "slope_pct": 5.0, # default mild uphill - "rh_trend_24h": -8.0, # typical fire-season drying trend - } - - # Enrich with real CFFDRS fire danger indices from nearest NRCan weather station - try: - from src.ingestion.cffdrs import get_cffdrs_for_location - cffdrs = get_cffdrs_for_location(lat, lon) - if cffdrs: - if cffdrs.get("fwi") is not None: - features["fwi"] = cffdrs["fwi"] - if cffdrs.get("isi") is not None: - features["isi"] = cffdrs["isi"] - if cffdrs.get("bui") is not None: - features["bui"] = cffdrs["bui"] - logger.info( - f"CFFDRS station '{cffdrs['source_station']}' " - f"({cffdrs['distance_km']} km away): " - f"FWI={cffdrs['fwi']}, ISI={cffdrs['isi']}, BUI={cffdrs['bui']}" - ) - except Exception as e: - logger.warning(f"CFFDRS lookup failed for {fire_id}: {e} — using fallback indices") - else: - logger.warning(f"No weather for {fire_id} — using defaults") - features = {} - - prediction = predict_spread_from_features(features) - prediction["fire_id"] = fire_id - return prediction - - -# ── CLI: Train + Evaluate + Live Demo ─────────────────────────────────────── - -if __name__ == "__main__": - logging.basicConfig(level=logging.WARNING) # suppress httpx INFO noise - - print("=" * 65) - print(" FireGrid XGBoost Spread Model v2 — Wind U/V · Slope · Trend") - print("=" * 65) - - model_1h, model_3h, metrics = train_spread_model(n_samples=6000) - - print("\nEvaluation (20% held-out test set):") - print(f" 1h model: MAE = {metrics['1h_mae_m']} m R² = {metrics['1h_r2']}") - print(f" 3h model: MAE = {metrics['3h_mae_m']} m R² = {metrics['3h_r2']}") - - print("\nLive predictions (real weather from Open-Meteo):\n") - demo_fires = [ - {"fire_id": "BC-2026-001", "name": "Okanagan Ridge Fire", "latitude": 49.9071, "longitude": -119.4960, "area_hectares": 12450}, - {"fire_id": "BC-2026-003", "name": "Fraser Valley Approach", "latitude": 49.3845, "longitude": -121.4483, "area_hectares": 250}, - {"fire_id": "AB-2026-001", "name": "Peace River Complex", "latitude": 56.2370, "longitude": -117.2900, "area_hectares": 12500}, - ] - - for fire in demo_fires: - result = predict_spread(fire["fire_id"], fire) - wx = result.get("features_used", {}) - print(f"━━━ {fire['name']} ({fire['fire_id']}) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━") - print(f" Wind: {wx.get('wind_speed_km_h', '?'):.1f} km/h " - f"-> U={wx.get('wind_u', 0):.1f}, V={wx.get('wind_v', 0):.1f}") - print(f" Temp/RH: {wx.get('temperature_c', '?')}°C / {wx.get('relative_humidity_pct', '?')}% RH") - print(f" Slope: {wx.get('slope_pct', 5.0):.0f}% RH trend: {wx.get('rh_trend_24h', -8.0):.0f}% per 24h") - print(f" Area: {fire['area_hectares']:,} ha") - print(f" +1h spread: {result['spread_1h_m']:,} m") - print(f" +3h spread: {result['spread_3h_m']:,} m") - print() - - print("Feature importances (1h model — sorted):") - fi = dict(zip(FEATURE_COLS, model_1h.feature_importances_, strict=True)) - for feat, imp in sorted(fi.items(), key=lambda x: -x[1]): - bar = "█" * int(imp * 50) - print(f" {feat:<28} {bar} ({imp:.3f})") diff --git a/src/models/train_rl_agent.py b/src/models/train_rl_agent.py index 505cead..d71adec 100644 --- a/src/models/train_rl_agent.py +++ b/src/models/train_rl_agent.py @@ -18,6 +18,42 @@ logger = logging.getLogger(__name__) MODEL_SAVE_PATH = Path(__file__).parent / "tactical_ppo_agent" +DEFAULT_SCENARIO_DATASET = Path("data/static/scenario_parameter_records_train.json") + + +def _resolve_dataset_path(path: str | None) -> str | None: + if path: + return path + if DEFAULT_SCENARIO_DATASET.exists(): + return str(DEFAULT_SCENARIO_DATASET) + return None + + +def _existing_path(path: str | None) -> str | None: + if path and Path(path).exists(): + return path + return None + + +def _evaluate_model(model, dataset_path: str, seed: int, episodes: int = 5) -> tuple[float, float]: + from src.models.fire_env import WildfireEnv, load_scenario_parameter_records + + records = load_scenario_parameter_records(dataset_path) + eval_env = WildfireEnv(scenario_parameter_records=records) + returns = [] + assets_lost_total = [] + for ep in range(episodes): + obs, _ = eval_env.reset(seed=seed + ep + 100) + ep_return = 0.0 + for _ in range(150): + action, _ = model.predict(obs, deterministic=True) + obs, reward, done, truncated, info = eval_env.step(int(action)) + ep_return += reward + if done or truncated: + break + returns.append(ep_return) + assets_lost_total.append(info["assets_lost"]) + return sum(returns) / len(returns), sum(assets_lost_total) / len(assets_lost_total) def train( @@ -25,6 +61,9 @@ def train( spread_rate_m_per_min: float = 15.0, n_envs: int = 4, seed: int = 42, + scenario_dataset_path: str | None = None, + val_dataset_path: str | None = None, + holdout_dataset_path: str | None = None, ) -> None: """ Train the PPO tactical agent. @@ -39,7 +78,7 @@ def train( from stable_baselines3 import PPO from stable_baselines3.common.env_util import make_vec_env - from src.models.fire_env import WildfireEnv + from src.models.fire_env import WildfireEnv, load_scenario_parameter_records except ImportError as e: print(f"Missing dependency: {e}") print(" Run: uv sync") @@ -55,7 +94,15 @@ def train( print(" Budgets: heli=8, crew=20") print() - env_kwargs = {"base_spread_rate_m_per_min": spread_rate_m_per_min} + scenario_dataset_path = _resolve_dataset_path(scenario_dataset_path) + + env_kwargs: dict = {} + if scenario_dataset_path: + records = load_scenario_parameter_records(scenario_dataset_path) + env_kwargs["scenario_parameter_records"] = records + print(f" Scenario records: {len(records)} from {scenario_dataset_path}") + else: + env_kwargs["base_spread_rate_m_per_min"] = spread_rate_m_per_min vec_env = make_vec_env( WildfireEnv, n_envs=n_envs, @@ -87,37 +134,51 @@ def train( # Quick evaluation print("\nRunning quick evaluation (5 episodes)...") - from src.models.fire_env import WildfireEnv as Env - eval_env = Env(base_spread_rate_m_per_min=spread_rate_m_per_min) - returns = [] - assets_lost_total = [] - for ep in range(5): - obs, _ = eval_env.reset(seed=seed + ep + 100) - ep_return = 0.0 - for _ in range(150): - action, _ = model.predict(obs, deterministic=True) - obs, reward, done, truncated, info = eval_env.step(int(action)) - ep_return += reward - if done or truncated: - break - returns.append(ep_return) - assets_lost_total.append(info["assets_lost"]) - - print(f" Mean return: {sum(returns)/len(returns):.1f}") - print(f" Mean assets lost: {sum(assets_lost_total)/len(assets_lost_total):.1f}") + eval_targets = [("train", scenario_dataset_path)] + if _existing_path(val_dataset_path): + eval_targets.append(("val", val_dataset_path)) + if _existing_path(holdout_dataset_path): + eval_targets.append(("holdout", holdout_dataset_path)) + + for split_name, dataset_path in eval_targets: + if not dataset_path: + continue + mean_return, mean_assets_lost = _evaluate_model(model, dataset_path, seed=seed, episodes=5) + print(f" [{split_name}] Mean return: {mean_return:.1f}") + print(f" [{split_name}] Mean assets lost: {mean_assets_lost:.1f}") print(f"\nTraining complete. Model ready at {MODEL_SAVE_PATH}.zip") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Train PPO wildfire tactical agent") - parser.add_argument("--timesteps", type=int, default=200_000, - help="Total training timesteps (default: 200000)") - parser.add_argument("--spread-rate", type=float, default=15.0, - help="Fire spread rate in m/min (default: 15.0)") - parser.add_argument("--envs", type=int, default=4, - help="Number of parallel environments (default: 4)") - parser.add_argument("--seed", type=int, default=42, - help="Random seed (default: 42)") + parser.add_argument( + "--timesteps", type=int, default=200_000, help="Total training timesteps (default: 200000)" + ) + parser.add_argument( + "--spread-rate", type=float, default=15.0, help="Fire spread rate in m/min (default: 15.0)" + ) + parser.add_argument( + "--envs", type=int, default=4, help="Number of parallel environments (default: 4)" + ) + parser.add_argument("--seed", type=int, default=42, help="Random seed (default: 42)") + parser.add_argument( + "--scenario-dataset", + type=str, + default=None, + help="Path to cached training scenario parameter JSON dataset", + ) + parser.add_argument( + "--val-dataset", + type=str, + default="data/static/scenario_parameter_records_val.json", + help="Path to cached validation scenario parameter JSON dataset", + ) + parser.add_argument( + "--holdout-dataset", + type=str, + default="data/static/scenario_parameter_records_holdout.json", + help="Path to cached holdout scenario parameter JSON dataset", + ) args = parser.parse_args() train( @@ -125,4 +186,7 @@ def train( spread_rate_m_per_min=args.spread_rate, n_envs=args.envs, seed=args.seed, + scenario_dataset_path=args.scenario_dataset, + val_dataset_path=args.val_dataset, + holdout_dataset_path=args.holdout_dataset, ) diff --git a/uv.lock b/uv.lock index 20671ca..80cff6e 100644 --- a/uv.lock +++ b/uv.lock @@ -14,9 +14,11 @@ version = "0.1.0" source = { virtual = "." } dependencies = [ { name = "gymnasium" }, + { name = "httpx" }, { name = "matplotlib" }, { name = "numpy" }, { name = "pandas" }, + { name = "python-dotenv" }, { name = "scikit-learn" }, { name = "stable-baselines3" }, { name = "torch" }, @@ -33,9 +35,11 @@ dev = [ [package.metadata] requires-dist = [ { name = "gymnasium", specifier = ">=1.2.3" }, + { name = "httpx", specifier = ">=0.28.1" }, { name = "matplotlib", specifier = ">=3.10.8" }, { name = "numpy", specifier = ">=2.4.2" }, { name = "pandas", specifier = ">=2.3.0" }, + { name = "python-dotenv", specifier = ">=1.1.1" }, { name = "scikit-learn", specifier = ">=1.7.0" }, { name = "stable-baselines3", specifier = ">=2.4.1" }, { name = "torch", specifier = ">=2.10.0" }, @@ -58,6 +62,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53", size = 13643, upload-time = "2024-05-20T21:33:24.1Z" }, ] +[[package]] +name = "anyio" +version = "4.13.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "idna" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/19/14/2c5dd9f512b66549ae92767a9c7b330ae88e1932ca57876909410251fe13/anyio-4.13.0.tar.gz", hash = "sha256:334b70e641fd2221c1505b3890c69882fe4a2df910cba14d97019b90b24439dc", size = 231622, upload-time = "2026-03-24T12:59:09.671Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/da/42/e921fccf5015463e32a3cf6ee7f980a6ed0f395ceeaa45060b61d86486c2/anyio-4.13.0-py3-none-any.whl", hash = "sha256:08b310f9e24a9594186fd75b4f73f4a4152069e3853f1ed8bfbf58369f4ad708", size = 114353, upload-time = "2026-03-24T12:59:08.246Z" }, +] + [[package]] name = "certifi" version = "2026.1.4" @@ -275,6 +291,43 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/56/d3/ea5f088e3638dbab12e5c20d6559d5b3bdaeaa1f2af74e526e6815836285/gymnasium-1.2.3-py3-none-any.whl", hash = "sha256:e6314bba8f549c7fdcc8677f7cd786b64908af6e79b57ddaa5ce1825bffb5373", size = 952113, upload-time = "2025-12-18T16:51:08.445Z" }, ] +[[package]] +name = "h11" +version = "0.16.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/01/ee/02a2c011bdab74c6fb3c75474d40b3052059d95df7e73351460c8588d963/h11-0.16.0.tar.gz", hash = "sha256:4e35b956cf45792e4caa5885e69fba00bdbc6ffafbfa020300e549b208ee5ff1", size = 101250, upload-time = "2025-04-24T03:35:25.427Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515, upload-time = "2025-04-24T03:35:24.344Z" }, +] + +[[package]] +name = "httpcore" +version = "1.0.9" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "certifi" }, + { name = "h11" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/06/94/82699a10bca87a5556c9c59b5963f2d039dbd239f25bc2a63907a05a14cb/httpcore-1.0.9.tar.gz", hash = "sha256:6e34463af53fd2ab5d807f399a9b45ea31c3dfa2276f15a2c3f00afff6e176e8", size = 85484, upload-time = "2025-04-24T22:06:22.219Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7e/f5/f66802a942d491edb555dd61e3a9961140fd64c90bce1eafd741609d334d/httpcore-1.0.9-py3-none-any.whl", hash = "sha256:2d400746a40668fc9dec9810239072b40b4484b640a8c38fd654a024c7a1bf55", size = 78784, upload-time = "2025-04-24T22:06:20.566Z" }, +] + +[[package]] +name = "httpx" +version = "0.28.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "certifi" }, + { name = "httpcore" }, + { name = "idna" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b1/df/48c586a5fe32a0f01324ee087459e112ebb7224f646c0b5023f5e79e9956/httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc", size = 141406, upload-time = "2024-12-06T15:37:23.222Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517, upload-time = "2024-12-06T15:37:21.509Z" }, +] + [[package]] name = "idna" version = "3.11" @@ -797,6 +850,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892, upload-time = "2024-03-01T18:36:18.57Z" }, ] +[[package]] +name = "python-dotenv" +version = "1.2.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/82/ed/0301aeeac3e5353ef3d94b6ec08bbcabd04a72018415dcb29e588514bba8/python_dotenv-1.2.2.tar.gz", hash = "sha256:2c371a91fbd7ba082c2c1dc1f8bf89ca22564a087c2c287cd9b662adde799cf3", size = 50135, upload-time = "2026-03-01T16:00:26.196Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0b/d7/1959b9648791274998a9c3526f6d0ec8fd2233e4d4acce81bbae76b44b2a/python_dotenv-1.2.2-py3-none-any.whl", hash = "sha256:1d8214789a24de455a8b8bd8ae6fe3c6b69a5e3d64aa8a8e5d68e694bbcb285a", size = 22101, upload-time = "2026-03-01T16:00:25.09Z" }, +] + [[package]] name = "pyyaml" version = "6.0.3"