From 80e2a081981000d0ca09f9958d24f10ed7ea36cf Mon Sep 17 00:00:00 2001 From: Tadd Bindas Date: Fri, 13 Mar 2026 21:53:50 -0400 Subject: [PATCH 1/2] Revert "Phase 1: KAN bias correction foundation (#144) (#158)" This reverts commit 0a318b83c7652f655a42d66ba2e20061c4cfa9c9. --- src/ddr/nn/__init__.py | 3 +- src/ddr/nn/temporal_phi_kan.py | 125 -------------------------- src/ddr/validation/__init__.py | 8 +- src/ddr/validation/configs.py | 44 +--------- src/ddr/validation/enums.py | 9 -- src/ddr/validation/losses.py | 84 ------------------ tests/nn/test_temporal_phi_kan.py | 140 ------------------------------ tests/validation/test_configs.py | 80 +---------------- tests/validation/test_losses.py | 91 ------------------- 9 files changed, 4 insertions(+), 580 deletions(-) delete mode 100644 src/ddr/nn/temporal_phi_kan.py delete mode 100644 src/ddr/validation/losses.py delete mode 100644 tests/nn/test_temporal_phi_kan.py delete mode 100644 tests/validation/test_losses.py diff --git a/src/ddr/nn/__init__.py b/src/ddr/nn/__init__.py index 94f2f65..1622490 100644 --- a/src/ddr/nn/__init__.py +++ b/src/ddr/nn/__init__.py @@ -1,4 +1,3 @@ from .kan import kan -from .temporal_phi_kan import TemporalPhiKAN -__all__ = ["TemporalPhiKAN", "kan"] +__all__ = ["kan"] diff --git a/src/ddr/nn/temporal_phi_kan.py b/src/ddr/nn/temporal_phi_kan.py deleted file mode 100644 index ca74291..0000000 --- a/src/ddr/nn/temporal_phi_kan.py +++ /dev/null @@ -1,125 +0,0 @@ -"""Temporal φ-KAN for bias correction of lateral inflows. - -A small KAN that corrects Q' using flow magnitude and seasonal context. -By Kolmogorov-Arnold theory, the correction decomposes as: - - φ(Q', sin_m, cos_m) = Σ_q Φ_q( ψ_{q,Q'}(Q') + ψ_{q,sin}(sin_m) + ψ_{q,cos}(cos_m) ) - -Each ψ is a plottable 1D B-spline curve — fully interpretable. -""" - -import logging -import math - -import torch -import torch.nn as nn -from kan import KAN - -from ddr.validation.configs import BiasCorrection -from ddr.validation.enums import PhiInputs - -log = logging.getLogger(__name__) - -_INPUT_DIM = { - PhiInputs.STATIC: 1, - PhiInputs.MONTHLY: 3, - PhiInputs.FORCING: 2, - PhiInputs.RANDOM: 3, -} - - -class TemporalPhiKAN(nn.Module): - """Small KAN that corrects Q' using flow magnitude and seasonal context. - - Parameters - ---------- - cfg : BiasCorrection - Bias correction configuration specifying phi_inputs mode and KAN hyperparameters. - seed : int - Random seed for reproducibility (from top-level Config.seed). - device : int | str - Compute device (from top-level Config.device). - """ - - def __init__(self, cfg: BiasCorrection, seed: int, device: int | str = "cpu"): - super().__init__() - self.phi_inputs = cfg.phi_inputs - self.input_dim = _INPUT_DIM[cfg.phi_inputs] - - self.phi_kan = KAN( - [self.input_dim, cfg.phi_hidden_size, 1], - grid=cfg.phi_grid, - k=cfg.phi_k, - seed=seed, - device=device, - ) - - def forward( - self, - q_prime: torch.Tensor, - month: torch.Tensor | None = None, - forcing: torch.Tensor | None = None, - grid_bounds: torch.Tensor | None = None, - ) -> torch.Tensor: - """Bias-correct lateral inflows. - - Parameters - ---------- - q_prime : (T, N) - Raw lateral inflow from dHBV2. - month : (T,), optional - Month of year as float [1, 12]. Required for MONTHLY mode. - forcing : (T, N), optional - Forcing variable values. Required for FORCING mode. - grid_bounds : (N, 2), optional - Per-node [min, max] from Spatial KAN for normalization. - - Returns - ------- - q_corrected : (T, N) - Bias-corrected lateral inflow. - """ - T, N = q_prime.shape - - # Normalize Q' per-node using Spatial KAN's grid bounds - if grid_bounds is not None: - grid_min = grid_bounds[:, 0] # (N,) - grid_max = grid_bounds[:, 1] # (N,) - q_norm = (q_prime - grid_min) / (grid_max - grid_min + 1e-8) # (T, N) - else: - q_norm = q_prime - - # Build input tensor based on mode - if self.phi_inputs == PhiInputs.STATIC: - phi_input = q_norm.unsqueeze(-1) # (T, N, 1) - - elif self.phi_inputs == PhiInputs.MONTHLY: - assert month is not None, "month tensor required for MONTHLY mode" - two_pi_month = 2.0 * math.pi * month / 12.0 # (T,) - sin_month = torch.sin(two_pi_month).unsqueeze(1).expand(T, N) # (T, N) - cos_month = torch.cos(two_pi_month).unsqueeze(1).expand(T, N) # (T, N) - phi_input = torch.stack([q_norm, sin_month, cos_month], dim=-1) # (T, N, 3) - - elif self.phi_inputs == PhiInputs.FORCING: - assert forcing is not None, "forcing tensor required for FORCING mode" - phi_input = torch.stack([q_norm, forcing], dim=-1) # (T, N, 2) - - elif self.phi_inputs == PhiInputs.RANDOM: - rand1 = torch.rand(T, N, device=q_prime.device) - rand2 = torch.rand(T, N, device=q_prime.device) - phi_input = torch.stack([q_norm, rand1, rand2], dim=-1) # (T, N, 3) - - else: - raise ValueError(f"Unknown phi_inputs mode: {self.phi_inputs}") - - # Flatten (T, N, input_dim) → (T*N, input_dim), run KAN, reshape back - phi_input_flat = phi_input.reshape(T * N, self.input_dim) - q_corrected_norm = self.phi_kan(phi_input_flat).reshape(T, N) # (T, N) - - # Denormalize back to physical units - if grid_bounds is not None: - q_corrected = q_corrected_norm * (grid_max - grid_min) + grid_min - else: - q_corrected = q_corrected_norm - - return torch.clamp(q_corrected, min=1e-6) diff --git a/src/ddr/validation/__init__.py b/src/ddr/validation/__init__.py index d650f75..787a5ca 100644 --- a/src/ddr/validation/__init__.py +++ b/src/ddr/validation/__init__.py @@ -1,19 +1,13 @@ from . import utils -from .configs import BiasCorrection, Config, GeoDataset, Mode, validate_config -from .enums import PhiInputs -from .losses import kge_loss, mass_balance_loss +from .configs import Config, GeoDataset, Mode, validate_config from .metrics import Metrics from .plots import plot_box_fig, plot_cdf, plot_drainage_area_boxplots, plot_gauge_map, plot_time_series __all__ = [ - "BiasCorrection", "Config", "Metrics", "Mode", "GeoDataset", - "PhiInputs", - "kge_loss", - "mass_balance_loss", "plot_time_series", "plot_box_fig", "plot_cdf", diff --git a/src/ddr/validation/configs.py b/src/ddr/validation/configs.py index 6e3b443..61c773c 100644 --- a/src/ddr/validation/configs.py +++ b/src/ddr/validation/configs.py @@ -9,7 +9,7 @@ from omegaconf import DictConfig, OmegaConf from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator, model_validator -from ddr.validation.enums import GeoDataset, Mode, PhiInputs +from ddr.validation.enums import GeoDataset, Mode log = logging.getLogger(__name__) @@ -188,44 +188,6 @@ def validate_checkpoint(cls, v: str | Path | None) -> Path | None: return check_path(str(v)) -class BiasCorrection(BaseModel): - """Configuration for KAN-based bias correction of lateral inflows""" - - model_config = ConfigDict(extra="forbid") - - enabled: bool = Field(default=False, description="Whether to enable bias correction") - phi_inputs: PhiInputs = Field( - default=PhiInputs.MONTHLY, - description="Input type for the temporal phi-KAN", - ) - forcing_var: str | None = Field( - default=None, - description="Forcing variable name (required when phi_inputs='forcing')", - ) - phi_hidden_size: int = Field(default=5, description="Hidden layer size for the phi-KAN") - phi_grid: int = Field(default=8, description="Grid size for phi-KAN B-spline basis") - phi_k: int = Field(default=3, description="B-spline order for phi-KAN layers") - lambda_mass: float = Field( - default=0.5, - description="Weight for mass balance loss relative to KGE loss", - ) - lambda_anneal: bool = Field( - default=False, - description="Whether to anneal lambda_mass over training epochs", - ) - use_kge_loss: bool = Field( - default=True, - description="Whether to use KGE loss (True) or MSE loss (False) when bias is enabled", - ) - - @model_validator(mode="after") - def validate_forcing_var(self) -> "BiasCorrection": - """Validate that forcing_var is set when phi_inputs is 'forcing'""" - if self.phi_inputs == PhiInputs.FORCING and self.forcing_var is None: - raise ValueError("forcing_var must be set when phi_inputs='forcing'") - return self - - class Config(BaseModel): """The base level configuration for the dMC (differentiable Muskingum-Cunge) model""" @@ -243,10 +205,6 @@ class Config(BaseModel): mode: Mode = Field(description="Operating mode: training, testing, or routing") params: Params = Field(description="Physical and numerical parameters for the routing model") kan: Kan = Field(description="Architecture and configuration settings for the Kolmogorov-Arnold Network") - bias: BiasCorrection = Field( - default_factory=BiasCorrection, - description="Configuration for KAN-based bias correction of lateral inflows", - ) np_seed: int = Field(default=1, description="Random seed for NumPy operations to ensure reproducibility") seed: int = Field(default=0, description="Random seed for PyTorch operations to ensure reproducibility") device: int | str = Field( diff --git a/src/ddr/validation/enums.py b/src/ddr/validation/enums.py index aea6330..80b9013 100644 --- a/src/ddr/validation/enums.py +++ b/src/ddr/validation/enums.py @@ -14,15 +14,6 @@ class Mode(StrEnum): ROUTING = "routing" -class PhiInputs(StrEnum): - """The type of inputs for the temporal phi-KAN bias correction""" - - STATIC = "static" - MONTHLY = "monthly" - FORCING = "forcing" - RANDOM = "random" - - class GeoDataset(StrEnum): """The geospatial dataset used for predictions and routing""" diff --git a/src/ddr/validation/losses.py b/src/ddr/validation/losses.py deleted file mode 100644 index a640d45..0000000 --- a/src/ddr/validation/losses.py +++ /dev/null @@ -1,84 +0,0 @@ -"""Differentiable loss functions for KAN bias correction training. - -mass_balance_loss: gives φ-KAN direct gradients (bypasses MC routing). -kge_loss: Kling-Gupta Efficiency loss that flows through MC routing. -""" - -import torch - - -def mass_balance_loss( - q_corrected: torch.Tensor, - target: torch.Tensor, - eps: float = 1e-6, -) -> torch.Tensor: - """Mass balance (ρ) loss — gives φ direct gradients bypassing MC. - - MC conserves mass, so total volume at gauge equals total injected volume - upstream. ρ_g = AUC_pred_g / AUC_obs_g. Loss = mean((ρ - 1)²). - - Parameters - ---------- - q_corrected : (G, T) - Bias-corrected discharge at gauge locations. - target : (G, T) - Observed discharge at gauge locations. - eps : float - Small constant to prevent division by zero. - - Returns - ------- - loss : scalar tensor - """ - auc_pred = q_corrected.sum(dim=1) # (G,) - auc_obs = target.sum(dim=1) # (G,) - rho = auc_pred / (auc_obs + eps) - return ((rho - 1.0) ** 2).mean() - - -def kge_loss( - pred: torch.Tensor, - target: torch.Tensor, - eps: float = 1e-6, -) -> torch.Tensor: - """Differentiable KGE loss. - - KGE = 1 - sqrt((r-1)² + (α-1)² + (β-1)²) - Loss = mean(sqrt((r-1)² + (α-1)² + (β-1)² + eps)) - - The eps inside the sqrt prevents NaN gradients at the ideal point. - The eps in denominators prevents division by zero for constant series. - - Parameters - ---------- - pred : (G, T) - Predicted discharge at gauge locations. - target : (G, T) - Observed discharge at gauge locations. - eps : float - Stabilization constant. - - Returns - ------- - loss : scalar tensor - """ - mu_pred = pred.mean(dim=1) # (G,) - mu_obs = target.mean(dim=1) # (G,) - sigma_pred = pred.std(dim=1, correction=0) # (G,) - sigma_obs = target.std(dim=1, correction=0) # (G,) - - # Correlation - pred_anom = pred - mu_pred.unsqueeze(1) - obs_anom = target - mu_obs.unsqueeze(1) - cov = (pred_anom * obs_anom).mean(dim=1) # (G,) - r = cov / (sigma_pred * sigma_obs + eps) # (G,) - - # Variability ratio - alpha = sigma_pred / (sigma_obs + eps) # (G,) - - # Bias ratio - beta = mu_pred / (mu_obs + eps) # (G,) - - # Euclidean distance from ideal (1, 1, 1) - ed = torch.sqrt((r - 1) ** 2 + (alpha - 1) ** 2 + (beta - 1) ** 2 + eps) # (G,) - return ed.mean() diff --git a/tests/nn/test_temporal_phi_kan.py b/tests/nn/test_temporal_phi_kan.py deleted file mode 100644 index 0b19bd8..0000000 --- a/tests/nn/test_temporal_phi_kan.py +++ /dev/null @@ -1,140 +0,0 @@ -"""Tests for ddr.nn.temporal_phi_kan — temporal φ-KAN bias correction.""" - -import math - -import torch - -from ddr.nn.temporal_phi_kan import TemporalPhiKAN -from ddr.validation.configs import BiasCorrection -from ddr.validation.enums import PhiInputs - - -class TestTemporalPhiKAN: - """Tests for the TemporalPhiKAN module.""" - - def _make_phi_kan(self, phi_inputs: PhiInputs = PhiInputs.MONTHLY) -> TemporalPhiKAN: - cfg = BiasCorrection(phi_inputs=phi_inputs) - return TemporalPhiKAN(cfg=cfg, seed=42, device="cpu") - - def test_monthly_output_shape(self) -> None: - model = self._make_phi_kan(PhiInputs.MONTHLY) - T, N = 24, 10 - q_prime = torch.rand(T, N) - month = torch.full((T,), 6.0) - output = model(q_prime, month=month) - - assert output.shape == (T, N) - - def test_static_output_shape(self) -> None: - model = self._make_phi_kan(PhiInputs.STATIC) - T, N = 24, 10 - q_prime = torch.rand(T, N) - output = model(q_prime) - - assert output.shape == (T, N) - - def test_forcing_output_shape(self) -> None: - cfg = BiasCorrection(phi_inputs=PhiInputs.FORCING, forcing_var="precip") - model = TemporalPhiKAN(cfg=cfg, seed=42, device="cpu") - T, N = 24, 10 - q_prime = torch.rand(T, N) - forcing = torch.rand(T, N) - output = model(q_prime, forcing=forcing) - - assert output.shape == (T, N) - - def test_random_output_shape(self) -> None: - model = self._make_phi_kan(PhiInputs.RANDOM) - T, N = 24, 10 - q_prime = torch.rand(T, N) - output = model(q_prime) - - assert output.shape == (T, N) - - def test_non_negative_output(self) -> None: - model = self._make_phi_kan(PhiInputs.MONTHLY) - T, N = 48, 5 - q_prime = torch.rand(T, N) * 0.01 # very small values - month = torch.full((T,), 1.0) - output = model(q_prime, month=month) - - assert (output >= 1e-6).all(), "Output has values below clamp threshold" - - def test_gradient_flow(self) -> None: - model = self._make_phi_kan(PhiInputs.MONTHLY) - T, N = 12, 5 - q_prime = torch.rand(T, N, requires_grad=True) - month = torch.full((T,), 3.0) - output = model(q_prime, month=month) - - loss = output.sum() - loss.backward() - - has_grad = False - for p in model.parameters(): - if p.grad is not None: - has_grad = True - break - assert has_grad, "No parameter received gradients" - - def test_deterministic_eval(self) -> None: - model = self._make_phi_kan(PhiInputs.MONTHLY) - model.eval() - T, N = 12, 5 - q_prime = torch.rand(T, N) - month = torch.full((T,), 7.0) - - with torch.no_grad(): - out1 = model(q_prime, month=month) - out2 = model(q_prime, month=month) - - assert torch.equal(out1, out2) - - def test_with_grid_bounds(self) -> None: - model = self._make_phi_kan(PhiInputs.MONTHLY) - T, N = 12, 5 - q_prime = torch.rand(T, N) * 100 - month = torch.full((T,), 6.0) - grid_bounds = torch.tensor([[0.0, 100.0]] * N) - - output = model(q_prime, month=month, grid_bounds=grid_bounds) - assert output.shape == (T, N) - - def test_without_grid_bounds(self) -> None: - model = self._make_phi_kan(PhiInputs.STATIC) - T, N = 12, 5 - q_prime = torch.rand(T, N) - - output = model(q_prime, grid_bounds=None) - assert output.shape == (T, N) - - def test_monotonicity(self) -> None: - """Output should increase when Q' increases (other inputs fixed).""" - model = self._make_phi_kan(PhiInputs.STATIC) - model.eval() - N = 10 - q_low = torch.full((1, N), 0.2) - q_high = torch.full((1, N), 0.8) - with torch.no_grad(): - out_low = model(q_low) - out_high = model(q_high) - assert (out_high >= out_low).all(), ( - f"Expected monotonic increase: out_low={out_low}, out_high={out_high}" - ) - - def test_sin_cos_encoding(self) -> None: - """Verify sin/cos month encoding produces expected values.""" - # January: month=1 → sin(2π/12) ≈ 0.5, cos(2π/12) ≈ 0.866 - month = torch.tensor([1.0]) - two_pi_month = 2.0 * math.pi * month / 12.0 - sin_val = torch.sin(two_pi_month) - cos_val = torch.cos(two_pi_month) - - assert abs(sin_val.item() - 0.5) < 0.01 - assert abs(cos_val.item() - 0.866) < 0.01 - - # April: month=4 → sin(2π*4/12) ≈ 0.866, cos ≈ -0.5 - month = torch.tensor([4.0]) - two_pi_month = 2.0 * math.pi * month / 12.0 - assert abs(torch.sin(two_pi_month).item() - 0.866) < 0.01 - assert abs(torch.cos(two_pi_month).item() - (-0.5)) < 0.01 diff --git a/tests/validation/test_configs.py b/tests/validation/test_configs.py index a66b295..9610452 100644 --- a/tests/validation/test_configs.py +++ b/tests/validation/test_configs.py @@ -5,7 +5,7 @@ from omegaconf import OmegaConf from pydantic import ValidationError -from ddr.validation.configs import BiasCorrection, Config, Params, validate_config +from ddr.validation.configs import Config, Params, validate_config def _minimal_config_dict(**overrides): @@ -220,81 +220,3 @@ def test_validate_config_from_dictconfig(self, tmp_path) -> None: config = validate_config(dc, save_config=False) assert isinstance(config, Config) assert config.name == "test_run" - - -class TestBiasCorrection: - """Test BiasCorrection config.""" - - def test_default_bias_disabled(self) -> None: - """BiasCorrection() defaults to enabled=False.""" - bc = BiasCorrection() - assert bc.enabled is False - assert bc.phi_inputs == "monthly" - assert bc.lambda_mass == 0.5 - - def test_config_without_bias_gets_default(self, tmp_path) -> None: - """Config without bias key gets BiasCorrection(enabled=False).""" - gpkg = tmp_path / "test.gpkg" - gpkg.touch() - adj = tmp_path / "adj" - adj.mkdir() - - d = _minimal_config_dict() - d["data_sources"]["geospatial_fabric_gpkg"] = str(gpkg) - d["data_sources"]["conus_adjacency"] = str(adj) - - config = Config(**d) - assert config.bias.enabled is False - - def test_config_with_bias_enabled(self, tmp_path) -> None: - """Config with bias.enabled=True is accepted.""" - gpkg = tmp_path / "test.gpkg" - gpkg.touch() - adj = tmp_path / "adj" - adj.mkdir() - - d = _minimal_config_dict() - d["data_sources"]["geospatial_fabric_gpkg"] = str(gpkg) - d["data_sources"]["conus_adjacency"] = str(adj) - d["bias"] = {"enabled": True} - - config = Config(**d) - assert config.bias.enabled is True - assert config.bias.phi_inputs == "monthly" - - def test_bias_extra_fields_rejected(self) -> None: - """extra='forbid' rejects unknown fields in BiasCorrection.""" - with pytest.raises(ValidationError): - BiasCorrection(enabled=True, unknown_field="bad") - - def test_bias_invalid_phi_inputs_rejected(self) -> None: - """Invalid phi_inputs string raises ValidationError.""" - with pytest.raises(ValidationError): - BiasCorrection(phi_inputs="invalid_mode") - - def test_forcing_requires_forcing_var(self) -> None: - """phi_inputs='forcing' without forcing_var raises.""" - with pytest.raises(ValidationError, match="forcing_var"): - BiasCorrection(phi_inputs="forcing", forcing_var=None) - - def test_forcing_with_var_valid(self) -> None: - """phi_inputs='forcing' with forcing_var is accepted.""" - bc = BiasCorrection(phi_inputs="forcing", forcing_var="precip") - assert bc.forcing_var == "precip" - - def test_validate_config_with_bias(self, tmp_path) -> None: - """validate_config round-trip with bias section.""" - gpkg = tmp_path / "test.gpkg" - gpkg.touch() - adj = tmp_path / "adj" - adj.mkdir() - - d = _minimal_config_dict(device="cpu") - d["data_sources"]["geospatial_fabric_gpkg"] = str(gpkg) - d["data_sources"]["conus_adjacency"] = str(adj) - d["bias"] = {"enabled": True, "phi_inputs": "static"} - - dc = OmegaConf.create(d) - config = validate_config(dc, save_config=False) - assert config.bias.enabled is True - assert config.bias.phi_inputs == "static" diff --git a/tests/validation/test_losses.py b/tests/validation/test_losses.py deleted file mode 100644 index 3696fd8..0000000 --- a/tests/validation/test_losses.py +++ /dev/null @@ -1,91 +0,0 @@ -"""Tests for ddr.validation.losses — differentiable loss functions.""" - -import torch - -from ddr.validation.losses import kge_loss, mass_balance_loss - - -class TestMassBalanceLoss: - """Test mass balance (ρ) loss.""" - - def test_perfect_volume_match(self) -> None: - """ρ=1 for identical series → loss=0.""" - pred = torch.tensor([[1.0, 2.0, 3.0, 4.0]]) - target = torch.tensor([[1.0, 2.0, 3.0, 4.0]]) - loss = mass_balance_loss(pred, target) - assert loss.item() < 1e-10 - - def test_volume_mismatch(self) -> None: - """ρ≠1 → loss>0.""" - pred = torch.tensor([[2.0, 4.0, 6.0, 8.0]]) # 2x target volume - target = torch.tensor([[1.0, 2.0, 3.0, 4.0]]) - loss = mass_balance_loss(pred, target) - assert loss.item() > 0 - - def test_batch_dimension(self) -> None: - """Loss averages over multiple gauges.""" - pred = torch.tensor([[1.0, 2.0], [3.0, 6.0]]) # G=2 - target = torch.tensor([[1.0, 2.0], [3.0, 6.0]]) - loss = mass_balance_loss(pred, target) - assert loss.item() < 1e-10 - - def test_gradient_is_finite(self) -> None: - pred = torch.tensor([[1.0, 2.0, 3.0]], requires_grad=True) - target = torch.tensor([[2.0, 3.0, 4.0]]) - loss = mass_balance_loss(pred, target) - loss.backward() - assert pred.grad is not None - assert torch.isfinite(pred.grad).all() - - def test_gradient_does_not_flow_through_mc(self) -> None: - """mass_balance_loss on pre-MC q_corrected must not grad MC params.""" - mc_layer = torch.nn.Linear(4, 4, bias=False) - q_corrected = torch.tensor([[1.0, 2.0, 3.0, 4.0]], requires_grad=True) - target = torch.tensor([[2.0, 3.0, 4.0, 5.0]]) - _mc_output = mc_layer(q_corrected) # MC processes but loss ignores - loss = mass_balance_loss(q_corrected, target) - loss.backward() - assert mc_layer.weight.grad is None, "MC should not receive gradients from mass_balance_loss" - assert q_corrected.grad is not None, "q_corrected should receive gradients" - - -class TestKgeLoss: - """Test KGE loss.""" - - def test_perfect_prediction(self) -> None: - """Identical pred and target → loss ≈ eps (near zero).""" - pred = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]]) - target = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]]) - loss = kge_loss(pred, target) - assert loss.item() < 0.01 - - def test_poor_prediction(self) -> None: - """Uncorrelated series → loss > 0.""" - pred = torch.tensor([[5.0, 4.0, 3.0, 2.0, 1.0]]) - target = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]]) - loss = kge_loss(pred, target) - assert loss.item() > 0.1 - - def test_gradient_is_finite(self) -> None: - pred = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]], requires_grad=True) - target = torch.tensor([[2.0, 3.0, 4.0, 5.0, 6.0]]) - loss = kge_loss(pred, target) - loss.backward() - assert pred.grad is not None - assert torch.isfinite(pred.grad).all() - - def test_constant_prediction_no_nan(self) -> None: - """Constant prediction (σ=0) should not produce NaN.""" - pred = torch.tensor([[3.0, 3.0, 3.0, 3.0, 3.0]], requires_grad=True) - target = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]]) - loss = kge_loss(pred, target) - assert torch.isfinite(loss) - loss.backward() - assert torch.isfinite(pred.grad).all() - - def test_batch_dimension(self) -> None: - """Loss averages over multiple gauges.""" - pred = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) # G=2 - target = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) - loss = kge_loss(pred, target) - assert loss.item() < 0.01 From 25ad05f566058a69f87101c88591a7b5cf60678f Mon Sep 17 00:00:00 2001 From: Tadd Bindas Date: Sat, 14 Mar 2026 10:26:39 -0400 Subject: [PATCH 2/2] reaffirmed benchmark for merit --- scripts/train.py | 13 +++- src/ddr/routing/mmc.py | 122 ++++++++++++++++++++----------- src/ddr/routing/torch_mc.py | 13 ++-- src/ddr/validation/configs.py | 6 +- tests/routing/test_mmc.py | 10 +-- tests/routing/test_utils.py | 2 +- tests/validation/test_configs.py | 5 +- 7 files changed, 104 insertions(+), 67 deletions(-) diff --git a/scripts/train.py b/scripts/train.py index fcbe9d0..99a53ef 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -8,7 +8,6 @@ import torch from hydra.core.hydra_config import HydraConfig from omegaconf import DictConfig -from torch.nn.functional import mse_loss from torch.utils.data import DataLoader, RandomSampler from ddr import ddr_functions, dmc, kan, streamflow @@ -92,7 +91,7 @@ def train(cfg: Config, flow: streamflow, routing_model: dmc, nn: kan) -> None: filtered_predictions = daily_runoff[~np_nan_mask] - loss = mse_loss( + loss = torch.nn.functional.l1_loss( input=filtered_predictions.transpose(0, 1)[cfg.experiment.warmup :].unsqueeze(2), target=filtered_observations.transpose(0, 1)[cfg.experiment.warmup :].unsqueeze(2), ) @@ -120,8 +119,7 @@ def train(cfg: Config, flow: streamflow, routing_model: dmc, nn: kan) -> None: param_map = { "n": routing_model.n, "q_spatial": routing_model.q_spatial, - "top_width": routing_model.top_width, - "side_slope": routing_model.side_slope, + "p_spatial": routing_model.p_spatial, } for param_name in cfg.kan.learnable_parameters: param_tensor = param_map.get(param_name) @@ -155,6 +153,13 @@ def train(cfg: Config, flow: streamflow, routing_model: dmc, nn: kan) -> None: saved_model_path=cfg.params.save_path / "saved_models", ) + # Free autograd graph and routing engine tensors to prevent OOM + del dmc_output, daily_runoff, streamflow_predictions, loss + del filtered_predictions, filtered_observations + routing_model.routing_engine.q_prime = None + routing_model.routing_engine._discharge_t = None + torch.cuda.empty_cache() + @hydra.main( version_base="1.3", diff --git a/src/ddr/routing/mmc.py b/src/ddr/routing/mmc.py index 7802839..b51fd66 100644 --- a/src/ddr/routing/mmc.py +++ b/src/ddr/routing/mmc.py @@ -71,36 +71,68 @@ def _log_base_q(x: torch.Tensor, q: float) -> torch.Tensor: return torch.log(x) / torch.log(torch.tensor(q, dtype=x.dtype)) +def _apply_data_override(derived: torch.Tensor, data: torch.Tensor | None) -> torch.Tensor: + """Override derived values with observed data where available. + + Three cases: + 1. data is None or empty -> return derived (MERIT without observed geometry) + 2. data has no NaN -> return data (Lynker full coverage) + 3. data has NaN -> blend: data where valid, derived where NaN (partial coverage) + + Parameters + ---------- + derived : torch.Tensor + Values derived from the power law + data : torch.Tensor or None + Observed data, possibly containing NaN + + Returns + ------- + torch.Tensor + Blended result + """ + if data is None or data.numel() == 0: + return derived + nan_mask = torch.isnan(data) + if not nan_mask.any(): + return data + return torch.where(~nan_mask, data, derived) + + def _get_trapezoid_velocity( q_t: torch.Tensor, _n: torch.Tensor, - top_width: torch.Tensor, - side_slope: torch.Tensor, _s0: torch.Tensor, p_spatial: torch.Tensor, _q_spatial: torch.Tensor, + data_top_width: torch.Tensor | None, + data_side_slope: torch.Tensor | None, velocity_lb: torch.Tensor, depth_lb: torch.Tensor, _btm_width_lb: torch.Tensor, -) -> torch.Tensor: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Calculate flow velocity using Manning's equation for trapezoidal channels. + Derives top_width and side_slope per-timestep from the Leopold & Maddock + power law (top_width = p * depth^q), then overrides with observed data + where available. + Parameters ---------- q_t : torch.Tensor Discharge at time t _n : torch.Tensor Manning's roughness coefficient - top_width : torch.Tensor - Top width of channel - side_slope : torch.Tensor - Side slope of channel (z:1, z horizontal : 1 vertical) _s0 : torch.Tensor Channel slope p_spatial : torch.Tensor - Spatial parameter p + Leopold & Maddock width coefficient _q_spatial : torch.Tensor - Spatial parameter q + Leopold & Maddock width-depth exponent + data_top_width : torch.Tensor or None + Observed top width data for override (Lynker/SWOT), or None/empty + data_side_slope : torch.Tensor or None + Observed side slope data for override (Lynker/SWOT), or None/empty velocity_lb : torch.Tensor Lower bound for velocity depth_lb : torch.Tensor @@ -110,19 +142,27 @@ def _get_trapezoid_velocity( Returns ------- - torch.Tensor - Flow velocity + tuple[torch.Tensor, torch.Tensor, torch.Tensor] + (celerity, top_width, side_slope) """ - numerator = q_t * _n * (_q_spatial + 1) + q_eps = _q_spatial + 1e-6 + numerator = q_t * _n * (q_eps + 1) denominator = p_spatial * torch.pow(_s0, 0.5) depth = torch.clamp( torch.pow( torch.div(numerator, denominator + 1e-8), - torch.div(3.0, 5.0 + 3.0 * _q_spatial), + torch.div(3.0, 5.0 + 3.0 * q_eps), ), min=depth_lb, ) + # Derive top_width and side_slope from Leopold & Maddock power law + top_width = p_spatial * torch.pow(depth, q_eps) + top_width = _apply_data_override(top_width, data_top_width) + + side_slope = torch.clamp(top_width * q_eps / (2 * depth), min=0.5, max=50.0) + side_slope = _apply_data_override(side_slope, data_side_slope) + # For z:1 side slopes (z horizontal : 1 vertical) _bottom_width = top_width - (2 * side_slope * depth) bottom_width = torch.clamp(_bottom_width, min=_btm_width_lb) @@ -140,7 +180,7 @@ def _get_trapezoid_velocity( v = torch.div(1, _n) * torch.pow(R, (2 / 3)) * torch.pow(_s0, (1 / 2)) c_ = torch.clamp(v, min=velocity_lb, max=torch.tensor(15.0, device=v.device)) c = c_ * 5 / 3 - return c + return c, top_width, side_slope class MuskingumCunge: @@ -186,8 +226,10 @@ def __init__(self, cfg: Config, device: str | torch.device = "cpu") -> None: self.routing_dataclass: Any = None self.length: torch.Tensor | None = None self.slope: torch.Tensor | None = None - self.top_width: torch.Tensor | None = None - self.side_slope: torch.Tensor | None = None + self.top_width: torch.Tensor | None = None # Derived per-timestep in route_timestep() + self.side_slope: torch.Tensor | None = None # Derived per-timestep in route_timestep() + self._data_top_width: torch.Tensor | None = None # Observed data for override + self._data_side_slope: torch.Tensor | None = None # Observed data for override self.x_storage: torch.Tensor | None = None self.observations: Any = None self.output_indices: list[Any] | None = None @@ -261,6 +303,16 @@ def _set_network_context(self, routing_dataclass: Any, streamflow: torch.Tensor) ) self.x_storage = routing_dataclass.x.to(self.device).to(torch.float32) + # Store observed geometry for data override in _get_trapezoid_velocity + if routing_dataclass.top_width is not None and routing_dataclass.top_width.numel() > 0: + self._data_top_width = routing_dataclass.top_width.to(self.device).to(torch.float32) + else: + self._data_top_width = None + if routing_dataclass.side_slope is not None and routing_dataclass.side_slope.numel() > 0: + self._data_side_slope = routing_dataclass.side_slope.to(self.device).to(torch.float32) + else: + self._data_side_slope = None + self.q_prime = streamflow.to(self.device) if routing_dataclass.flow_scale is not None: @@ -282,23 +334,13 @@ def _denormalize_spatial_parameters(self, spatial_parameters: dict[str, torch.Te log_space="q_spatial" in log_space_params, ) - routing_dataclass = self.routing_dataclass - if routing_dataclass.top_width.numel() == 0: - self.top_width = denormalize( - value=spatial_parameters["top_width"], - bounds=self.parameter_bounds["top_width"], - log_space="top_width" in log_space_params, + # p_spatial: use learned value if in spatial_parameters, otherwise use default + if "p_spatial" in spatial_parameters and "p_spatial" in self.parameter_bounds: + self.p_spatial = denormalize( + value=spatial_parameters["p_spatial"], + bounds=self.parameter_bounds["p_spatial"], + log_space="p_spatial" in log_space_params, ) - else: - self.top_width = routing_dataclass.top_width.to(self.device).to(torch.float32) - if routing_dataclass.side_slope.numel() == 0: - self.side_slope = denormalize( - value=spatial_parameters["side_slope"], - bounds=self.parameter_bounds["side_slope"], - log_space="side_slope" in log_space_params, - ) - else: - self.side_slope = routing_dataclass.side_slope.to(self.device).to(torch.float32) def _init_discharge_state(self, carry_state: bool) -> None: """Cold-start via topological accumulation, or carry from previous batch.""" @@ -373,15 +415,15 @@ def forward(self) -> torch.Tensor: dtype=torch.float32, ) - # Vectorized initial values + # Vectorized initial values (avoid double in-place write for autograd safety) gathered = self._discharge_t[self._flat_indices] - output[:, 0] = torch.scatter_add( + initial = torch.scatter_add( input=self._scatter_input, dim=0, index=self._group_ids, src=gathered, ) - output[:, 0] = torch.clamp(output[:, 0], min=self.discharge_lb) + output[:, 0] = torch.clamp(initial, min=self.discharge_lb) # Route through time series for timestep in tqdm( @@ -478,8 +520,6 @@ def route_timestep( if ( self._discharge_t is None or self.n is None - or self.top_width is None - or self.side_slope is None or self.slope is None or self.q_spatial is None or self.length is None @@ -488,15 +528,15 @@ def route_timestep( ): raise ValueError("Required attributes not set. Call setup_inputs() first.") - # Calculate velocity using internal routing_dataclass data - velocity = _get_trapezoid_velocity( + # Calculate velocity and derive top_width/side_slope from Leopold & Maddock power law + velocity, self.top_width, self.side_slope = _get_trapezoid_velocity( q_t=self._discharge_t, _n=self.n, - top_width=self.top_width, - side_slope=self.side_slope, _s0=self.slope, p_spatial=self.p_spatial, _q_spatial=self.q_spatial, + data_top_width=self._data_top_width, + data_side_slope=self._data_side_slope, velocity_lb=self.velocity_lb, depth_lb=self.depth_lb, _btm_width_lb=self.bottom_width_lb, diff --git a/src/ddr/routing/torch_mc.py b/src/ddr/routing/torch_mc.py index ac89068..b6034fd 100644 --- a/src/ddr/routing/torch_mc.py +++ b/src/ddr/routing/torch_mc.py @@ -180,13 +180,16 @@ def forward(self, **kwargs: Any) -> dict[str, torch.Tensor]: self.network = self.routing_engine.network self.n = self.routing_engine.n self.q_spatial = self.routing_engine.q_spatial - self.top_width = self.routing_engine.top_width - self.side_slope = self.routing_engine.side_slope + self.p_spatial = self.routing_engine.p_spatial self._discharge_t = self.routing_engine._discharge_t # Perform routing output = self.routing_engine.forward() + # Read back per-timestep derived geometry from the engine + self.top_width = self.routing_engine.top_width + self.side_slope = self.routing_engine.side_slope + # Update discharge state for compatibility self._discharge_t = self.routing_engine._discharge_t @@ -195,6 +198,8 @@ def forward(self, **kwargs: Any) -> dict[str, torch.Tensor]: self.n.retain_grad() if self.q_spatial is not None: self.q_spatial.retain_grad() + if self.p_spatial is not None and self.p_spatial.requires_grad: + self.p_spatial.retain_grad() if self._discharge_t is not None: self._discharge_t.retain_grad() @@ -205,10 +210,6 @@ def forward(self, **kwargs: Any) -> dict[str, torch.Tensor]: spatial_params["n"].retain_grad() if "q_spatial" in spatial_params: spatial_params["q_spatial"].retain_grad() - if "top_width" in spatial_params: - spatial_params["top_width"].retain_grad() - if "side_slope" in spatial_params: - spatial_params["side_slope"].retain_grad() if "p_spatial" in spatial_params: spatial_params["p_spatial"].retain_grad() diff --git a/src/ddr/validation/configs.py b/src/ddr/validation/configs.py index 61c773c..301ad12 100644 --- a/src/ddr/validation/configs.py +++ b/src/ddr/validation/configs.py @@ -92,15 +92,13 @@ class Params(BaseModel): default_factory=lambda: { "n": [0.015, 0.25], # (m⁻¹/³s) "q_spatial": [0.0, 1.0], # 0 = rectangular, 1 = triangular - "top_width": [1.0, 5000.0], # Log-space (m) - "side_slope": [0.5, 50.0], # H:V ratio Log-space (-) + "p_spatial": [1.0, 200.0], # Leopold & Maddock width coefficient, Log-space (m) }, description="The parameter space bounds [min, max] to project learned physical values to", ) log_space_parameters: list[str] = Field( default_factory=lambda: [ - "top_width", - "side_slope", + "p_spatial", ], description="Parameters to denormalize in log-space for right-skewed distributions", ) diff --git a/tests/routing/test_mmc.py b/tests/routing/test_mmc.py index 2dd8e5d..4bb1d3c 100644 --- a/tests/routing/test_mmc.py +++ b/tests/routing/test_mmc.py @@ -60,14 +60,11 @@ def test_setup_inputs_basic(self) -> None: # Check spatial attributes assert mc.length is not None assert mc.slope is not None - assert mc.top_width is not None - assert mc.side_slope is not None assert mc.x_storage is not None assert_tensor_properties(mc.length, (10,)) assert_tensor_properties(mc.slope, (10,)) - assert_tensor_properties(mc.top_width, (10,)) - assert_tensor_properties(mc.side_slope, (10,)) assert_tensor_properties(mc.x_storage, (10,)) + # top_width and side_slope are now derived per-timestep in route_timestep() # Check parameter denormalization assert mc.n is not None @@ -115,16 +112,13 @@ def test_setup_inputs_device_conversion(self) -> None: # Check that all tensors are on correct device assert mc.length is not None assert mc.slope is not None - assert mc.top_width is not None - assert mc.side_slope is not None assert mc.x_storage is not None assert mc.q_prime is not None assert mc.length.device.type == "cpu" assert mc.slope.device.type == "cpu" - assert mc.top_width.device.type == "cpu" - assert mc.side_slope.device.type == "cpu" assert mc.x_storage.device.type == "cpu" assert mc.q_prime.device.type == "cpu" + # top_width and side_slope are derived per-timestep in route_timestep() class TestMuskingumCungeSparseOperations: diff --git a/tests/routing/test_utils.py b/tests/routing/test_utils.py index 9c4ba17..53a0489 100644 --- a/tests/routing/test_utils.py +++ b/tests/routing/test_utils.py @@ -42,7 +42,7 @@ def create_mock_config() -> Config: "gages": "mock.csv", }, "params": { - "parameter_ranges": {"n": [0.01, 0.1], "q_spatial": [0.1, 0.9]}, + "parameter_ranges": {"n": [0.01, 0.1], "q_spatial": [0.1, 0.9], "p_spatial": [1.0, 200.0]}, "defaults": {"p_spatial": 1.0}, "attribute_minimums": { "velocity": 0.1, diff --git a/tests/validation/test_configs.py b/tests/validation/test_configs.py index 9610452..2b235e3 100644 --- a/tests/validation/test_configs.py +++ b/tests/validation/test_configs.py @@ -136,12 +136,11 @@ def test_parameter_ranges_defaults(self) -> None: p = Params() assert "n" in p.parameter_ranges assert "q_spatial" in p.parameter_ranges - assert "top_width" in p.parameter_ranges - assert "side_slope" in p.parameter_ranges + assert "p_spatial" in p.parameter_ranges def test_log_space_parameters_default(self) -> None: p = Params() - assert p.log_space_parameters == ["top_width", "side_slope"] + assert p.log_space_parameters == ["p_spatial"] class TestSetSeed: