From c95e876ef707696f62a78030db3ef0b63350512d Mon Sep 17 00:00:00 2001 From: SCiarella Date: Fri, 27 Mar 2026 09:25:45 +0100 Subject: [PATCH] Make `Engine` instances reusable --- src/diffwofost/physical_models/engine.py | 97 +++++++++++++++--- src/diffwofost/physical_models/utils.py | 70 +------------ .../physical_models/crop/test_assimilation.py | 10 +- .../crop/test_evapotranspiration.py | 10 +- .../crop/test_leaf_dynamics.py | 10 +- .../physical_models/crop/test_partitioning.py | 10 +- tests/physical_models/crop/test_phenology.py | 7 +- .../physical_models/crop/test_respiration.py | 10 +- .../crop/test_root_dynamics.py | 10 +- .../crop/test_stem_dynamics.py | 10 +- .../crop/test_storage_organ_dynamics.py | 10 +- tests/physical_models/crop/test_wofost72.py | 5 +- .../physical_models/soil/test_waterbalance.py | 10 +- tests/physical_models/test_engine.py | 99 ++++++++++++++----- tests/physical_models/test_utils.py | 11 +++ 15 files changed, 214 insertions(+), 165 deletions(-) diff --git a/src/diffwofost/physical_models/engine.py b/src/diffwofost/physical_models/engine.py index b367012..3130838 100644 --- a/src/diffwofost/physical_models/engine.py +++ b/src/diffwofost/physical_models/engine.py @@ -1,44 +1,101 @@ +import gc from pathlib import Path import torch from pcse import signals from pcse.base import BaseEngine -from pcse.base.variablekiosk import VariableKiosk -from pcse.engine import Engine +from pcse.engine import Engine as PcseEngine from pcse.timer import Timer from pcse.traitlets import Instance from diffwofost.physical_models.config import Configuration +from diffwofost.physical_models.variablekiosk import VariableKiosk -class Engine(Engine): +class Engine(PcseEngine): mconf = Instance(Configuration) def __init__( self, - parameterprovider, - weatherdataprovider, - agromanagement, - config: str | Path | Configuration, + parameterprovider=None, + weatherdataprovider=None, + agromanagement=None, + config: str | Path | Configuration | None = None, + external_states=None, ): BaseEngine.__init__(self) + if config is None: + msg = "A model configuration must be provided when initializing the engine." + raise TypeError(msg) + # If a path is given, load the model configuration from a PCSE config file if isinstance(config, str | Path): self.mconf = Configuration.from_pcse_config_file(config) else: self.mconf = config - self.parameterprovider = parameterprovider - self._shape = _get_params_shape(self.parameterprovider) + self._default_external_states = external_states + + if any( + item is not None for item in (parameterprovider, weatherdataprovider, agromanagement) + ): + if not all( + item is not None + for item in (parameterprovider, weatherdataprovider, agromanagement) + ): + msg = ( + "parameterprovider, weatherdataprovider and agromanagement must all be " + "provided when setting up the engine." + ) + raise TypeError(msg) + self.setup( + parameterprovider, + weatherdataprovider, + agromanagement, + external_states=external_states, + ) - # Variable kiosk for registering and publishing variables - self.kiosk = VariableKiosk() + def _reset_runtime_state(self): + for component_name in ("crop", "soil"): + component = getattr(self, component_name, None) + if component is not None: + component._delete() + setattr(self, component_name, None) + + gc.collect() + + self.flag_terminate = False + self.flag_crop_finish = False + self.flag_crop_start = False + self.flag_crop_delete = False + self.flag_output = False + self.flag_summary_output = False - # Placeholder for variables to be saved during a model run self._saved_output = [] self._saved_summary_output = [] self._saved_terminal_output = {} - # register handlers for starting/finishing the crop simulation, for + def setup( + self, + parameterprovider, + weatherdataprovider, + agromanagement, + external_states=None, + ): + """Set up the engine for a new simulation run.""" + if external_states is None: + external_states = self._default_external_states + else: + self._default_external_states = external_states + + self._reset_runtime_state() + + self.parameterprovider = parameterprovider + self._shape = _get_params_shape(self.parameterprovider) + + # Variable kiosk for registering and publishing variables + self.kiosk = VariableKiosk(external_states) + + # Register handlers for starting/finishing the crop simulation, for # handling output and terminating the system self._connect_signal(self._on_CROP_START, signal=signals.crop_start) self._connect_signal(self._on_CROP_FINISH, signal=signals.crop_finish) @@ -53,6 +110,7 @@ def __init__( # Timer: starting day, final day and model output self.timer = Timer(self.kiosk, start_date, end_date, self.mconf) self.day, _ = self.timer() + self.kiosk(self.day) # Driving variables self.weatherdataprovider = weatherdataprovider @@ -67,6 +125,7 @@ def __init__( # Calculate initial rates self.calc_rates(self.day, self.drv) + return self def _on_CROP_START( self, day, crop_name=None, variety_name=None, crop_start_type=None, crop_end_type=None @@ -86,6 +145,18 @@ def _on_CROP_START( ) self.crop = self.mconf.CROP(day, self.kiosk, self.parameterprovider, shape=self._shape) + def _finish_cropsimulation(self, day): + self.flag_crop_finish = False + + self.crop.finalize(day) + self._save_summary_output() + + if self.flag_crop_delete: + self.flag_crop_delete = False + self.crop._delete() + self.crop = None + gc.collect() + def _get_params_shape(parameterprovider): shape = () diff --git a/src/diffwofost/physical_models/utils.py b/src/diffwofost/physical_models/utils.py index e22787a..d7434af 100644 --- a/src/diffwofost/physical_models/utils.py +++ b/src/diffwofost/physical_models/utils.py @@ -12,23 +12,17 @@ import math from collections import namedtuple from collections.abc import Iterable -from pathlib import Path import torch import yaml from pcse import signals from pcse.base.parameter_providers import ParameterProvider from pcse.base.weather import WeatherDataContainer from pcse.base.weather import WeatherDataProvider -from pcse.engine import BaseEngine from pcse.settings import settings -from pcse.timer import Timer from pcse.traitlets import TraitType from pcse.util import doy from diffwofost.physical_models.config import ComputeConfig -from diffwofost.physical_models.config import Configuration from diffwofost.physical_models.engine import Engine -from diffwofost.physical_models.engine import _get_params_shape -from diffwofost.physical_models.variablekiosk import VariableKiosk logging.disable(logging.CRITICAL) @@ -36,65 +30,6 @@ class EngineTestHelper(Engine): """An engine which is purely for running the YAML unit tests.""" - def __init__( - self, - parameterprovider, - weatherdataprovider, - agromanagement, - config, - external_states=None, - ): - BaseEngine.__init__(self) - - # If a path is given, load the model configuration from a PCSE config file - if isinstance(config, str | Path): - self.mconf = Configuration.from_pcse_config_file(config) - else: - self.mconf = config - - self.parameterprovider = parameterprovider - self._shape = _get_params_shape(self.parameterprovider) - - # Variable kiosk for registering and publishing variables - self.kiosk = VariableKiosk(external_states) - - # Placeholder for variables to be saved during a model run - self._saved_output = list() - self._saved_summary_output = list() - self._saved_terminal_output = dict() - - # register handlers for starting/finishing the crop simulation, for - # handling output and terminating the system - self._connect_signal(self._on_CROP_START, signal=signals.crop_start) - self._connect_signal(self._on_CROP_FINISH, signal=signals.crop_finish) - self._connect_signal(self._on_OUTPUT, signal=signals.output) - self._connect_signal(self._on_TERMINATE, signal=signals.terminate) - - # Component for agromanagement - self.agromanager = self.mconf.AGROMANAGEMENT(self.kiosk, agromanagement) - start_date = self.agromanager.start_date - end_date = self.agromanager.end_date - - # Timer: starting day, final day and model output - self.timer = Timer(self.kiosk, start_date, end_date, self.mconf) - self.day, delt = self.timer() - # Update external states in the kiosk - self.kiosk(self.day) - - # Driving variables - self.weatherdataprovider = weatherdataprovider - self.drv = self._get_driving_variables(self.day) - - # Component for simulation of soil processes - if self.mconf.SOIL is not None: - self.soil = self.mconf.SOIL(self.day, self.kiosk, parameterprovider) - - # Call AgroManagement module for management actions at initialization - self.agromanager(self.day, self.drv) - - # Calculate initial rates - self.calc_rates(self.day, self.drv) - def _run(self): """Make one time step of the simulation.""" # Update timer @@ -135,9 +70,8 @@ def __init__(self, yaml_weather, meteo_range_checks=True): # instances with arrays. settings.METEO_RANGE_CHECKS = meteo_range_checks for weather in yaml_weather: - if "SNOWDEPTH" in weather: - weather.pop("SNOWDEPTH") - wdc = WeatherDataContainer(**weather) + weather_inputs = {k: v for k, v in weather.items() if k != "SNOWDEPTH"} + wdc = WeatherDataContainer(**weather_inputs) self._store_WeatherDataContainer(wdc, wdc.DAY) diff --git a/tests/physical_models/crop/test_assimilation.py b/tests/physical_models/crop/test_assimilation.py index 1ad2c6a..c9910ad 100644 --- a/tests/physical_models/crop/test_assimilation.py +++ b/tests/physical_models/crop/test_assimilation.py @@ -1,4 +1,3 @@ -import copy from unittest.mock import patch import pytest import torch @@ -26,11 +25,11 @@ def get_test_diff_assimilation_model(): prepare_engine_input(test_data, crop_model_params) ) return DiffAssimilation( - copy.deepcopy(crop_model_params_provider), + crop_model_params_provider, weather_data_provider, agro_management_inputs, assimilation_config, - copy.deepcopy(external_states), + external_states, ) @@ -49,17 +48,16 @@ def __init__( self.agro_management_inputs = agro_management_inputs self.config = config self.external_states = external_states + self.engine = EngineTestHelper(config=self.config, external_states=self.external_states) def forward(self, params_dict): for name, value in params_dict.items(): self.crop_model_params_provider.set_override(name, value, check=False) - engine = EngineTestHelper( + engine = self.engine.setup( self.crop_model_params_provider, self.weather_data_provider, self.agro_management_inputs, - self.config, - self.external_states, ) engine.run_till_terminate() results = engine.get_output() diff --git a/tests/physical_models/crop/test_evapotranspiration.py b/tests/physical_models/crop/test_evapotranspiration.py index 7518da7..7813c19 100644 --- a/tests/physical_models/crop/test_evapotranspiration.py +++ b/tests/physical_models/crop/test_evapotranspiration.py @@ -1,4 +1,3 @@ -import copy import datetime import warnings from types import SimpleNamespace @@ -97,11 +96,11 @@ def get_test_diff_evapotranspiration_model(device: str = "cpu"): prepare_engine_input(test_data, crop_model_params, meteo_range_checks=False, device=device) ) return DiffEvapotranspiration( - copy.deepcopy(crop_model_params_provider), + crop_model_params_provider, weather_data_provider, agro_management_inputs, evapotranspiration_config, - copy.deepcopy(external_states), + external_states, device=device, ) @@ -123,6 +122,7 @@ def __init__( self.config = config self.external_states = external_states self.device = device + self.engine = EngineTestHelper(config=self.config, external_states=self.external_states) def forward(self, params_dict: dict[str, torch.Tensor]): for name, value in params_dict.items(): @@ -130,12 +130,10 @@ def forward(self, params_dict: dict[str, torch.Tensor]): value = value.to(self.device) self.crop_model_params_provider.set_override(name, value, check=False) - engine = EngineTestHelper( + engine = self.engine.setup( self.crop_model_params_provider, self.weather_data_provider, self.agro_management_inputs, - self.config, - self.external_states, ) engine.run_till_terminate() results = engine.get_output() diff --git a/tests/physical_models/crop/test_leaf_dynamics.py b/tests/physical_models/crop/test_leaf_dynamics.py index 2319792..7fbeed7 100644 --- a/tests/physical_models/crop/test_leaf_dynamics.py +++ b/tests/physical_models/crop/test_leaf_dynamics.py @@ -1,4 +1,3 @@ -import copy import warnings from unittest.mock import patch import pytest @@ -26,11 +25,11 @@ def get_test_diff_leaf_model(): prepare_engine_input(test_data, crop_model_params) ) return DiffLeafDynamics( - copy.deepcopy(crop_model_params_provider), + crop_model_params_provider, weather_data_provider, agro_management_inputs, leaf_dynamics_config, - copy.deepcopy(external_states), + external_states, ) @@ -49,18 +48,17 @@ def __init__( self.agro_management_inputs = agro_management_inputs self.config = config self.external_states = external_states + self.engine = EngineTestHelper(config=self.config, external_states=self.external_states) def forward(self, params_dict): # pass new value of parameters to the model for name, value in params_dict.items(): self.crop_model_params_provider.set_override(name, value, check=False) - engine = EngineTestHelper( + engine = self.engine.setup( self.crop_model_params_provider, self.weather_data_provider, self.agro_management_inputs, - self.config, - self.external_states, ) engine.run_till_terminate() results = engine.get_output() diff --git a/tests/physical_models/crop/test_partitioning.py b/tests/physical_models/crop/test_partitioning.py index 7b50c13..3990c2c 100644 --- a/tests/physical_models/crop/test_partitioning.py +++ b/tests/physical_models/crop/test_partitioning.py @@ -1,4 +1,3 @@ -import copy import warnings from unittest.mock import patch import pytest @@ -28,11 +27,11 @@ def get_test_diff_partitioning(): external_states, ) = prepare_engine_input(test_data, crop_model_params) return DiffPartitioning( - copy.deepcopy(crop_model_params_provider), + crop_model_params_provider, weather_data_provider, agro_management_inputs, partitioning_config, - copy.deepcopy(external_states), + external_states, ) @@ -51,18 +50,17 @@ def __init__( self.agro_management_inputs = agro_management_inputs self.config = config self.external_states = external_states + self.engine = EngineTestHelper(config=self.config, external_states=self.external_states) def forward(self, params_dict: dict[str, torch.Tensor]): # pass new value of parameters to the model for name, value in params_dict.items(): self.crop_model_params_provider.set_override(name, value, check=False) - engine = EngineTestHelper( + engine = self.engine.setup( self.crop_model_params_provider, self.weather_data_provider, self.agro_management_inputs, - self.config, - self.external_states, ) engine.run_till_terminate() results = engine.get_output() diff --git a/tests/physical_models/crop/test_phenology.py b/tests/physical_models/crop/test_phenology.py index 8ae274d..b988547 100644 --- a/tests/physical_models/crop/test_phenology.py +++ b/tests/physical_models/crop/test_phenology.py @@ -1,4 +1,3 @@ -import copy import warnings from unittest.mock import patch import pytest @@ -86,7 +85,7 @@ def get_test_diff_phenology_model(): prepare_engine_input(test_data, crop_model_params) ) return DiffPhenologyDynamics( - copy.deepcopy(crop_model_params_provider), + crop_model_params_provider, weather_data_provider, agro_management_inputs, phenology_config, @@ -106,17 +105,17 @@ def __init__( self.weather_data_provider = weather_data_provider self.agro_management_inputs = agro_management_inputs self.config = config + self.engine = EngineTestHelper(config=self.config) def forward(self, params_dict): # pass new value of parameters to the model for name, value in params_dict.items(): self.crop_model_params_provider.set_override(name, value, check=False) - engine = EngineTestHelper( + engine = self.engine.setup( self.crop_model_params_provider, self.weather_data_provider, self.agro_management_inputs, - self.config, ) engine.run_till_terminate() results = engine.get_output() diff --git a/tests/physical_models/crop/test_respiration.py b/tests/physical_models/crop/test_respiration.py index 6ab5e61..3b3714c 100644 --- a/tests/physical_models/crop/test_respiration.py +++ b/tests/physical_models/crop/test_respiration.py @@ -1,4 +1,3 @@ -import copy import warnings from unittest.mock import patch import pytest @@ -29,11 +28,11 @@ def get_test_diff_respiration_model(): external_states, ) = prepare_engine_input(test_data, crop_model_params) return DiffRespiration( - copy.deepcopy(crop_model_params_provider), + crop_model_params_provider, weather_data_provider, agro_management_inputs, respiration_config, - copy.deepcopy(external_states), + external_states, ) @@ -52,17 +51,16 @@ def __init__( self.agro_management_inputs = agro_management_inputs self.config = config self.external_states = external_states + self.engine = EngineTestHelper(config=self.config, external_states=self.external_states) def forward(self, params_dict): for name, value in params_dict.items(): self.crop_model_params_provider.set_override(name, value, check=False) - engine = EngineTestHelper( + engine = self.engine.setup( self.crop_model_params_provider, self.weather_data_provider, self.agro_management_inputs, - self.config, - self.external_states, ) engine.run_till_terminate() results = engine.get_output() diff --git a/tests/physical_models/crop/test_root_dynamics.py b/tests/physical_models/crop/test_root_dynamics.py index 50f9987..73c74a0 100644 --- a/tests/physical_models/crop/test_root_dynamics.py +++ b/tests/physical_models/crop/test_root_dynamics.py @@ -1,4 +1,3 @@ -import copy import warnings from unittest.mock import patch import pytest @@ -27,11 +26,11 @@ def get_test_diff_root_model(): prepare_engine_input(test_data, crop_model_params) ) return DiffRootDynamics( - copy.deepcopy(crop_model_params_provider), + crop_model_params_provider, weather_data_provider, agro_management_inputs, root_dynamics_config, - copy.deepcopy(external_states), + external_states, ) @@ -50,18 +49,17 @@ def __init__( self.agro_management_inputs = agro_management_inputs self.config = config self.external_states = external_states + self.engine = EngineTestHelper(config=self.config, external_states=self.external_states) def forward(self, params_dict): # pass new value of parameters to the model for name, value in params_dict.items(): self.crop_model_params_provider.set_override(name, value, check=False) - engine = EngineTestHelper( + engine = self.engine.setup( self.crop_model_params_provider, self.weather_data_provider, self.agro_management_inputs, - self.config, - self.external_states, ) engine.run_till_terminate() results = engine.get_output() diff --git a/tests/physical_models/crop/test_stem_dynamics.py b/tests/physical_models/crop/test_stem_dynamics.py index 593defa..3d478c6 100644 --- a/tests/physical_models/crop/test_stem_dynamics.py +++ b/tests/physical_models/crop/test_stem_dynamics.py @@ -1,4 +1,3 @@ -import copy import warnings from unittest.mock import patch import pytest @@ -87,11 +86,11 @@ def get_test_diff_stem_model(device: str = "cpu"): ) = _prepare_common_stem_inputs(test_data_url, device=device) return DiffStemDynamics( - copy.deepcopy(crop_model_params_provider), + crop_model_params_provider, weather_data_provider, agro_management_inputs, stem_dynamics_config, - copy.deepcopy(external_states), + external_states, device=device, ) @@ -113,18 +112,17 @@ def __init__( self.config = config self.external_states = external_states self.device = device + self.engine = EngineTestHelper(config=self.config, external_states=self.external_states) def forward(self, params_dict): # pass new value of parameters to the model for name, value in params_dict.items(): self.crop_model_params_provider.set_override(name, value, check=False) - engine = EngineTestHelper( + engine = self.engine.setup( self.crop_model_params_provider, self.weather_data_provider, self.agro_management_inputs, - self.config, - self.external_states, ) engine.run_till_terminate() results = engine.get_output() diff --git a/tests/physical_models/crop/test_storage_organ_dynamics.py b/tests/physical_models/crop/test_storage_organ_dynamics.py index 2962f3d..f23cecf 100644 --- a/tests/physical_models/crop/test_storage_organ_dynamics.py +++ b/tests/physical_models/crop/test_storage_organ_dynamics.py @@ -1,4 +1,3 @@ -import copy import warnings from unittest.mock import patch import pytest @@ -80,11 +79,11 @@ def get_test_diff_storage_model(device: str = "cpu"): ) = _prepare_common_storage_inputs(test_data_url) return DiffStorageDynamics( - copy.deepcopy(crop_model_params_provider), + crop_model_params_provider, weather_data_provider, agro_management_inputs, storage_dynamics_config, - copy.deepcopy(external_states), + external_states, device=device, ) @@ -106,18 +105,17 @@ def __init__( self.config = config self.external_states = external_states self.device = device + self.engine = EngineTestHelper(config=self.config, external_states=self.external_states) def forward(self, params_dict): # pass new value of parameters to the model for name, value in params_dict.items(): self.crop_model_params_provider.set_override(name, value, check=False) - engine = EngineTestHelper( + engine = self.engine.setup( self.crop_model_params_provider, self.weather_data_provider, self.agro_management_inputs, - self.config, - self.external_states, ) engine.run_till_terminate() results = engine.get_output() diff --git a/tests/physical_models/crop/test_wofost72.py b/tests/physical_models/crop/test_wofost72.py index 7ba94a1..d8584f7 100644 --- a/tests/physical_models/crop/test_wofost72.py +++ b/tests/physical_models/crop/test_wofost72.py @@ -133,18 +133,17 @@ def __init__( self.agro_management_inputs = agro_management_inputs self.config = config self.external_states = external_states + self.engine = EngineTestHelper(config=self.config, external_states=self.external_states) def forward(self, params_dict): # pass new value of parameters to the model for name, value in params_dict.items(): self.crop_model_params_provider.set_override(name, value, check=False) - engine = EngineTestHelper( + engine = self.engine.setup( self.crop_model_params_provider, self.weather_data_provider, self.agro_management_inputs, - self.config, - self.external_states, ) engine.run_till_terminate() results = engine.get_output() diff --git a/tests/physical_models/soil/test_waterbalance.py b/tests/physical_models/soil/test_waterbalance.py index fafc1a6..7d1e109 100644 --- a/tests/physical_models/soil/test_waterbalance.py +++ b/tests/physical_models/soil/test_waterbalance.py @@ -1,4 +1,3 @@ -import copy import warnings from unittest.mock import patch import pytest @@ -31,11 +30,11 @@ def get_test_diff_waterbalance_model(device: str = "cpu"): prepare_engine_input(test_data, crop_model_params, device=device) ) return DiffWaterbalancePP( - copy.deepcopy(crop_model_params_provider), + crop_model_params_provider, weather_data_provider, agro_management_inputs, waterbalance_config, - copy.deepcopy(external_states), + external_states, ) @@ -54,18 +53,17 @@ def __init__( self.agro_management_inputs = agro_management_inputs self.config = config self.external_states = external_states + self.engine = EngineTestHelper(config=self.config, external_states=self.external_states) def forward(self, params_dict): # Pass new parameter values to the model for name, value in params_dict.items(): self.crop_model_params_provider.set_override(name, value, check=False) - engine = EngineTestHelper( + engine = self.engine.setup( self.crop_model_params_provider, self.weather_data_provider, self.agro_management_inputs, - self.config, - self.external_states, ) engine.run_till_terminate() diff --git a/tests/physical_models/test_engine.py b/tests/physical_models/test_engine.py index fb73ae4..a116d5a 100644 --- a/tests/physical_models/test_engine.py +++ b/tests/physical_models/test_engine.py @@ -1,4 +1,5 @@ import pytest +import torch from diffwofost.physical_models.config import Configuration from diffwofost.physical_models.crop.phenology import DVS_Phenology from diffwofost.physical_models.engine import Engine @@ -12,37 +13,89 @@ ) +def _get_engine_inputs(): + test_data_url = f"{phy_data_folder}/test_phenology_wofost72_05.yaml" + test_data = get_test_data(test_data_url) + crop_model_params = [ + "TSUMEM", + "TBASEM", + "TEFFMX", + "TSUM1", + "TSUM2", + "IDSL", + "DLO", + "DLC", + "DVSI", + "DVSEND", + "DTSMTB", + "VERNSAT", + "VERNBASE", + "VERNDVS", + ] + return test_data, prepare_engine_input(test_data, crop_model_params) + + @pytest.mark.usefixtures("fast_mode") class TestEngine: def test_engine(self): - test_data_url = f"{phy_data_folder}/test_phenology_wofost72_05.yaml" - test_data = get_test_data(test_data_url) - crop_model_params = [ - "TSUMEM", - "TBASEM", - "TEFFMX", - "TSUM1", - "TSUM2", - "IDSL", - "DLO", - "DLC", - "DVSI", - "DVSEND", - "DTSMTB", - "VERNSAT", - "VERNBASE", - "VERNDVS", - ] - (crop_model_params_provider, weather_data_provider, agro_management_inputs, _) = ( - prepare_engine_input(test_data, crop_model_params) + _, (crop_model_params_provider, weather_data_provider, agro_management_inputs, _) = ( + _get_engine_inputs() + ) + engine = Engine( + parameterprovider=crop_model_params_provider, + weatherdataprovider=weather_data_provider, + agromanagement=agro_management_inputs, + config=config, + ) + start_day = engine.day + engine.run(days=5) + + assert engine.day > start_day + + def test_engine_setup_reuses_instance(self): + _, (crop_model_params_provider, weather_data_provider, agro_management_inputs, _) = ( + _get_engine_inputs() ) + engine = Engine(config=config) + + engine.setup(crop_model_params_provider, weather_data_provider, agro_management_inputs) + start_day = engine.day + engine.run(days=5) + assert engine.day > start_day + + updated_dvsi = crop_model_params_provider["DVSI"] + torch.tensor( + 0.2, + dtype=crop_model_params_provider["DVSI"].dtype, + device=crop_model_params_provider["DVSI"].device, + ) + crop_model_params_provider.set_override("DVSI", updated_dvsi, check=False) + + returned_engine = engine.setup( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + ) + + assert returned_engine is engine + assert engine.flag_terminate is False + assert engine.day == start_day + assert torch.equal(engine.parameterprovider["DVSI"], updated_dvsi) + + engine.run(days=1) + assert engine.day > start_day + + def test_engine_preserves_parameter_overrides_after_run(self): + _, (crop_model_params_provider, weather_data_provider, agro_management_inputs, _) = ( + _get_engine_inputs() + ) + original_dvsi = crop_model_params_provider["DVSI"].clone() + engine = Engine( parameterprovider=crop_model_params_provider, weatherdataprovider=weather_data_provider, agromanagement=agro_management_inputs, config=config, ) - engine.run_till_terminate() - actual_results = engine.get_output() + engine.run(days=5) - assert len(actual_results) == len(test_data["ModelResults"]) + assert torch.equal(crop_model_params_provider["DVSI"], original_dvsi) diff --git a/tests/physical_models/test_utils.py b/tests/physical_models/test_utils.py index ed0c6ba..561f431 100644 --- a/tests/physical_models/test_utils.py +++ b/tests/physical_models/test_utils.py @@ -714,6 +714,17 @@ def test_batched_gradient_at_boundaries(self): class TestGetDrvParam: """Tests for _get_drv function.""" + def test_weather_provider_does_not_mutate_input_weather(self): + test_data_url = f"{phy_data_folder}/test_phenology_wofost72_24.yaml" + test_data = get_test_data(test_data_url) + weather_inputs = test_data["WeatherVariables"] + + assert any("SNOWDEPTH" in item for item in weather_inputs) + + WeatherDataProviderTestHelper(weather_inputs) + + assert any("SNOWDEPTH" in item for item in weather_inputs) + def test_float_broadcast(self): expected_shape = (3, 2) test_data_url = f"{phy_data_folder}/test_leafdynamics_wofost72_05.yaml"