Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from typing import Union
from pydantic import BaseModel, model_validator, Field, AfterValidator
from GANDLF.configuration.default_config import DefaultParameters
from GANDLF.configuration.differential_privacy_config import DifferentialPrivacyConfig
from GANDLF.configuration.nested_training_config import NestedTraining
from GANDLF.configuration.optimizer_config import OptimizerConfig
from GANDLF.configuration.patch_sampler_config import PatchSamplerConfig
from GANDLF.configuration.scheduler_config import SchedulerConfig
from GANDLF.Configuration.default_config import DefaultParameters
from GANDLF.Configuration.differential_privacy_config import DifferentialPrivacyConfig
from GANDLF.Configuration.nested_training_config import NestedTraining
from GANDLF.Configuration.optimizer_config import OptimizerConfig
from GANDLF.Configuration.patch_sampler_config import PatchSamplerConfig
from GANDLF.Configuration.scheduler_config import SchedulerConfig
from GANDLF.utils import version_check
from importlib.metadata import version
from typing_extensions import Self, Literal, Annotated
from GANDLF.configuration.validators import (
validate_schedular,
from GANDLF.Configuration.validators import (
validate_scheduler,
validate_optimizer,
validate_loss_function,
validate_metrics,
Expand All @@ -22,7 +22,8 @@
validate_data_postprocessing_after_reverse_one_hot_encoding,
validate_differential_privacy,
)
from GANDLF.configuration.model_config import ModelConfig
from GANDLF.Configuration.model_config import ModelConfig
from GANDLF.Configuration.scheduler_config import base_triangle_config


class Version(BaseModel): # TODO: Maybe should be to another folder
Expand Down Expand Up @@ -65,8 +66,7 @@ class UserDefinedParameters(DefaultParameters):
default="", description="Parallel compute command."
)
scheduler: Union[str, SchedulerConfig] = Field(
description="Scheduler.", default=SchedulerConfig(type="triangle_modified")
)
description="Scheduler.")
optimizer: Union[str, OptimizerConfig] = Field(
description="Optimizer.", default=OptimizerConfig(type="adam")
) # TODO: Check it again for (opt)
Expand Down Expand Up @@ -103,7 +103,7 @@ def validate(self) -> Self:
self.parallel_compute_command
)
# validate scheduler
self.scheduler = validate_schedular(self.scheduler, self.learning_rate)
self.scheduler = validate_scheduler(self.scheduler, self.learning_rate)
# validate optimizer
self.optimizer = validate_optimizer(self.optimizer)
# validate patch_sampler
Expand Down
25 changes: 23 additions & 2 deletions GANDLF/configuration/utils.py → GANDLF/Configuration/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import logging
from typing import Optional, Union


from typing import Type
from pydantic import BaseModel, ValidationError
from pydantic import BaseModel, ValidationError, create_model
from pydantic_core import ErrorDetails


Expand Down Expand Up @@ -107,3 +106,25 @@ def handle_configuration_errors(e: ValidationError):
messages = extract_messages(convert_errors(e))
for message in messages:
logging.error(message)

def combine_models(base_model: Type[BaseModel], extra_model: Type[BaseModel]):
"""Combine base model with an extra model dynamically."""
fields = {}
# Collect base model fields
for field_name, field_info in base_model.model_fields.items():
fields[field_name] = (field_info.annotation, field_info.default if field_info.default is not Ellipsis else ...)

# Add fields from the extra model
for field_name, field_info in extra_model.model_fields.items():
fields[field_name] = (field_info.annotation, field_info.default if field_info.default is not Ellipsis else ...)

# Return the new dynamically combined model
return create_model(base_model.__name__, **fields)


def add_config_at_the_end_of_values(schedulers_dict):
# Create a new dictionary with '_config' appended to each value
updated_dict = {key: value.__name__ + "_config" if callable(value) else str(value).replace('"', '').replace("'", '')
for key, value in schedulers_dict.items()}

return updated_dict
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import ast
import traceback
from copy import deepcopy
from sched import scheduler

