diff --git a/scripts/train.py b/scripts/train.py index fcbe9d0..99a53ef 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -8,7 +8,6 @@ import torch from hydra.core.hydra_config import HydraConfig from omegaconf import DictConfig -from torch.nn.functional import mse_loss from torch.utils.data import DataLoader, RandomSampler from ddr import ddr_functions, dmc, kan, streamflow @@ -92,7 +91,7 @@ def train(cfg: Config, flow: streamflow, routing_model: dmc, nn: kan) -> None: filtered_predictions = daily_runoff[~np_nan_mask] - loss = mse_loss( + loss = torch.nn.functional.l1_loss( input=filtered_predictions.transpose(0, 1)[cfg.experiment.warmup :].unsqueeze(2), target=filtered_observations.transpose(0, 1)[cfg.experiment.warmup :].unsqueeze(2), ) @@ -120,8 +119,7 @@ def train(cfg: Config, flow: streamflow, routing_model: dmc, nn: kan) -> None: param_map = { "n": routing_model.n, "q_spatial": routing_model.q_spatial, - "top_width": routing_model.top_width, - "side_slope": routing_model.side_slope, + "p_spatial": routing_model.p_spatial, } for param_name in cfg.kan.learnable_parameters: param_tensor = param_map.get(param_name) @@ -155,6 +153,13 @@ def train(cfg: Config, flow: streamflow, routing_model: dmc, nn: kan) -> None: saved_model_path=cfg.params.save_path / "saved_models", ) + # Free autograd graph and routing engine tensors to prevent OOM + del dmc_output, daily_runoff, streamflow_predictions, loss + del filtered_predictions, filtered_observations + routing_model.routing_engine.q_prime = None + routing_model.routing_engine._discharge_t = None + torch.cuda.empty_cache() + @hydra.main( version_base="1.3", diff --git a/src/ddr/routing/mmc.py b/src/ddr/routing/mmc.py index 7802839..b51fd66 100644 --- a/src/ddr/routing/mmc.py +++ b/src/ddr/routing/mmc.py @@ -71,36 +71,68 @@ def _log_base_q(x: torch.Tensor, q: float) -> torch.Tensor: return torch.log(x) / torch.log(torch.tensor(q, dtype=x.dtype)) +def _apply_data_override(derived: torch.Tensor, data: torch.Tensor | None) -> torch.Tensor: + """Override derived values with observed data where available. + + Three cases: + 1. data is None or empty -> return derived (MERIT without observed geometry) + 2. data has no NaN -> return data (Lynker full coverage) + 3. data has NaN -> blend: data where valid, derived where NaN (partial coverage) + + Parameters + ---------- + derived : torch.Tensor + Values derived from the power law + data : torch.Tensor or None + Observed data, possibly containing NaN + + Returns + ------- + torch.Tensor + Blended result + """ + if data is None or data.numel() == 0: + return derived + nan_mask = torch.isnan(data) + if not nan_mask.any(): + return data + return torch.where(~nan_mask, data, derived) + + def _get_trapezoid_velocity( q_t: torch.Tensor, _n: torch.Tensor, - top_width: torch.Tensor, - side_slope: torch.Tensor, _s0: torch.Tensor, p_spatial: torch.Tensor, _q_spatial: torch.Tensor, + data_top_width: torch.Tensor | None, + data_side_slope: torch.Tensor | None, velocity_lb: torch.Tensor, depth_lb: torch.Tensor, _btm_width_lb: torch.Tensor, -) -> torch.Tensor: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Calculate flow velocity using Manning's equation for trapezoidal channels. + Derives top_width and side_slope per-timestep from the Leopold & Maddock + power law (top_width = p * depth^q), then overrides with observed data + where available. + Parameters ---------- q_t : torch.Tensor Discharge at time t _n : torch.Tensor Manning's roughness coefficient - top_width : torch.Tensor - Top width of channel - side_slope : torch.Tensor - Side slope of channel (z:1, z horizontal : 1 vertical) _s0 : torch.Tensor Channel slope p_spatial : torch.Tensor - Spatial parameter p + Leopold & Maddock width coefficient _q_spatial : torch.Tensor - Spatial parameter q + Leopold & Maddock width-depth exponent + data_top_width : torch.Tensor or None + Observed top width data for override (Lynker/SWOT), or None/empty + data_side_slope : torch.Tensor or None + Observed side slope data for override (Lynker/SWOT), or None/empty velocity_lb : torch.Tensor Lower bound for velocity depth_lb : torch.Tensor @@ -110,19 +142,27 @@ def _get_trapezoid_velocity( Returns ------- - torch.Tensor - Flow velocity + tuple[torch.Tensor, torch.Tensor, torch.Tensor] + (celerity, top_width, side_slope) """ - numerator = q_t * _n * (_q_spatial + 1) + q_eps = _q_spatial + 1e-6 + numerator = q_t * _n * (q_eps + 1) denominator = p_spatial * torch.pow(_s0, 0.5) depth = torch.clamp( torch.pow( torch.div(numerator, denominator + 1e-8), - torch.div(3.0, 5.0 + 3.0 * _q_spatial), + torch.div(3.0, 5.0 + 3.0 * q_eps), ), min=depth_lb, ) + # Derive top_width and side_slope from Leopold & Maddock power law + top_width = p_spatial * torch.pow(depth, q_eps) + top_width = _apply_data_override(top_width, data_top_width) + + side_slope = torch.clamp(top_width * q_eps / (2 * depth), min=0.5, max=50.0) + side_slope = _apply_data_override(side_slope, data_side_slope) + # For z:1 side slopes (z horizontal : 1 vertical) _bottom_width = top_width - (2 * side_slope * depth) bottom_width = torch.clamp(_bottom_width, min=_btm_width_lb) @@ -140,7 +180,7 @@ def _get_trapezoid_velocity( v = torch.div(1, _n) * torch.pow(R, (2 / 3)) * torch.pow(_s0, (1 / 2)) c_ = torch.clamp(v, min=velocity_lb, max=torch.tensor(15.0, device=v.device)) c = c_ * 5 / 3 - return c + return c, top_width, side_slope class MuskingumCunge: @@ -186,8 +226,10 @@ def __init__(self, cfg: Config, device: str | torch.device = "cpu") -> None: self.routing_dataclass: Any = None self.length: torch.Tensor | None = None self.slope: torch.Tensor | None = None - self.top_width: torch.Tensor | None = None - self.side_slope: torch.Tensor | None = None + self.top_width: torch.Tensor | None = None # Derived per-timestep in route_timestep() + self.side_slope: torch.Tensor | None = None # Derived per-timestep in route_timestep() + self._data_top_width: torch.Tensor | None = None # Observed data for override + self._data_side_slope: torch.Tensor | None = None # Observed data for override self.x_storage: torch.Tensor | None = None self.observations: Any = None self.output_indices: list[Any] | None = None @@ -261,6 +303,16 @@ def _set_network_context(self, routing_dataclass: Any, streamflow: torch.Tensor) ) self.x_storage = routing_dataclass.x.to(self.device).to(torch.float32) + # Store observed geometry for data override in _get_trapezoid_velocity + if routing_dataclass.top_width is not None and routing_dataclass.top_width.numel() > 0: + self._data_top_width = routing_dataclass.top_width.to(self.device).to(torch.float32) + else: + self._data_top_width = None + if routing_dataclass.side_slope is not None and routing_dataclass.side_slope.numel() > 0: + self._data_side_slope = routing_dataclass.side_slope.to(self.device).to(torch.float32) + else: + self._data_side_slope = None + self.q_prime = streamflow.to(self.device) if routing_dataclass.flow_scale is not None: @@ -282,23 +334,13 @@ def _denormalize_spatial_parameters(self, spatial_parameters: dict[str, torch.Te log_space="q_spatial" in log_space_params, ) - routing_dataclass = self.routing_dataclass - if routing_dataclass.top_width.numel() == 0: - self.top_width = denormalize( - value=spatial_parameters["top_width"], - bounds=self.parameter_bounds["top_width"], - log_space="top_width" in log_space_params, + # p_spatial: use learned value if in spatial_parameters, otherwise use default + if "p_spatial" in spatial_parameters and "p_spatial" in self.parameter_bounds: + self.p_spatial = denormalize( + value=spatial_parameters["p_spatial"], + bounds=self.parameter_bounds["p_spatial"], + log_space="p_spatial" in log_space_params, ) - else: - self.top_width = routing_dataclass.top_width.to(self.device).to(torch.float32) - if routing_dataclass.side_slope.numel() == 0: - self.side_slope = denormalize( - value=spatial_parameters["side_slope"], - bounds=self.parameter_bounds["side_slope"], - log_space="side_slope" in log_space_params, - ) - else: - self.side_slope = routing_dataclass.side_slope.to(self.device).to(torch.float32) def _init_discharge_state(self, carry_state: bool) -> None: """Cold-start via topological accumulation, or carry from previous batch.""" @@ -373,15 +415,15 @@ def forward(self) -> torch.Tensor: dtype=torch.float32, ) - # Vectorized initial values + # Vectorized initial values (avoid double in-place write for autograd safety) gathered = self._discharge_t[self._flat_indices] - output[:, 0] = torch.scatter_add( + initial = torch.scatter_add( input=self._scatter_input, dim=0, index=self._group_ids, src=gathered, ) - output[:, 0] = torch.clamp(output[:, 0], min=self.discharge_lb) + output[:, 0] = torch.clamp(initial, min=self.discharge_lb) # Route through time series for timestep in tqdm( @@ -478,8 +520,6 @@ def route_timestep( if ( self._discharge_t is None or self.n is None - or self.top_width is None - or self.side_slope is None or self.slope is None or self.q_spatial is None or self.length is None @@ -488,15 +528,15 @@ def route_timestep( ): raise ValueError("Required attributes not set. Call setup_inputs() first.") - # Calculate velocity using internal routing_dataclass data - velocity = _get_trapezoid_velocity( + # Calculate velocity and derive top_width/side_slope from Leopold & Maddock power law + velocity, self.top_width, self.side_slope = _get_trapezoid_velocity( q_t=self._discharge_t, _n=self.n, - top_width=self.top_width, - side_slope=self.side_slope, _s0=self.slope, p_spatial=self.p_spatial, _q_spatial=self.q_spatial, + data_top_width=self._data_top_width, + data_side_slope=self._data_side_slope, velocity_lb=self.velocity_lb, depth_lb=self.depth_lb, _btm_width_lb=self.bottom_width_lb, diff --git a/src/ddr/routing/torch_mc.py b/src/ddr/routing/torch_mc.py index ac89068..b6034fd 100644 --- a/src/ddr/routing/torch_mc.py +++ b/src/ddr/routing/torch_mc.py @@ -180,13 +180,16 @@ def forward(self, **kwargs: Any) -> dict[str, torch.Tensor]: self.network = self.routing_engine.network self.n = self.routing_engine.n self.q_spatial = self.routing_engine.q_spatial - self.top_width = self.routing_engine.top_width - self.side_slope = self.routing_engine.side_slope + self.p_spatial = self.routing_engine.p_spatial self._discharge_t = self.routing_engine._discharge_t # Perform routing output = self.routing_engine.forward() + # Read back per-timestep derived geometry from the engine + self.top_width = self.routing_engine.top_width + self.side_slope = self.routing_engine.side_slope + # Update discharge state for compatibility self._discharge_t = self.routing_engine._discharge_t @@ -195,6 +198,8 @@ def forward(self, **kwargs: Any) -> dict[str, torch.Tensor]: self.n.retain_grad() if self.q_spatial is not None: self.q_spatial.retain_grad() + if self.p_spatial is not None and self.p_spatial.requires_grad: + self.p_spatial.retain_grad() if self._discharge_t is not None: self._discharge_t.retain_grad() @@ -205,10 +210,6 @@ def forward(self, **kwargs: Any) -> dict[str, torch.Tensor]: spatial_params["n"].retain_grad() if "q_spatial" in spatial_params: spatial_params["q_spatial"].retain_grad() - if "top_width" in spatial_params: - spatial_params["top_width"].retain_grad() - if "side_slope" in spatial_params: - spatial_params["side_slope"].retain_grad() if "p_spatial" in spatial_params: spatial_params["p_spatial"].retain_grad() diff --git a/src/ddr/validation/configs.py b/src/ddr/validation/configs.py index 61c773c..301ad12 100644 --- a/src/ddr/validation/configs.py +++ b/src/ddr/validation/configs.py @@ -92,15 +92,13 @@ class Params(BaseModel): default_factory=lambda: { "n": [0.015, 0.25], # (m⁻¹/³s) "q_spatial": [0.0, 1.0], # 0 = rectangular, 1 = triangular - "top_width": [1.0, 5000.0], # Log-space (m) - "side_slope": [0.5, 50.0], # H:V ratio Log-space (-) + "p_spatial": [1.0, 200.0], # Leopold & Maddock width coefficient, Log-space (m) }, description="The parameter space bounds [min, max] to project learned physical values to", ) log_space_parameters: list[str] = Field( default_factory=lambda: [ - "top_width", - "side_slope", + "p_spatial", ], description="Parameters to denormalize in log-space for right-skewed distributions", ) diff --git a/tests/routing/test_mmc.py b/tests/routing/test_mmc.py index 2dd8e5d..4bb1d3c 100644 --- a/tests/routing/test_mmc.py +++ b/tests/routing/test_mmc.py @@ -60,14 +60,11 @@ def test_setup_inputs_basic(self) -> None: # Check spatial attributes assert mc.length is not None assert mc.slope is not None - assert mc.top_width is not None - assert mc.side_slope is not None assert mc.x_storage is not None assert_tensor_properties(mc.length, (10,)) assert_tensor_properties(mc.slope, (10,)) - assert_tensor_properties(mc.top_width, (10,)) - assert_tensor_properties(mc.side_slope, (10,)) assert_tensor_properties(mc.x_storage, (10,)) + # top_width and side_slope are now derived per-timestep in route_timestep() # Check parameter denormalization assert mc.n is not None @@ -115,16 +112,13 @@ def test_setup_inputs_device_conversion(self) -> None: # Check that all tensors are on correct device assert mc.length is not None assert mc.slope is not None - assert mc.top_width is not None - assert mc.side_slope is not None assert mc.x_storage is not None assert mc.q_prime is not None assert mc.length.device.type == "cpu" assert mc.slope.device.type == "cpu" - assert mc.top_width.device.type == "cpu" - assert mc.side_slope.device.type == "cpu" assert mc.x_storage.device.type == "cpu" assert mc.q_prime.device.type == "cpu" + # top_width and side_slope are derived per-timestep in route_timestep() class TestMuskingumCungeSparseOperations: diff --git a/tests/routing/test_utils.py b/tests/routing/test_utils.py index 9c4ba17..53a0489 100644 --- a/tests/routing/test_utils.py +++ b/tests/routing/test_utils.py @@ -42,7 +42,7 @@ def create_mock_config() -> Config: "gages": "mock.csv", }, "params": { - "parameter_ranges": {"n": [0.01, 0.1], "q_spatial": [0.1, 0.9]}, + "parameter_ranges": {"n": [0.01, 0.1], "q_spatial": [0.1, 0.9], "p_spatial": [1.0, 200.0]}, "defaults": {"p_spatial": 1.0}, "attribute_minimums": { "velocity": 0.1, diff --git a/tests/validation/test_configs.py b/tests/validation/test_configs.py index 9610452..2b235e3 100644 --- a/tests/validation/test_configs.py +++ b/tests/validation/test_configs.py @@ -136,12 +136,11 @@ def test_parameter_ranges_defaults(self) -> None: p = Params() assert "n" in p.parameter_ranges assert "q_spatial" in p.parameter_ranges - assert "top_width" in p.parameter_ranges - assert "side_slope" in p.parameter_ranges + assert "p_spatial" in p.parameter_ranges def test_log_space_parameters_default(self) -> None: p = Params() - assert p.log_space_parameters == ["top_width", "side_slope"] + assert p.log_space_parameters == ["p_spatial"] class TestSetSeed: