Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import torch
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig
from torch.nn.functional import mse_loss
from torch.utils.data import DataLoader, RandomSampler

from ddr import ddr_functions, dmc, kan, streamflow
Expand Down Expand Up @@ -92,7 +91,7 @@ def train(cfg: Config, flow: streamflow, routing_model: dmc, nn: kan) -> None:

filtered_predictions = daily_runoff[~np_nan_mask]

loss = mse_loss(
loss = torch.nn.functional.l1_loss(
input=filtered_predictions.transpose(0, 1)[cfg.experiment.warmup :].unsqueeze(2),
target=filtered_observations.transpose(0, 1)[cfg.experiment.warmup :].unsqueeze(2),
)
Expand Down Expand Up @@ -120,8 +119,7 @@ def train(cfg: Config, flow: streamflow, routing_model: dmc, nn: kan) -> None:
param_map = {
"n": routing_model.n,
"q_spatial": routing_model.q_spatial,
"top_width": routing_model.top_width,
"side_slope": routing_model.side_slope,
"p_spatial": routing_model.p_spatial,
}
for param_name in cfg.kan.learnable_parameters:
param_tensor = param_map.get(param_name)
Expand Down Expand Up @@ -155,6 +153,13 @@ def train(cfg: Config, flow: streamflow, routing_model: dmc, nn: kan) -> None:
saved_model_path=cfg.params.save_path / "saved_models",
)

# Free autograd graph and routing engine tensors to prevent OOM
del dmc_output, daily_runoff, streamflow_predictions, loss
del filtered_predictions, filtered_observations
routing_model.routing_engine.q_prime = None
routing_model.routing_engine._discharge_t = None
torch.cuda.empty_cache()