from GANDLF.configuration.differential_privacy_config import DifferentialPrivacyConfig
from GANDLF.Configuration.differential_privacy_config import DifferentialPrivacyConfig
from GANDLF.data.post_process import postprocessing_after_reverse_one_hot_encoding
import numpy as np
import sys
from GANDLF.configuration.optimizer_config import OptimizerConfig
from GANDLF.configuration.patch_sampler_config import PatchSamplerConfig
from GANDLF.configuration.scheduler_config import SchedulerConfig
from GANDLF.configuration.utils import initialize_key
from GANDLF.Configuration.optimizer_config import OptimizerConfig
from GANDLF.Configuration.patch_sampler_config import PatchSamplerConfig
from GANDLF.Configuration.scheduler_config import SchedulerConfig, base_triangle_config
from GANDLF.Configuration.utils import initialize_key, combine_models
from GANDLF.metrics import surface_distance_ids


Expand Down Expand Up @@ -169,11 +170,14 @@ def validate_parallel_compute_command(value):
return value


def validate_schedular(value, learning_rate):
def validate_scheduler(value, learning_rate):
if isinstance(value, str):
value = SchedulerConfig(type=value)
if value.step_size is None:
value.step_size = learning_rate / 5.0
schedulerConfigCombine = combine_models(SchedulerConfig,base_triangle_config)
combineScheduler = schedulerConfigCombine(**value.model_dump())
value = SchedulerConfig(**combineScheduler.model_dump())
return value


Expand Down
6 changes: 3 additions & 3 deletions GANDLF/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import yaml
from pydantic import ValidationError

from GANDLF.configuration.parameters_config import Parameters
from GANDLF.configuration.exclude_parameters import exclude_parameters
from GANDLF.configuration.utils import handle_configuration_errors
from GANDLF.Configuration.parameters_config import Parameters
from GANDLF.Configuration.exclude_parameters import exclude_parameters
from GANDLF.Configuration.utils import handle_configuration_errors


def _parseConfig(
Expand Down
2 changes: 1 addition & 1 deletion GANDLF/configuration/model_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pydantic import BaseModel, model_validator, Field, AliasChoices, ConfigDict
from typing_extensions import Self, Literal, Optional
from typing import Union
from GANDLF.configuration.validators import validate_class_list, validate_norm_type
from GANDLF.Configuration.validators import validate_class_list, validate_norm_type
from GANDLF.models import global_models_dict

# Define model architecture options
Expand Down
2 changes: 1 addition & 1 deletion GANDLF/configuration/parameters_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pydantic import BaseModel, ConfigDict
from GANDLF.configuration.user_defined_config import UserDefinedParameters
from GANDLF.Configuration.user_defined_config import UserDefinedParameters


class ParametersConfiguration(BaseModel):
Expand Down
12 changes: 8 additions & 4 deletions GANDLF/configuration/scheduler_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,17 @@

TYPE_OPTIONS = Literal[tuple(global_schedulers_dict.keys())]

class base_triangle_config(BaseModel):
min_lr: float = Field(default= (10 ** -3))
max_lr: float = Field(default=1)



# It allows extra parameters
class SchedulerConfig(BaseModel):
model_config = ConfigDict(extra="allow")
type: TYPE_OPTIONS = Field(
description="triangle/triangle_modified use LambdaLR but triangular/triangular2/exp_range uses CyclicLR"
description="triangle/triangle_modified use LambdaLR but triangular/triangular2/exp_range uses CyclicLR",
default = "triangle"
)
# min_lr: 0.00001, #TODO: this should be defined ??
# max_lr: 1, #TODO: this should be defined ??
step_size: float = Field(description="step_size", default=None)
step_size: float = Field(description="step_size", default=None)
6 changes: 3 additions & 3 deletions GANDLF/schedulers/wrap_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ def base_triangle(parameters):
This function parses the parameters from the config file and returns the appropriate object
"""

# pick defaults
parameters["scheduler"]["min_lr"] = parameters["scheduler"].get("min_lr", 10**-3)
parameters["scheduler"]["max_lr"] = parameters["scheduler"].get("max_lr", 1)
# # pick defaults
# parameters["scheduler"]["min_lr"] = parameters["scheduler"].get("min_lr", 10**-3)
# parameters["scheduler"]["max_lr"] = parameters["scheduler"].get("max_lr", 1)

clr = cyclical_lr(
parameters["scheduler"]["step_size"],
Expand Down
Loading
Loading