Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 84 additions & 13 deletions src/diffwofost/physical_models/engine.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,101 @@
import gc
from pathlib import Path
import torch
from pcse import signals
from pcse.base import BaseEngine
from pcse.base.variablekiosk import VariableKiosk
from pcse.engine import Engine
from pcse.engine import Engine as PcseEngine
from pcse.timer import Timer
from pcse.traitlets import Instance
from diffwofost.physical_models.config import Configuration
from diffwofost.physical_models.variablekiosk import VariableKiosk


class Engine(Engine):
class Engine(PcseEngine):
mconf = Instance(Configuration)

def __init__(
self,
parameterprovider,
weatherdataprovider,
agromanagement,
config: str | Path | Configuration,
parameterprovider=None,
weatherdataprovider=None,
agromanagement=None,
config: str | Path | Configuration | None = None,
external_states=None,
):
BaseEngine.__init__(self)

if config is None:
msg = "A model configuration must be provided when initializing the engine."
raise TypeError(msg)

# If a path is given, load the model configuration from a PCSE config file
if isinstance(config, str | Path):
self.mconf = Configuration.from_pcse_config_file(config)
else:
self.mconf = config

self.parameterprovider = parameterprovider
self._shape = _get_params_shape(self.parameterprovider)
self._default_external_states = external_states

if any(
item is not None for item in (parameterprovider, weatherdataprovider, agromanagement)
):
if not all(
item is not None
for item in (parameterprovider, weatherdataprovider, agromanagement)
):
msg = (
"parameterprovider, weatherdataprovider and agromanagement must all be "
"provided when setting up the engine."
)
raise TypeError(msg)
self.setup(
parameterprovider,
weatherdataprovider,
agromanagement,
external_states=external_states,
)

# Variable kiosk for registering and publishing variables
self.kiosk = VariableKiosk()
def _reset_runtime_state(self):
for component_name in ("crop", "soil"):
component = getattr(self, component_name, None)
if component is not None:
component._delete()
setattr(self, component_name, None)

gc.collect()

self.flag_terminate = False
self.flag_crop_finish = False
self.flag_crop_start = False
self.flag_crop_delete = False
self.flag_output = False
self.flag_summary_output = False

# Placeholder for variables to be saved during a model run
self._saved_output = []
self._saved_summary_output = []
self._saved_terminal_output = {}

# register handlers for starting/finishing the crop simulation, for
def setup(
self,
parameterprovider,
weatherdataprovider,
agromanagement,
external_states=None,
):
"""Set up the engine for a new simulation run."""
if external_states is None:
external_states = self._default_external_states
else:
self._default_external_states = external_states

self._reset_runtime_state()

self.parameterprovider = parameterprovider
self._shape = _get_params_shape(self.parameterprovider)

# Variable kiosk for registering and publishing variables
self.kiosk = VariableKiosk(external_states)

