Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 126 additions & 15 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,32 @@

from ddr import ddr_functions, dmc, kan, streamflow
from ddr._version import __version__
from ddr.io.readers import ForcingsReader
from ddr.nn import TemporalPhiKAN
from ddr.scripts_utils import load_checkpoint, resolve_learning_rate
from ddr.validation import Config, Metrics, plot_time_series, utils, validate_config
from ddr.validation.enums import BiasLossFn
from ddr.validation.losses import huber_loss, kge_loss, mass_balance_loss

log = logging.getLogger(__name__)


def train(cfg: Config, flow: streamflow, routing_model: dmc, nn: kan) -> None:
def train(
cfg: Config,
flow: streamflow,
routing_model: dmc,
nn: kan,
phi_kan: TemporalPhiKAN | None = None,
q_prime_stats: dict[str, dict[str, float]] | None = None,
forcings_reader: ForcingsReader | None = None,
) -> None:
"""Do model training."""
data_generator = torch.Generator()
data_generator.manual_seed(cfg.seed)
dataset = cfg.geodataset.get_dataset_class(cfg=cfg)

if cfg.experiment.checkpoint:
state = load_checkpoint(nn, cfg.experiment.checkpoint, torch.device(cfg.device))
state = load_checkpoint(nn, cfg.experiment.checkpoint, torch.device(cfg.device), phi_kan=phi_kan)
start_epoch = state["epoch"]
start_mini_batch = (
0 if state["mini_batch"] == 0 else state["mini_batch"] + 1
Expand All @@ -38,7 +50,10 @@ def train(cfg: Config, flow: streamflow, routing_model: dmc, nn: kan) -> None:
start_mini_batch = 0
lr = cfg.experiment.learning_rate[start_epoch]

optimizer = torch.optim.Adam(params=nn.parameters(), lr=lr)
params_to_optimize = list(nn.parameters())
if phi_kan is not None:
params_to_optimize += list(phi_kan.parameters())
optimizer = torch.optim.Adam(params=params_to_optimize, lr=lr)
sampler = RandomSampler(
data_source=dataset,
generator=data_generator,
Expand All @@ -65,10 +80,50 @@ def train(cfg: Config, flow: streamflow, routing_model: dmc, nn: kan) -> None:
start_mini_batch = 0
routing_model.set_progress_info(epoch=epoch, mini_batch=i)

streamflow_predictions = flow(
routing_dataclass=routing_dataclass, device=cfg.device, dtype=torch.float32
)
spatial_params = nn(inputs=routing_dataclass.normalized_spatial_attributes.to(cfg.device))

if phi_kan is not None:
assert q_prime_stats is not None
# Get daily Q' for phi-KAN (24x less memory than hourly)
q_prime_daily = flow(
routing_dataclass=routing_dataclass,
device=cfg.device,
dtype=torch.float32,
use_hourly=True,
)
divide_ids = routing_dataclass.divide_ids
q_mean = torch.tensor(
[q_prime_stats.get(str(did), {}).get("mean", 1e-6) for did in divide_ids],
device=cfg.device,
dtype=torch.float32,
)
q_std = torch.tensor(
[q_prime_stats.get(str(did), {}).get("std", 1e-8) for did in divide_ids],
device=cfg.device,
dtype=torch.float32,
)
forcing_tensor = None
if forcings_reader is not None:
forcing_tensor = forcings_reader(
routing_dataclass=routing_dataclass, device=cfg.device, dtype=torch.float32
)
# Bias-correct at daily resolution
month = dataset.dates.batch_month_tensor_daily.to(cfg.device)
q_prime_corrected = phi_kan(
q_prime_daily,
month=month,
forcing=forcing_tensor,
q_prime_mean=q_mean,
q_prime_std=q_std,
)
# Interpolate corrected daily → hourly for MC routing
T_hourly = len(routing_dataclass.dates.batch_hourly_time_range)
streamflow_predictions = q_prime_corrected.repeat_interleave(24, dim=0)[:T_hourly]
else:
streamflow_predictions = flow(
routing_dataclass=routing_dataclass, device=cfg.device, dtype=torch.float32
)

dmc_kwargs = {
"routing_dataclass": routing_dataclass,
"spatial_parameters": spatial_params,
Expand All @@ -92,15 +147,27 @@ def train(cfg: Config, flow: streamflow, routing_model: dmc, nn: kan) -> None:

filtered_predictions = daily_runoff[~np_nan_mask]

loss = mse_loss(
input=filtered_predictions.transpose(0, 1)[cfg.experiment.warmup :].unsqueeze(2),
target=filtered_observations.transpose(0, 1)[cfg.experiment.warmup :].unsqueeze(2),
)
if phi_kan is not None:
pred_gt = filtered_predictions.transpose(0, 1)[cfg.experiment.warmup :]
obs_gt = filtered_observations.transpose(0, 1)[cfg.experiment.warmup :]
mb_loss = mass_balance_loss(pred_gt, obs_gt)
if cfg.bias.loss_fn == BiasLossFn.HUBER:
routing_loss = huber_loss(pred_gt, obs_gt)
elif cfg.bias.loss_fn == BiasLossFn.KGE:
routing_loss = kge_loss(pred_gt, obs_gt)
else:
routing_loss = mse_loss(pred_gt, obs_gt)
loss = cfg.bias.lambda_mass * mb_loss + (1 - cfg.bias.lambda_mass) * routing_loss
else:
loss = huber_loss(
filtered_predictions.transpose(0, 1)[cfg.experiment.warmup :],
filtered_observations.transpose(0, 1)[cfg.experiment.warmup :],
)

log.info("Running backpropagation")

loss.backward()
torch.nn.utils.clip_grad_norm_(nn.parameters(), max_norm=1.0)
torch.nn.utils.clip_grad_norm_(params_to_optimize, max_norm=1.0)
optimizer.step()
optimizer.zero_grad()

Expand All @@ -112,9 +179,15 @@ def train(cfg: Config, flow: streamflow, routing_model: dmc, nn: kan) -> None:
_nse = metrics.nse
nse = _nse[~np.isinf(_nse) & ~np.isnan(_nse)]
rmse = metrics.rmse
kge = metrics.kge
utils.log_metrics(nse, rmse, kge, epoch=epoch, mini_batch=i)
log.info(f"Loss: {loss.item()}")
kge_metric = metrics.kge
utils.log_metrics(nse, rmse, kge_metric, epoch=epoch, mini_batch=i)
if phi_kan is not None:
log.info(
f"Loss: {loss.item():.6f} (mass_balance: {mb_loss.item():.6f}, "
f"{cfg.bias.loss_fn.value}: {routing_loss.item():.6f})"
)
else:
log.info(f"Loss: {loss.item()}")

# Log parameter ranges for all learnable routing parameters
param_map = {
Expand Down Expand Up @@ -153,8 +226,17 @@ def train(cfg: Config, flow: streamflow, routing_model: dmc, nn: kan) -> None:
optimizer=optimizer,
name=cfg.name,
saved_model_path=cfg.params.save_path / "saved_models",
phi_kan=phi_kan,
)

# Free autograd graph from this mini-batch so the next
# iteration doesn't OOM during the phi_kan forward pass.
del dmc_output, daily_runoff, streamflow_predictions, loss
del filtered_predictions, filtered_observations
routing_model.routing_engine.q_prime = None
routing_model.routing_engine._discharge_t = None
torch.cuda.empty_cache()


@hydra.main(
version_base="1.3",
Expand All @@ -178,9 +260,38 @@ def main(cfg: DictConfig) -> None:
seed=config.seed,
device=config.device,
)

phi_kan = None
q_prime_stats = None
if config.bias.enabled:
phi_kan = TemporalPhiKAN(
cfg=config.bias,
seed=config.seed,
device=config.device,
)

routing_model = dmc(cfg=config, device=cfg.device)
flow = streamflow(config)
train(cfg=config, flow=flow, routing_model=routing_model, nn=nn)

forcings_reader = None
if config.bias.enabled:
from ddr.io.statistics import set_streamflow_statistics
from ddr.validation.enums import PhiInputs

q_prime_stats = set_streamflow_statistics(config, flow.ds)
if config.bias.phi_inputs == PhiInputs.FORCING:
assert config.bias.forcing_var is not None
forcings_reader = ForcingsReader(config, forcing_var_names=[config.bias.forcing_var])

train(
cfg=config,
flow=flow,
routing_model=routing_model,
nn=nn,
phi_kan=phi_kan,
q_prime_stats=q_prime_stats,
forcings_reader=forcings_reader,
)

except KeyboardInterrupt:
log.info("Keyboard interrupt received")
Expand Down
104 changes: 98 additions & 6 deletions scripts/train_and_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@

from ddr import dmc, kan, streamflow
from ddr._version import __version__
from ddr.io.readers import ForcingsReader
from ddr.nn import TemporalPhiKAN
from ddr.scripts_utils import compute_daily_runoff, load_checkpoint
from ddr.validation import Config, Metrics, utils, validate_config
from scripts.train import train
Expand All @@ -37,16 +39,21 @@ def _test(
flow: streamflow,
routing_model: dmc,
nn: kan,
phi_kan: TemporalPhiKAN | None = None,
q_prime_stats: dict[str, dict[str, float]] | None = None,
forcings_reader: ForcingsReader | None = None,
) -> None:
"""Do model evaluation and get performance metrics."""
dataset = cfg.geodataset.get_dataset_class(cfg=cfg)

if cfg.experiment.checkpoint:
load_checkpoint(nn, cfg.experiment.checkpoint, torch.device(cfg.device))
load_checkpoint(nn, cfg.experiment.checkpoint, torch.device(cfg.device), phi_kan=phi_kan)
else:
log.warning("Creating new spatial model for evaluation.")

nn = nn.eval()
if phi_kan is not None:
phi_kan = phi_kan.eval()
sampler = SequentialSampler(
data_source=dataset,
)
Expand Down Expand Up @@ -74,10 +81,50 @@ def _test(
for i, routing_dataclass in enumerate(dataloader, start=0):
routing_model.set_progress_info(epoch=0, mini_batch=i)

streamflow_predictions = flow(
routing_dataclass=routing_dataclass, device=cfg.device, dtype=torch.float32
)
spatial_params = nn(inputs=routing_dataclass.normalized_spatial_attributes.to(cfg.device))

if phi_kan is not None:
assert q_prime_stats is not None
# Get daily Q' for phi-KAN (24x less memory than hourly)
q_prime_daily = flow(
routing_dataclass=routing_dataclass,
device=cfg.device,
dtype=torch.float32,
use_hourly=True,
)
divide_ids = routing_dataclass.divide_ids
q_mean = torch.tensor(
[q_prime_stats.get(str(did), {}).get("mean", 1e-6) for did in divide_ids],
device=cfg.device,
dtype=torch.float32,
)
q_std = torch.tensor(
[q_prime_stats.get(str(did), {}).get("std", 1e-8) for did in divide_ids],
device=cfg.device,
dtype=torch.float32,
)
forcing_tensor = None
if forcings_reader is not None:
forcing_tensor = forcings_reader(
routing_dataclass=routing_dataclass, device=cfg.device, dtype=torch.float32
)
# Bias-correct at daily resolution
month = dataset.dates.batch_month_tensor_daily.to(cfg.device)
q_prime_corrected = phi_kan(
q_prime_daily,
month=month,
forcing=forcing_tensor,
q_prime_mean=q_mean,
q_prime_std=q_std,
)
# Interpolate corrected daily → hourly for MC routing
T_hourly = len(routing_dataclass.dates.batch_hourly_time_range)
streamflow_predictions = q_prime_corrected.repeat_interleave(24, dim=0)[:T_hourly]
else:
streamflow_predictions = flow(
routing_dataclass=routing_dataclass, device=cfg.device, dtype=torch.float32
)

dmc_kwargs = {
"routing_dataclass": routing_dataclass,
"spatial_parameters": spatial_params,
Expand Down Expand Up @@ -159,10 +206,38 @@ def main(cfg: DictConfig) -> None:
seed=config.seed,
device=config.device,
)

phi_kan = None
q_prime_stats = None
if config.bias.enabled:
phi_kan = TemporalPhiKAN(
cfg=config.bias,
seed=config.seed,
device=config.device,
)

routing_model = dmc(cfg=config, device=cfg.device)
flow = streamflow(config)

train(cfg=config, flow=flow, routing_model=routing_model, nn=nn_model)
forcings_reader = None
if config.bias.enabled:
from ddr.io.statistics import set_streamflow_statistics
from ddr.validation.enums import PhiInputs

q_prime_stats = set_streamflow_statistics(config, flow.ds)
if config.bias.phi_inputs == PhiInputs.FORCING:
assert config.bias.forcing_var is not None
forcings_reader = ForcingsReader(config, forcing_var_names=[config.bias.forcing_var])

train(
cfg=config,
flow=flow,
routing_model=routing_model,
nn=nn_model,
phi_kan=phi_kan,
q_prime_stats=q_prime_stats,
forcings_reader=forcings_reader,
)

train_elapsed = time.perf_counter() - start_time
log.info(f"Training complete in {train_elapsed / 60:.2f} minutes")
Expand Down Expand Up @@ -197,10 +272,27 @@ def main(cfg: DictConfig) -> None:
seed=test_config.seed,
device=test_config.device,
)

phi_kan = None
if test_config.bias.enabled:
phi_kan = TemporalPhiKAN(
cfg=test_config.bias,
seed=test_config.seed,
device=test_config.device,
)

routing_model = dmc(cfg=test_config, device=test_config.device)
flow = streamflow(test_config)

_test(cfg=test_config, flow=flow, routing_model=routing_model, nn=nn_model)
_test(
cfg=test_config,
flow=flow,
routing_model=routing_model,
nn=nn_model,
phi_kan=phi_kan,
q_prime_stats=q_prime_stats,
forcings_reader=forcings_reader,
)

except KeyboardInterrupt:
log.info("Keyboard interrupt received")
Expand Down
Loading