@hydra.main(
version_base="1.3",
Expand Down
122 changes: 81 additions & 41 deletions src/ddr/routing/mmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,36 +71,68 @@ def _log_base_q(x: torch.Tensor, q: float) -> torch.Tensor:
return torch.log(x) / torch.log(torch.tensor(q, dtype=x.dtype))


def _apply_data_override(derived: torch.Tensor, data: torch.Tensor | None) -> torch.Tensor:
"""Override derived values with observed data where available.

Three cases:
1. data is None or empty -> return derived (MERIT without observed geometry)
2. data has no NaN -> return data (Lynker full coverage)
3. data has NaN -> blend: data where valid, derived where NaN (partial coverage)

Parameters
----------
derived : torch.Tensor
Values derived from the power law
data : torch.Tensor or None
Observed data, possibly containing NaN

Returns
-------
torch.Tensor
Blended result
"""
if data is None or data.numel() == 0:
return derived
nan_mask = torch.isnan(data)
if not nan_mask.any():
return data
return torch.where(~nan_mask, data, derived)


def _get_trapezoid_velocity(
q_t: torch.Tensor,
_n: torch.Tensor,
top_width: torch.Tensor,
side_slope: torch.Tensor,
_s0: torch.Tensor,
p_spatial: torch.Tensor,
_q_spatial: torch.Tensor,
data_top_width: torch.Tensor | None,
data_side_slope: torch.Tensor | None,
velocity_lb: torch.Tensor,
depth_lb: torch.Tensor,
_btm_width_lb: torch.Tensor,
) -> torch.Tensor:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Calculate flow velocity using Manning's equation for trapezoidal channels.

Derives top_width and side_slope per-timestep from the Leopold & Maddock
power law (top_width = p * depth^q), then overrides with observed data
where available.

Parameters
----------
q_t : torch.Tensor
Discharge at time t
_n : torch.Tensor
Manning's roughness coefficient
top_width : torch.Tensor
Top width of channel
side_slope : torch.Tensor
Side slope of channel (z:1, z horizontal : 1 vertical)
_s0 : torch.Tensor
Channel slope
p_spatial : torch.Tensor
Spatial parameter p
Leopold & Maddock width coefficient
_q_spatial : torch.Tensor
Spatial parameter q
Leopold & Maddock width-depth exponent
data_top_width : torch.Tensor or None
Observed top width data for override (Lynker/SWOT), or None/empty
data_side_slope : torch.Tensor or None
Observed side slope data for override (Lynker/SWOT), or None/empty
velocity_lb : torch.Tensor
Lower bound for velocity
depth_lb : torch.Tensor
Expand All @@ -110,19 +142,27 @@ def _get_trapezoid_velocity(

Returns
-------
torch.Tensor
Flow velocity
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
(celerity, top_width, side_slope)
"""
numerator = q_t * _n * (_q_spatial + 1)
q_eps = _q_spatial + 1e-6
numerator = q_t * _n * (q_eps + 1)
denominator = p_spatial * torch.pow(_s0, 0.5)
depth = torch.clamp(
torch.pow(
torch.div(numerator, denominator + 1e-8),
torch.div(3.0, 5.0 + 3.0 * _q_spatial),
torch.div(3.0, 5.0 + 3.0 * q_eps),
),
min=depth_lb,
)

# Derive top_width and side_slope from Leopold & Maddock power law
top_width = p_spatial * torch.pow(depth, q_eps)
top_width = _apply_data_override(top_width, data_top_width)

side_slope = torch.clamp(top_width * q_eps / (2 * depth), min=0.5, max=50.0)
side_slope = _apply_data_override(side_slope, data_side_slope)

# For z:1 side slopes (z horizontal : 1 vertical)
_bottom_width = top_width - (2 * side_slope * depth)
bottom_width = torch.clamp(_bottom_width, min=_btm_width_lb)
Expand All @@ -140,7 +180,7 @@ def _get_trapezoid_velocity(
v = torch.div(1, _n) * torch.pow(R, (2 / 3)) * torch.pow(_s0, (1 / 2))
c_ = torch.clamp(v, min=velocity_lb, max=torch.tensor(15.0, device=v.device))
c = c_ * 5 / 3
return c
return c, top_width, side_slope


class MuskingumCunge:
Expand Down Expand Up @@ -186,8 +226,10 @@ def __init__(self, cfg: Config, device: str | torch.device = "cpu") -> None:
self.routing_dataclass: Any = None
self.length: torch.Tensor | None = None
self.slope: torch.Tensor | None = None
self.top_width: torch.Tensor | None = None
self.side_slope: torch.Tensor | None = None
self.top_width: torch.Tensor | None = None # Derived per-timestep in route_timestep()
self.side_slope: torch.Tensor | None = None # Derived per-timestep in route_timestep()
self._data_top_width: torch.Tensor | None = None # Observed data for override
self._data_side_slope: torch.Tensor | None = None # Observed data for override
self.x_storage: torch.Tensor | None = None
self.observations: Any = None
self.output_indices: list[Any] | None = None
Expand Down Expand Up @@ -261,6 +303,16 @@ def _set_network_context(self, routing_dataclass: Any, streamflow: torch.Tensor)
)
self.x_storage = routing_dataclass.x.to(self.device).to(torch.float32)

# Store observed geometry for data override in _get_trapezoid_velocity
if routing_dataclass.top_width is not None and routing_dataclass.top_width.numel() > 0:
self._data_top_width = routing_dataclass.top_width.to(self.device).to(torch.float32)
else:
self._data_top_width = None
if routing_dataclass.side_slope is not None and routing_dataclass.side_slope.numel() > 0:
self._data_side_slope = routing_dataclass.side_slope.to(self.device).to(torch.float32)
else:
self._data_side_slope = None

self.q_prime = streamflow.to(self.device)

if routing_dataclass.flow_scale is not None:
Expand All @@ -282,23 +334,13 @@ def _denormalize_spatial_parameters(self, spatial_parameters: dict[str, torch.Te
log_space="q_spatial" in log_space_params,
)

routing_dataclass = self.routing_dataclass
if routing_dataclass.top_width.numel() == 0:
self.top_width = denormalize(
value=spatial_parameters["top_width"],
bounds=self.parameter_bounds["top_width"],
log_space="top_width" in log_space_params,
# p_spatial: use learned value if in spatial_parameters, otherwise use default
if "p_spatial" in spatial_parameters and "p_spatial" in self.parameter_bounds:
self.p_spatial = denormalize(
value=spatial_parameters["p_spatial"],
bounds=self.parameter_bounds["p_spatial"],
log_space="p_spatial" in log_space_params,
)
else:
self.top_width = routing_dataclass.top_width.to(self.device).to(torch.float32)
if routing_dataclass.side_slope.numel() == 0:
self.side_slope = denormalize(
value=spatial_parameters["side_slope"],
bounds=self.parameter_bounds["side_slope"],
log_space="side_slope" in log_space_params,
)
else:
self.side_slope = routing_dataclass.side_slope.to(self.device).to(torch.float32)

def _init_discharge_state(self, carry_state: bool) -> None:
"""Cold-start via topological accumulation, or carry from previous batch."""
Expand Down Expand Up @@ -373,15 +415,15 @@ def forward(self) -> torch.Tensor:
dtype=torch.float32,
)

# Vectorized initial values
# Vectorized initial values (avoid double in-place write for autograd safety)
gathered = self._discharge_t[self._flat_indices]
output[:, 0] = torch.scatter_add(
initial = torch.scatter_add(
input=self._scatter_input,
dim=0,
index=self._group_ids,
src=gathered,
)
output[:, 0] = torch.clamp(output[:, 0], min=self.discharge_lb)
output[:, 0] = torch.clamp(initial, min=self.discharge_lb)

# Route through time series
for timestep in tqdm(
Expand Down Expand Up @@ -478,8 +520,6 @@ def route_timestep(
if (
self._discharge_t is None
or self.n is None
or self.top_width is None
or self.side_slope is None
or self.slope is None
or self.q_spatial is None
or self.length is None
Expand All @@ -488,15 +528,15 @@ def route_timestep(
):
raise ValueError("Required attributes not set. Call setup_inputs() first.")

# Calculate velocity using internal routing_dataclass data
velocity = _get_trapezoid_velocity(
# Calculate velocity and derive top_width/side_slope from Leopold & Maddock power law
velocity, self.top_width, self.side_slope = _get_trapezoid_velocity(
q_t=self._discharge_t,
_n=self.n,
top_width=self.top_width,
side_slope=self.side_slope,
_s0=self.slope,
p_spatial=self.p_spatial,
_q_spatial=self.q_spatial,
data_top_width=self._data_top_width,
data_side_slope=self._data_side_slope,
velocity_lb=self.velocity_lb,
depth_lb=self.depth_lb,
_btm_width_lb=self.bottom_width_lb,
Expand Down
13 changes: 7 additions & 6 deletions src/ddr/routing/torch_mc.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,13 +180,16 @@ def forward(self, **kwargs: Any) -> dict[str, torch.Tensor]:
self.network = self.routing_engine.network
self.n = self.routing_engine.n
self.q_spatial = self.routing_engine.q_spatial
self.top_width = self.routing_engine.top_width
self.side_slope = self.routing_engine.side_slope
self.p_spatial = self.routing_engine.p_spatial
self._discharge_t = self.routing_engine._discharge_t

# Perform routing
output = self.routing_engine.forward()

# Read back per-timestep derived geometry from the engine
self.top_width = self.routing_engine.top_width
self.side_slope = self.routing_engine.side_slope

# Update discharge state for compatibility
self._discharge_t = self.routing_engine._discharge_t

Expand All @@ -195,6 +198,8 @@ def forward(self, **kwargs: Any) -> dict[str, torch.Tensor]:
self.n.retain_grad()
if self.q_spatial is not None:
self.q_spatial.retain_grad()
if self.p_spatial is not None and self.p_spatial.requires_grad:
self.p_spatial.retain_grad()
if self._discharge_t is not None:
self._discharge_t.retain_grad()

Expand All @@ -205,10 +210,6 @@ def forward(self, **kwargs: Any) -> dict[str, torch.Tensor]:
spatial_params["n"].retain_grad()
if "q_spatial" in spatial_params:
spatial_params["q_spatial"].retain_grad()
if "top_width" in spatial_params:
spatial_params["top_width"].retain_grad()
if "side_slope" in spatial_params:
spatial_params["side_slope"].retain_grad()
if "p_spatial" in spatial_params:
spatial_params["p_spatial"].retain_grad()

Expand Down
6 changes: 2 additions & 4 deletions src/ddr/validation/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,13 @@ class Params(BaseModel):
default_factory=lambda: {
"n": [0.015, 0.25], # (m⁻¹/³s)
"q_spatial": [0.0, 1.0], # 0 = rectangular, 1 = triangular
"top_width": [1.0, 5000.0], # Log-space (m)
"side_slope": [0.5, 50.0], # H:V ratio Log-space (-)
"p_spatial": [1.0, 200.0], # Leopold & Maddock width coefficient, Log-space (m)
},
description="The parameter space bounds [min, max] to project learned physical values to",
)
log_space_parameters: list[str] = Field(
default_factory=lambda: [
"top_width",
"side_slope",
"p_spatial",
],
description="Parameters to denormalize in log-space for right-skewed distributions",
)
Expand Down
10 changes: 2 additions & 8 deletions tests/routing/test_mmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,11 @@ def test_setup_inputs_basic(self) -> None:
# Check spatial attributes
assert mc.length is not None
assert mc.slope is not None
assert mc.top_width is not None
assert mc.side_slope is not None
assert mc.x_storage is not None
assert_tensor_properties(mc.length, (10,))
assert_tensor_properties(mc.slope, (10,))
assert_tensor_properties(mc.top_width, (10,))
assert_tensor_properties(mc.side_slope, (10,))
assert_tensor_properties(mc.x_storage, (10,))
# top_width and side_slope are now derived per-timestep in route_timestep()

# Check parameter denormalization
assert mc.n is not None
Expand Down Expand Up @@ -115,16 +112,13 @@ def test_setup_inputs_device_conversion(self) -> None:
# Check that all tensors are on correct device
assert mc.length is not None
assert mc.slope is not None
assert mc.top_width is not None
assert mc.side_slope is not None
assert mc.x_storage is not None
assert mc.q_prime is not None
assert mc.length.device.type == "cpu"
assert mc.slope.device.type == "cpu"
assert mc.top_width.device.type == "cpu"
assert mc.side_slope.device.type == "cpu"
assert mc.x_storage.device.type == "cpu"
assert mc.q_prime.device.type == "cpu"
# top_width and side_slope are derived per-timestep in route_timestep()


class TestMuskingumCungeSparseOperations:
Expand Down
2 changes: 1 addition & 1 deletion tests/routing/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def create_mock_config() -> Config:
"gages": "mock.csv",
},
"params": {
"parameter_ranges": {"n": [0.01, 0.1], "q_spatial": [0.1, 0.9]},
"parameter_ranges": {"n": [0.01, 0.1], "q_spatial": [0.1, 0.9], "p_spatial": [1.0, 200.0]},
"defaults": {"p_spatial": 1.0},
"attribute_minimums": {
"velocity": 0.1,
Expand Down
5 changes: 2 additions & 3 deletions tests/validation/test_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,11 @@ def test_parameter_ranges_defaults(self) -> None:
p = Params()
assert "n" in p.parameter_ranges
assert "q_spatial" in p.parameter_ranges
assert "top_width" in p.parameter_ranges
assert "side_slope" in p.parameter_ranges
assert "p_spatial" in p.parameter_ranges

def test_log_space_parameters_default(self) -> None:
p = Params()
assert p.log_space_parameters == ["top_width", "side_slope"]
assert p.log_space_parameters == ["p_spatial"]


class TestSetSeed:
Expand Down
Loading