# Register handlers for starting/finishing the crop simulation, for
# handling output and terminating the system
self._connect_signal(self._on_CROP_START, signal=signals.crop_start)
self._connect_signal(self._on_CROP_FINISH, signal=signals.crop_finish)
Expand All @@ -53,6 +110,7 @@ def __init__(
# Timer: starting day, final day and model output
self.timer = Timer(self.kiosk, start_date, end_date, self.mconf)
self.day, _ = self.timer()
self.kiosk(self.day)

# Driving variables
self.weatherdataprovider = weatherdataprovider
Expand All @@ -67,6 +125,7 @@ def __init__(

# Calculate initial rates
self.calc_rates(self.day, self.drv)
return self

def _on_CROP_START(
self, day, crop_name=None, variety_name=None, crop_start_type=None, crop_end_type=None
Expand All @@ -86,6 +145,18 @@ def _on_CROP_START(
)
self.crop = self.mconf.CROP(day, self.kiosk, self.parameterprovider, shape=self._shape)

def _finish_cropsimulation(self, day):
self.flag_crop_finish = False

self.crop.finalize(day)
self._save_summary_output()

if self.flag_crop_delete:
self.flag_crop_delete = False
self.crop._delete()
self.crop = None
gc.collect()


def _get_params_shape(parameterprovider):
shape = ()
Expand Down
70 changes: 2 additions & 68 deletions src/diffwofost/physical_models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,89 +12,24 @@
import math
from collections import namedtuple
from collections.abc import Iterable
from pathlib import Path
import torch
import yaml
from pcse import signals
from pcse.base.parameter_providers import ParameterProvider
from pcse.base.weather import WeatherDataContainer
from pcse.base.weather import WeatherDataProvider
from pcse.engine import BaseEngine
from pcse.settings import settings
from pcse.timer import Timer
from pcse.traitlets import TraitType
from pcse.util import doy
from diffwofost.physical_models.config import ComputeConfig
from diffwofost.physical_models.config import Configuration
from diffwofost.physical_models.engine import Engine
from diffwofost.physical_models.engine import _get_params_shape
from diffwofost.physical_models.variablekiosk import VariableKiosk

logging.disable(logging.CRITICAL)


class EngineTestHelper(Engine):
"""An engine which is purely for running the YAML unit tests."""

def __init__(
self,
parameterprovider,
weatherdataprovider,
agromanagement,
config,
external_states=None,
):
BaseEngine.__init__(self)

# If a path is given, load the model configuration from a PCSE config file
if isinstance(config, str | Path):
self.mconf = Configuration.from_pcse_config_file(config)
else:
self.mconf = config

self.parameterprovider = parameterprovider
self._shape = _get_params_shape(self.parameterprovider)

# Variable kiosk for registering and publishing variables
self.kiosk = VariableKiosk(external_states)

# Placeholder for variables to be saved during a model run
self._saved_output = list()
self._saved_summary_output = list()
self._saved_terminal_output = dict()

# register handlers for starting/finishing the crop simulation, for
# handling output and terminating the system
self._connect_signal(self._on_CROP_START, signal=signals.crop_start)
self._connect_signal(self._on_CROP_FINISH, signal=signals.crop_finish)
self._connect_signal(self._on_OUTPUT, signal=signals.output)
self._connect_signal(self._on_TERMINATE, signal=signals.terminate)

# Component for agromanagement
self.agromanager = self.mconf.AGROMANAGEMENT(self.kiosk, agromanagement)
start_date = self.agromanager.start_date
end_date = self.agromanager.end_date

# Timer: starting day, final day and model output
self.timer = Timer(self.kiosk, start_date, end_date, self.mconf)
self.day, delt = self.timer()
# Update external states in the kiosk
self.kiosk(self.day)

# Driving variables
self.weatherdataprovider = weatherdataprovider
self.drv = self._get_driving_variables(self.day)

# Component for simulation of soil processes
if self.mconf.SOIL is not None:
self.soil = self.mconf.SOIL(self.day, self.kiosk, parameterprovider)

# Call AgroManagement module for management actions at initialization
self.agromanager(self.day, self.drv)

# Calculate initial rates
self.calc_rates(self.day, self.drv)

def _run(self):
"""Make one time step of the simulation."""
# Update timer
Expand Down Expand Up @@ -135,9 +70,8 @@ def __init__(self, yaml_weather, meteo_range_checks=True):
# instances with arrays.
settings.METEO_RANGE_CHECKS = meteo_range_checks
for weather in yaml_weather:
if "SNOWDEPTH" in weather:
weather.pop("SNOWDEPTH")
wdc = WeatherDataContainer(**weather)
weather_inputs = {k: v for k, v in weather.items() if k != "SNOWDEPTH"}
wdc = WeatherDataContainer(**weather_inputs)
self._store_WeatherDataContainer(wdc, wdc.DAY)


Expand Down
10 changes: 4 additions & 6 deletions tests/physical_models/crop/test_assimilation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import copy
from unittest.mock import patch
import pytest
import torch
Expand Down Expand Up @@ -26,11 +25,11 @@ def get_test_diff_assimilation_model():
prepare_engine_input(test_data, crop_model_params)
)
return DiffAssimilation(
copy.deepcopy(crop_model_params_provider),
crop_model_params_provider,
weather_data_provider,
agro_management_inputs,
assimilation_config,
copy.deepcopy(external_states),
external_states,
)


Expand All @@ -49,17 +48,16 @@ def __init__(
self.agro_management_inputs = agro_management_inputs
self.config = config
self.external_states = external_states
self.engine = EngineTestHelper(config=self.config, external_states=self.external_states)

def forward(self, params_dict):
for name, value in params_dict.items():
self.crop_model_params_provider.set_override(name, value, check=False)

engine = EngineTestHelper(
engine = self.engine.setup(
self.crop_model_params_provider,
self.weather_data_provider,
self.agro_management_inputs,
self.config,
self.external_states,
)
engine.run_till_terminate()
results = engine.get_output()
Expand Down
10 changes: 4 additions & 6 deletions tests/physical_models/crop/test_evapotranspiration.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import copy
import datetime
import warnings
from types import SimpleNamespace
Expand Down Expand Up @@ -97,11 +96,11 @@ def get_test_diff_evapotranspiration_model(device: str = "cpu"):
prepare_engine_input(test_data, crop_model_params, meteo_range_checks=False, device=device)
)
return DiffEvapotranspiration(
copy.deepcopy(crop_model_params_provider),
crop_model_params_provider,
weather_data_provider,
agro_management_inputs,
evapotranspiration_config,
copy.deepcopy(external_states),
external_states,
device=device,
)

Expand All @@ -123,19 +122,18 @@ def __init__(
self.config = config
self.external_states = external_states
self.device = device
self.engine = EngineTestHelper(config=self.config, external_states=self.external_states)

def forward(self, params_dict: dict[str, torch.Tensor]):
for name, value in params_dict.items():
if isinstance(value, torch.Tensor) and value.device.type != self.device:
value = value.to(self.device)
self.crop_model_params_provider.set_override(name, value, check=False)

engine = EngineTestHelper(
engine = self.engine.setup(
self.crop_model_params_provider,
self.weather_data_provider,
self.agro_management_inputs,
self.config,
self.external_states,
)
engine.run_till_terminate()
results = engine.get_output()
Expand Down
10 changes: 4 additions & 6 deletions tests/physical_models/crop/test_leaf_dynamics.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import copy
import warnings
from unittest.mock import patch
import pytest
Expand Down Expand Up @@ -26,11 +25,11 @@ def get_test_diff_leaf_model():
prepare_engine_input(test_data, crop_model_params)
)
return DiffLeafDynamics(
copy.deepcopy(crop_model_params_provider),
crop_model_params_provider,
weather_data_provider,
agro_management_inputs,
leaf_dynamics_config,
copy.deepcopy(external_states),
external_states,
)


Expand All @@ -49,18 +48,17 @@ def __init__(
self.agro_management_inputs = agro_management_inputs
self.config = config
self.external_states = external_states
self.engine = EngineTestHelper(config=self.config, external_states=self.external_states)

def forward(self, params_dict):
# pass new value of parameters to the model
for name, value in params_dict.items():
self.crop_model_params_provider.set_override(name, value, check=False)

engine = EngineTestHelper(
engine = self.engine.setup(
self.crop_model_params_provider,
self.weather_data_provider,
self.agro_management_inputs,
self.config,
self.external_states,
)
engine.run_till_terminate()
results = engine.get_output()
Expand Down
Loading
Loading