Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
name: CI

on:
push:
branches: [main]
pull_request:
branches: [main]

jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: astral-sh/ruff-action@v3

smoke-test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: astral-sh/setup-uv@v4
- run: uv sync
- name: Verify env imports and runs
run: |
uv run python -c "
from src.models.fire_env import WildfireEnv
env = WildfireEnv()
obs, _ = env.reset(seed=42)
assert obs.shape == (631,)
for _ in range(10):
obs, r, done, trunc, info = env.step(env.action_space.sample())
print('smoke test passed')
"
35 changes: 35 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# FireGrid

Empirical RL benchmark for wildfire tactical suppression. Compares DQN, A2C, PPO, and heuristic baselines on a 25x25 grid environment with critical assets and finite suppression budgets.

## Setup

```bash
uv sync
```

### Pre-commit hooks (optional)

Install [lefthook](https://github.com/evilmartians/lefthook) for local lint/format checks on commit:

```bash
# pick one
brew install lefthook
npm i -g lefthook

# then wire it up
lefthook install
```

## Usage

```bash
# Train PPO agent (200k steps)
uv run python -m src.models.train_rl_agent

# Quick test (10k steps)
uv run python -m src.models.train_rl_agent --timesteps 10000

# Train XGBoost spread model
uv run python -m src.models.spread_model
```
8 changes: 8 additions & 0 deletions lefthook.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
pre-commit:
commands:
lint:
glob: "*.py"
run: uv run ruff check {staged_files}
format-check:
glob: "*.py"
run: uv run ruff format --check {staged_files}
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,9 @@ dependencies = [
"scikit-learn>=1.7.0",
"pandas>=2.3.0",
]

[dependency-groups]
dev = [
"ruff>=0.11.0",
"pytest>=8.0.0",
]
30 changes: 30 additions & 0 deletions ruff.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
target-version = "py314"
line-length = 100
extend-exclude = ["drd-archive"]

[lint]
select = [
"E", # pycodestyle errors
"W", # pycodestyle warnings
"F", # pyflakes
"I", # isort
"UP", # pyupgrade
"B", # flake8-bugbear
"SIM", # flake8-simplify
"RUF", # ruff-specific rules
"N", # pep8-naming
]
ignore = [
"E501", # line too long — handled by formatter
"RUF001", # ambiguous unicode in strings
"RUF002", # ambiguous unicode in docstrings
"RUF003", # ambiguous unicode in comments
"RUF012", # mutable class attribute (gymnasium pattern)
"N806", # uppercase variable in function (ML convention: X, X_train, etc.)
]

[lint.isort]
known-first-party = ["src"]

[format]
quote-style = "double"
16 changes: 7 additions & 9 deletions src/ingestion/cffdrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@
import io
import logging
import math
from datetime import datetime, timezone
from typing import Optional
from datetime import UTC, datetime

import httpx

Expand All @@ -49,7 +48,7 @@ def _haversine_km(lat1: float, lon1: float, lat2: float, lon2: float) -> float:
return R * 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a))


def _parse_float(val: str) -> Optional[float]:
def _parse_float(val: str) -> float | None:
"""Safely parse a float from a CSV string, returning None if invalid."""
try:
f = float(val)
Expand All @@ -58,15 +57,15 @@ def _parse_float(val: str) -> Optional[float]:
return None


def fetch_cffdrs_stations(year: Optional[int] = None) -> list[dict]:
def fetch_cffdrs_stations(year: int | None = None) -> list[dict]:
"""
Download the full CWFIS annual FWI observation CSV and parse it.

Returns a list of station dicts with lat, lon, and all CFFDRS indices.
Filters to BC + AB stations only.
"""
if year is None:
year = datetime.now(timezone.utc).year
year = datetime.now(UTC).year

url = CFFDRS_BASE_URL.format(year=year)
logger.info(f"Fetching CFFDRS station data from {url}")
Expand All @@ -81,7 +80,7 @@ def fetch_cffdrs_stations(year: Optional[int] = None) -> list[dict]:
except httpx.HTTPStatusError as e:
logger.error(f"CFFDRS HTTP {e.response.status_code}")
# Try prior year as fallback (may not have current year yet)
if e.response.status_code == 404 and year == datetime.now(timezone.utc).year:
if e.response.status_code == 404 and year == datetime.now(UTC).year:
logger.info("Trying prior year as fallback...")
return fetch_cffdrs_stations(year - 1)
return []
Expand Down Expand Up @@ -133,9 +132,9 @@ def fetch_cffdrs_stations(year: Optional[int] = None) -> list[dict]:
def get_cffdrs_for_location(
latitude: float,
longitude: float,
stations: Optional[list[dict]] = None,
stations: list[dict] | None = None,
max_radius_km: float = 200.0,
) -> Optional[dict]:
) -> dict | None:
"""
Find the nearest CWFIS weather station and return its CFFDRS indices.

Expand Down Expand Up @@ -220,7 +219,6 @@ def get_cffdrs_for_fires(fires: list[dict]) -> dict[str, dict]:

# ── Manual test ───────────────────────────────────────────────────────────────
if __name__ == "__main__":
import json
logging.basicConfig(level=logging.INFO)

test_fires = [
Expand Down
13 changes: 5 additions & 8 deletions src/ingestion/cwfis.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,10 @@
import csv
import io
import logging
from datetime import datetime, timezone
from typing import Optional
from datetime import UTC, datetime

import httpx

from src.core.config import settings

logger = logging.getLogger(__name__)

# ── CWFIS Open Data URLs ──────────────────────────────────────────────────────
Expand All @@ -43,7 +40,7 @@ def _severity_from_status(status: str) -> str:
return "low"


def _normalize_cwfis_row(row: dict) -> Optional[dict]:
def _normalize_cwfis_row(row: dict) -> dict | None:
"""
Normalize a single CWFIS CSV row into a FireGrid FireEvent dict.

Expand Down Expand Up @@ -75,9 +72,9 @@ def _normalize_cwfis_row(row: dict) -> Optional[dict]:

# Build ISO timestamp from start date
try:
started_at = datetime.strptime(start_date, "%Y-%m-%d").replace(tzinfo=timezone.utc).isoformat()
started_at = datetime.strptime(start_date, "%Y-%m-%d").replace(tzinfo=UTC).isoformat()
except (ValueError, TypeError):
started_at = datetime.now(timezone.utc).isoformat()
started_at = datetime.now(UTC).isoformat()

