diff --git a/src/diffwofost/physical_models/base/__init__.py b/src/diffwofost/physical_models/base/__init__.py index 615e40c..10ed5be 100644 --- a/src/diffwofost/physical_models/base/__init__.py +++ b/src/diffwofost/physical_models/base/__init__.py @@ -1,5 +1,5 @@ -from .states_rates import TensorParamTemplate -from .states_rates import TensorRatesTemplate -from .states_rates import TensorStatesTemplate +from diffwofost.physical_models.base.states_rates import TensorParamTemplate +from diffwofost.physical_models.base.states_rates import TensorRatesTemplate +from diffwofost.physical_models.base.states_rates import TensorStatesTemplate __all__ = ["TensorParamTemplate", "TensorRatesTemplate", "TensorStatesTemplate"] diff --git a/src/diffwofost/physical_models/base/states_rates.py b/src/diffwofost/physical_models/base/states_rates.py index 1ef6af5..fe65577 100644 --- a/src/diffwofost/physical_models/base/states_rates.py +++ b/src/diffwofost/physical_models/base/states_rates.py @@ -2,8 +2,8 @@ from pcse.base import RatesTemplate from pcse.base import StatesTemplate from pcse.traitlets import HasTraits -from ..traitlets import Tensor -from ..utils import AfgenTrait +from diffwofost.physical_models.traitlets import Tensor +from diffwofost.physical_models.utils import AfgenTrait class TensorContainer(HasTraits): diff --git a/src/diffwofost/physical_models/crop/wofost72.py b/src/diffwofost/physical_models/crop/wofost72.py index 8d15fbd..70b3dc3 100644 --- a/src/diffwofost/physical_models/crop/wofost72.py +++ b/src/diffwofost/physical_models/crop/wofost72.py @@ -12,16 +12,22 @@ from diffwofost.physical_models.base import TensorRatesTemplate from diffwofost.physical_models.base import TensorStatesTemplate from diffwofost.physical_models.config import ComputeConfig +from diffwofost.physical_models.crop.assimilation import WOFOST72_Assimilation as Assimilation +from diffwofost.physical_models.crop.evapotranspiration import ( + EvapotranspirationWrapper as Evapotranspiration, +) +from diffwofost.physical_models.crop.leaf_dynamics import WOFOST_Leaf_Dynamics as Leaf_Dynamics +from diffwofost.physical_models.crop.partitioning import DVS_Partitioning as Partitioning +from diffwofost.physical_models.crop.phenology import DVS_Phenology as Phenology +from diffwofost.physical_models.crop.respiration import ( + WOFOST_Maintenance_Respiration as MaintenanceRespiration, +) +from diffwofost.physical_models.crop.root_dynamics import WOFOST_Root_Dynamics as Root_Dynamics +from diffwofost.physical_models.crop.stem_dynamics import WOFOST_Stem_Dynamics as Stem_Dynamics +from diffwofost.physical_models.crop.storage_organ_dynamics import ( + WOFOST_Storage_Organ_Dynamics as Storage_Organ_Dynamics, +) from diffwofost.physical_models.traitlets import Tensor -from .assimilation import WOFOST72_Assimilation as Assimilation -from .evapotranspiration import EvapotranspirationWrapper as Evapotranspiration -from .leaf_dynamics import WOFOST_Leaf_Dynamics as Leaf_Dynamics -from .partitioning import DVS_Partitioning as Partitioning -from .phenology import DVS_Phenology as Phenology -from .respiration import WOFOST_Maintenance_Respiration as MaintenanceRespiration -from .root_dynamics import WOFOST_Root_Dynamics as Root_Dynamics -from .stem_dynamics import WOFOST_Stem_Dynamics as Stem_Dynamics -from .storage_organ_dynamics import WOFOST_Storage_Organ_Dynamics as Storage_Organ_Dynamics class Wofost72(SimulationObject): diff --git a/src/diffwofost/physical_models/engine.py b/src/diffwofost/physical_models/engine.py index d570f0e..b367012 100644 --- a/src/diffwofost/physical_models/engine.py +++ b/src/diffwofost/physical_models/engine.py @@ -6,7 +6,7 @@ from pcse.engine import Engine from pcse.timer import Timer from pcse.traitlets import Instance -from .config import Configuration +from diffwofost.physical_models.config import Configuration class Engine(Engine): diff --git a/src/diffwofost/physical_models/traitlets.py b/src/diffwofost/physical_models/traitlets.py index efda289..987c3b8 100644 --- a/src/diffwofost/physical_models/traitlets.py +++ b/src/diffwofost/physical_models/traitlets.py @@ -1,7 +1,7 @@ import torch from traitlets_pcse import TraitType from traitlets_pcse import Undefined -from .config import ComputeConfig +from diffwofost.physical_models.config import ComputeConfig class Tensor(TraitType): diff --git a/src/diffwofost/physical_models/utils.py b/src/diffwofost/physical_models/utils.py index 480b22e..e22787a 100644 --- a/src/diffwofost/physical_models/utils.py +++ b/src/diffwofost/physical_models/utils.py @@ -1,8 +1,6 @@ """This file contains code that is required to run the YAML unit tests. It contains: - - VariableKioskTestHelper: A subclass of the VariableKiosk that can use externally - forced states/rates - EngineTestHelper: engine specifically for running the YAML tests. - WeatherDataProviderTestHelper: a weatherdata provides that takes the weather inputs from the YAML file. @@ -19,7 +17,6 @@ import yaml from pcse import signals from pcse.base.parameter_providers import ParameterProvider -from pcse.base.variablekiosk import VariableKiosk from pcse.base.weather import WeatherDataContainer from pcse.base.weather import WeatherDataProvider from pcse.engine import BaseEngine @@ -27,69 +24,15 @@ from pcse.timer import Timer from pcse.traitlets import TraitType from pcse.util import doy -from .config import ComputeConfig -from .config import Configuration -from .engine import Engine -from .engine import _get_params_shape +from diffwofost.physical_models.config import ComputeConfig +from diffwofost.physical_models.config import Configuration +from diffwofost.physical_models.engine import Engine +from diffwofost.physical_models.engine import _get_params_shape +from diffwofost.physical_models.variablekiosk import VariableKiosk logging.disable(logging.CRITICAL) -class VariableKioskTestHelper(VariableKiosk): - """Variable Kiosk for testing purposes which allows to use external states.""" - - external_state_list = None - - def __init__(self, external_state_list=None): - super().__init__() - self.current_externals = {} - if external_state_list: - self.external_state_list = external_state_list - - def __call__(self, day): - """Sets the external state/rate variables for the current day. - - Returns True if the list of external state/rate variables is exhausted, - otherwise False. - """ - if self.external_state_list: - current_externals = self.external_state_list.pop(0) - forcing_day = current_externals.pop("DAY") - msg = "Failure updating VariableKiosk with external states: days are not matching!" - assert forcing_day == day, msg - self.current_externals.clear() - self.current_externals.update(current_externals) - if len(self.external_state_list) == 0: - return True - - return False - - def is_external_state(self, item): - """Returns True if the item is an external state.""" - return item in self.current_externals - - def __getattr__(self, item): - """Allow use of attribute notation. - - eg "kiosk.LAI" on published rates or states. - """ - if item in self.current_externals: - return self.current_externals[item] - else: - return dict.__getitem__(self, item) - - def __getitem__(self, item): - """Override __getitem__ to first look in external states.""" - if item in self.current_externals: - return self.current_externals[item] - else: - return dict.__getitem__(self, item) - - def __contains__(self, key): - """Override __contains__ to first look in external states.""" - return key in self.current_externals or dict.__contains__(self, key) - - class EngineTestHelper(Engine): """An engine which is purely for running the YAML unit tests.""" @@ -113,7 +56,7 @@ def __init__( self._shape = _get_params_shape(self.parameterprovider) # Variable kiosk for registering and publishing variables - self.kiosk = VariableKioskTestHelper(external_states) + self.kiosk = VariableKiosk(external_states) # Placeholder for variables to be saved during a model run self._saved_output = list() @@ -157,10 +100,10 @@ def _run(self): # Update timer self.day, delt = self.timer() - # When the list of external states is exhausted the VariableKioskTestHelper will - # return True signalling the end of the test - stop_test = self.kiosk(self.day) - if stop_test: + self.kiosk(self.day) + # When the list of external states is exhausted, send crop_finish to + # end the test run + if self.kiosk.external_states_exhausted: self._send_signal( signal=signals.crop_finish, day=self.day, finish_type="maturity", crop_delete=False ) diff --git a/src/diffwofost/physical_models/variablekiosk.py b/src/diffwofost/physical_models/variablekiosk.py new file mode 100644 index 0000000..8363d84 --- /dev/null +++ b/src/diffwofost/physical_models/variablekiosk.py @@ -0,0 +1,73 @@ +from pcse.base.variablekiosk import VariableKiosk as _PcseVariableKiosk + + +class VariableKiosk(_PcseVariableKiosk): + """Extends pcse's VariableKiosk with support for external dependencies. + + The external_state_list parameter accepts a list of per-day dicts, each + containing a ``"DAY"`` key and the variable values to inject for that day. + Calling the kiosk with a day (``kiosk(day)``) advances to the next entry + and makes those variables available via normal attribute/item access. + + All original VariableKiosk behaviour (registering, publishing, flushing) + is inherited unchanged from pcse. + """ + + def __init__(self, external_state_list=None): + super().__init__() + self.current_externals = {} + self._last_called_day = None + # Build a day-keyed dict for O(1) lookup + self._external_states = {} + if external_state_list is not None: + self._external_states = { + item["DAY"]: {k: v for k, v in item.items() if k != "DAY"} + for item in list(external_state_list) + } + + def __call__(self, day): + """Set the external state/rate variables for the current day. + + If the day has an entry in the external state list, its values are + injected into ``current_externals``. If the day has no entry, + ``current_externals`` is cleared so the module falls back to normally + registered kiosk variables. Does nothing when no list was provided. + Always returns False; use ``external_states_exhausted`` to check whether + the last entry has been passed. + """ + self._last_called_day = day + if self._external_states: + self.current_externals.clear() + if day in self._external_states: + self.current_externals.update(self._external_states[day]) + return False + + @property + def external_states_exhausted(self): + """True when the simulation has advanced past the last external state entry.""" + if not self._external_states or self._last_called_day is None: + return False + return self._last_called_day >= max(self._external_states.keys()) + + def is_external_state(self, item): + """Returns True if the item is an external state.""" + return item in self.current_externals + + def __contains__(self, item): + """Checks external states first, then the published kiosk variables.""" + current_externals = self.__dict__.get("current_externals", {}) + return item in current_externals or dict.__contains__(self, item) + + def __getitem__(self, item): + """Look in external states before falling back to published variables.""" + current_externals = self.__dict__.get("current_externals", {}) + if item in current_externals: + return current_externals[item] + return dict.__getitem__(self, item) + + def __getattr__(self, item): + """Allow attribute notation (e.g. ``kiosk.LAI``), checking externals first.""" + current_externals = self.__dict__.get("current_externals", {}) + if item in current_externals: + return current_externals[item] + return dict.__getitem__(self, item) diff --git a/tests/physical_models/base/test_variablekiosk.py b/tests/physical_models/base/test_variablekiosk.py new file mode 100644 index 0000000..9ac4785 --- /dev/null +++ b/tests/physical_models/base/test_variablekiosk.py @@ -0,0 +1,212 @@ +import datetime +import pytest +from pcse.base.variablekiosk import VariableKiosk as PcseVariableKiosk +from diffwofost.physical_models.variablekiosk import VariableKiosk + +DAY1 = datetime.date(2000, 1, 1) +DAY2 = datetime.date(2000, 1, 2) +DAY3 = datetime.date(2000, 1, 3) + + +def _make_external_states(): + return [ + {"DAY": DAY1, "LAI": 0.5, "DVS": 0.1}, + {"DAY": DAY2, "LAI": 1.0, "DVS": 0.2}, + {"DAY": DAY3, "LAI": 1.5, "DVS": 0.3}, + ] + + +@pytest.mark.usefixtures("fast_mode") +class TestVariableKioskIsSubclassOfPcse: + def test_is_instance_of_pcse_variablekiosk(self): + """Must satisfy the pcse Instance(VariableKiosk) trait used in BaseEngine.""" + kiosk = VariableKiosk() + assert isinstance(kiosk, PcseVariableKiosk) + + +@pytest.mark.usefixtures("fast_mode") +class TestVariableKioskInit: + def test_init_without_external_states(self): + kiosk = VariableKiosk() + assert kiosk.current_externals == {} + + def test_init_with_external_states_stores_copy(self): + ext = _make_external_states() + kiosk = VariableKiosk(ext) + assert len(kiosk._external_states) == 3 + + def test_init_makes_independent_copy_of_list(self): + ext = _make_external_states() + kiosk = VariableKiosk(ext) + ext.clear() + assert len(kiosk._external_states) == 3 + + +@pytest.mark.usefixtures("fast_mode") +class TestVariableKioskCall: + def test_call_populates_current_externals(self): + kiosk = VariableKiosk(_make_external_states()) + kiosk(DAY1) + assert kiosk.current_externals == {"LAI": 0.5, "DVS": 0.1} + + def test_call_advances_on_each_day(self): + kiosk = VariableKiosk(_make_external_states()) + kiosk(DAY1) + kiosk(DAY2) + assert kiosk.current_externals == {"LAI": 1.0, "DVS": 0.2} + + def test_call_always_returns_false(self): + kiosk = VariableKiosk(_make_external_states()) + assert kiosk(DAY1) is False + assert kiosk(DAY2) is False + assert kiosk(DAY3) is False + + def test_call_returns_false_without_external_list(self): + kiosk = VariableKiosk() + assert kiosk(DAY1) is False + + def test_call_raises_on_day_mismatch(self): + """A day not present in the list simply clears externals — no error.""" + kiosk = VariableKiosk(_make_external_states()) + kiosk(DAY1) + missing_day = datetime.date(1999, 6, 15) + kiosk(missing_day) # should not raise + assert kiosk.current_externals == {} + + def test_external_states_exhausted_false_before_last_entry(self): + kiosk = VariableKiosk(_make_external_states()) + kiosk(DAY1) + assert kiosk.external_states_exhausted is False + + def test_external_states_exhausted_true_after_last_entry(self): + kiosk = VariableKiosk(_make_external_states()) + kiosk(DAY1) + kiosk(DAY2) + kiosk(DAY3) + assert kiosk.external_states_exhausted is True + + def test_external_states_exhausted_false_without_external_list(self): + kiosk = VariableKiosk() + assert kiosk.external_states_exhausted is False + + def test_original_list_is_not_modified(self): + ext = _make_external_states() + kiosk = VariableKiosk(ext) + kiosk(DAY1) + kiosk(DAY2) + kiosk(DAY3) + # The stored dict must be intact even after full consumption + assert len(kiosk._external_states) == 3 + + def test_sparse_external_states_injects_on_matching_days(self): + """Only days present in the list inject externals; gaps clear current_externals.""" + ext = [ + {"DAY": DAY1, "LAI": 0.5}, + {"DAY": DAY3, "LAI": 1.5}, # DAY2 intentionally absent + ] + kiosk = VariableKiosk(ext) + kiosk(DAY1) + assert kiosk.current_externals == {"LAI": 0.5} + kiosk(DAY2) # no entry → externals cleared + assert kiosk.current_externals == {} + kiosk(DAY3) + assert kiosk.current_externals == {"LAI": 1.5} + + def test_external_states_exhausted_false_on_gap_day_before_last_entry(self): + """A gap day before the last entry does not signal finished.""" + ext = [ + {"DAY": DAY1, "LAI": 0.5}, + {"DAY": DAY3, "LAI": 1.5}, + ] + kiosk = VariableKiosk(ext) + kiosk(DAY2) # gap day, max is DAY3 + assert kiosk.external_states_exhausted is False + + +@pytest.mark.usefixtures("fast_mode") +class TestVariableKioskExternalAccess: + def test_getitem_returns_external_variable(self): + kiosk = VariableKiosk(_make_external_states()) + kiosk(DAY1) + assert kiosk["LAI"] == 0.5 + + def test_getattr_returns_external_variable(self): + kiosk = VariableKiosk(_make_external_states()) + kiosk(DAY1) + assert kiosk.LAI == 0.5 + + def test_contains_finds_external_variable(self): + kiosk = VariableKiosk(_make_external_states()) + kiosk(DAY1) + assert "LAI" in kiosk + + def test_is_external_state_returns_true_for_external(self): + kiosk = VariableKiosk(_make_external_states()) + kiosk(DAY1) + assert kiosk.is_external_state("LAI") is True + + def test_is_external_state_returns_false_for_non_external(self): + kiosk = VariableKiosk(_make_external_states()) + kiosk(DAY1) + assert kiosk.is_external_state("NONEXISTENT") is False + + def test_is_external_state_returns_false_before_first_call(self): + kiosk = VariableKiosk(_make_external_states()) + assert kiosk.is_external_state("LAI") is False + + def test_external_shadows_published_variable(self): + """An external state with the same name as a published variable takes precedence.""" + kiosk = VariableKiosk(_make_external_states()) + oid = 42 + kiosk.register_variable(oid, "LAI", type="S", publish=True) + kiosk.set_variable(oid, "LAI", 99.0) + assert kiosk["LAI"] == 99.0 # before external update: published value + kiosk(DAY1) + assert kiosk["LAI"] == 0.5 # after: external overrides + assert kiosk.LAI == 0.5 + + def test_externals_cleared_between_days(self): + """current_externals only holds variables from the most recent day.""" + ext = [ + {"DAY": DAY1, "LAI": 0.5}, + {"DAY": DAY2, "DVS": 0.2}, # LAI absent on day 2 + ] + kiosk = VariableKiosk(ext) + kiosk(DAY1) + assert "LAI" in kiosk.current_externals + kiosk(DAY2) + assert "LAI" not in kiosk.current_externals + assert "DVS" in kiosk.current_externals + + +@pytest.mark.usefixtures("fast_mode") +class TestVariableKioskInheritedBehaviour: + def test_register_and_set_published_variable(self): + kiosk = VariableKiosk() + oid = 1 + kiosk.register_variable(oid, "DVS", type="S", publish=True) + kiosk.set_variable(oid, "DVS", 1.0) + assert kiosk["DVS"] == 1.0 + + def test_flush_rates_clears_published_rates(self): + kiosk = VariableKiosk() + oid = 1 + kiosk.register_variable(oid, "DVR", type="R", publish=True) + kiosk.set_variable(oid, "DVR", 0.05) + kiosk.flush_rates() + assert "DVR" not in kiosk + + def test_flush_states_clears_published_states(self): + kiosk = VariableKiosk() + oid = 1 + kiosk.register_variable(oid, "DVS", type="S", publish=True) + kiosk.set_variable(oid, "DVS", 0.5) + kiosk.flush_states() + assert "DVS" not in kiosk + + def test_variable_exists(self): + kiosk = VariableKiosk() + oid = 1 + kiosk.register_variable(oid, "DVS", type="S") + assert kiosk.variable_exists("DVS") is True + assert kiosk.variable_exists("LAI") is False