diff --git a/tvo/models/tvae.py b/tvo/models/tvae.py index 7bce5b2..b5fd85c 100644 --- a/tvo/models/tvae.py +++ b/tvo/models/tvae.py @@ -205,6 +205,9 @@ def __init__( activation: Callable = None, external_model: Optional[to.nn.Module] = None, optimizer: Optional[opt.Optimizer] = None, + clrmode: str = "triangular2", + *args, + **kwargs, ): """Create a TVAE model with Gaussian observables. @@ -306,6 +309,7 @@ def __init__( max_lr=max_lr, step_size_up=cycliclr_step_size_up, cycle_momentum=False, + mode=clrmode, ) # number of datapoints processed in a training epoch self._train_datapoints = to.tensor([0], dtype=to.int, device=tvo.get_device()) @@ -481,6 +485,8 @@ def __init__( activation: Callable = None, external_model: Optional[to.nn.Module] = None, optimizer: Optional[opt.Optimizer] = None, + *args, + **kwargs, ): """Create a TVAE model with Bernoulli observables.