# Build a stable fire_id using province + fire number
safe_num = fire_number.replace(" ", "_").replace("/", "-")
Expand All @@ -93,7 +90,7 @@ def _normalize_cwfis_row(row: dict) -> Optional[dict]:
"longitude": lon,
"area_hectares": hectares,
"started_at": started_at,
"updated_at": datetime.now(timezone.utc).isoformat(),
"updated_at": datetime.now(UTC).isoformat(),
"source": "CWFIS_NRCAN",
}
except (ValueError, KeyError, TypeError) as e:
Expand Down
23 changes: 11 additions & 12 deletions src/ingestion/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
"""

import random
from datetime import datetime, timedelta, timezone

from datetime import UTC, datetime, timedelta

# ── Seed for reproducible dummy data ───────────────────────────────────────────
random.seed(42)
Expand Down Expand Up @@ -44,8 +43,8 @@ def _rand_coord() -> tuple[float, float]:
"latitude": 49.9071,
"longitude": -119.4960,
"area_hectares": 4200.0,
"started_at": (datetime.now(timezone.utc) - timedelta(hours=18)).isoformat(),
"updated_at": datetime.now(timezone.utc).isoformat(),
"started_at": (datetime.now(UTC) - timedelta(hours=18)).isoformat(),
"updated_at": datetime.now(UTC).isoformat(),
"source": "dummy",
},
{
Expand All @@ -57,8 +56,8 @@ def _rand_coord() -> tuple[float, float]:
"latitude": 50.6745,
"longitude": -120.3273,
"area_hectares": 800.0,
"started_at": (datetime.now(timezone.utc) - timedelta(hours=6)).isoformat(),
"updated_at": datetime.now(timezone.utc).isoformat(),
"started_at": (datetime.now(UTC) - timedelta(hours=6)).isoformat(),
"updated_at": datetime.now(UTC).isoformat(),
"source": "dummy",
},
{
Expand All @@ -70,8 +69,8 @@ def _rand_coord() -> tuple[float, float]:
"latitude": 49.3845,
"longitude": -121.4483,
"area_hectares": 250.0,
"started_at": (datetime.now(timezone.utc) - timedelta(hours=8)).isoformat(),
"updated_at": datetime.now(timezone.utc).isoformat(),
"started_at": (datetime.now(UTC) - timedelta(hours=8)).isoformat(),
"updated_at": datetime.now(UTC).isoformat(),
"source": "dummy",
},
{
Expand All @@ -83,8 +82,8 @@ def _rand_coord() -> tuple[float, float]:
"latitude": 56.2370,
"longitude": -117.2900,
"area_hectares": 12500.0,
"started_at": (datetime.now(timezone.utc) - timedelta(hours=36)).isoformat(),
"updated_at": datetime.now(timezone.utc).isoformat(),
"started_at": (datetime.now(UTC) - timedelta(hours=36)).isoformat(),
"updated_at": datetime.now(UTC).isoformat(),
"source": "dummy",
},
]
Expand Down Expand Up @@ -130,7 +129,7 @@ def get_dummy_burn_probability(fire_id: str) -> dict:
"fire_id": fire_id,
"model": "dummy_v0",
"horizon_hours": 24,
"generated_at": datetime.now(timezone.utc).isoformat(),
"generated_at": datetime.now(UTC).isoformat(),
"wind_speed_kmh": random.uniform(20, 60),
"wind_direction_deg": random.uniform(220, 280), # SW winds, pushing NE
"cells": cells,
Expand Down Expand Up @@ -229,7 +228,7 @@ def get_dummy_choke_points(fire_id: str) -> dict:
return {
"fire_id": fire_id,
"model": "greedy_heuristic_v0",
"generated_at": datetime.now(timezone.utc).isoformat(),
"generated_at": datetime.now(UTC).isoformat(),
"total_choke_points": len(recommendations),
"recommendations": recommendations,
}
11 changes: 5 additions & 6 deletions src/ingestion/firms.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
import csv
import io
import logging
from datetime import datetime, timezone
from typing import Optional
from datetime import UTC, datetime

import httpx

Expand Down Expand Up @@ -63,7 +62,7 @@ def _frp_to_severity(frp: float) -> str:
return "low"


def _normalize_hotspot(row: dict, idx: int) -> Optional[dict]:
def _normalize_hotspot(row: dict, idx: int) -> dict | None:
"""
Normalize a single FIRMS CSV row into a FireGrid FireEvent dict.
Returns None if the row is missing critical fields.
Expand All @@ -78,9 +77,9 @@ def _normalize_hotspot(row: dict, idx: int) -> Optional[dict]:
# Build ISO timestamp from acquisition date + time
try:
dt_str = f"{acq_date} {acq_time.zfill(4)}"
detected_at = datetime.strptime(dt_str, "%Y-%m-%d %H%M").replace(tzinfo=timezone.utc).isoformat()
detected_at = datetime.strptime(dt_str, "%Y-%m-%d %H%M").replace(tzinfo=UTC).isoformat()
except ValueError:
detected_at = datetime.now(timezone.utc).isoformat()
detected_at = datetime.now(UTC).isoformat()

province = _assign_province(lat, lon)
severity = _frp_to_severity(frp)
Expand All @@ -104,7 +103,7 @@ def _normalize_hotspot(row: dict, idx: int) -> Optional[dict]:
"confidence": row.get("confidence", "n"),
"satellite": row.get("satellite", "N20"),
"started_at": detected_at,
"updated_at": datetime.now(timezone.utc).isoformat(),
"updated_at": datetime.now(UTC).isoformat(),
"source": "NASA_FIRMS_VIIRS",
}
except (ValueError, KeyError) as e:
Expand Down
8 changes: 3 additions & 5 deletions src/ingestion/weather.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
"""

import logging
from datetime import datetime, timezone
from typing import Optional
from datetime import UTC, datetime

import httpx

Expand All @@ -47,7 +46,7 @@ def get_fire_weather(
longitude: float,
*,
timeout: int = 10,
) -> Optional[dict]:
) -> dict | None:
"""
Fetch current weather conditions at a fire's coordinates.

Expand Down Expand Up @@ -112,7 +111,7 @@ def get_fire_weather(
"precipitation_mm": current.get("precipitation", 0.0),
"surface_pressure_hpa": current.get("surface_pressure"),
"dew_point_c": current.get("dew_point_2m"),
"fetched_at": datetime.now(timezone.utc).isoformat(),
"fetched_at": datetime.now(UTC).isoformat(),
}


Expand Down Expand Up @@ -146,7 +145,6 @@ def get_weather_for_fires(fires: list[dict]) -> dict[str, dict]:

# ── Manual test ───────────────────────────────────────────────────────────────
if __name__ == "__main__":
import json
logging.basicConfig(level=logging.INFO)

# Test fires
Expand Down
Loading
Loading