From 47bf001b3d822c4e4df01f87367f4ae3ae932b26 Mon Sep 17 00:00:00 2001 From: Filippos Date: Wed, 1 Nov 2023 09:44:43 +0100 Subject: [PATCH 1/3] allow unknown named arguments to tvae --- tvo/models/tvae.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tvo/models/tvae.py b/tvo/models/tvae.py index 7bce5b25..3ce11701 100644 --- a/tvo/models/tvae.py +++ b/tvo/models/tvae.py @@ -205,6 +205,8 @@ def __init__( activation: Callable = None, external_model: Optional[to.nn.Module] = None, optimizer: Optional[opt.Optimizer] = None, + *args, + **kwargs, ): """Create a TVAE model with Gaussian observables. @@ -481,6 +483,8 @@ def __init__( activation: Callable = None, external_model: Optional[to.nn.Module] = None, optimizer: Optional[opt.Optimizer] = None, + *args, + **kwargs, ): """Create a TVAE model with Bernoulli observables. From 3c416b1a36dbe91527e68b1f7086a6a71504dd04 Mon Sep 17 00:00:00 2001 From: Filippos Date: Thu, 2 Nov 2023 11:23:51 +0100 Subject: [PATCH 2/3] add cycling learning modes to tvae, default at triangular2 --- tvo/models/tvae.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tvo/models/tvae.py b/tvo/models/tvae.py index 3ce11701..1ca12150 100644 --- a/tvo/models/tvae.py +++ b/tvo/models/tvae.py @@ -205,6 +205,7 @@ def __init__( activation: Callable = None, external_model: Optional[to.nn.Module] = None, optimizer: Optional[opt.Optimizer] = None, + clrmode: str='triangular2', *args, **kwargs, ): @@ -308,6 +309,7 @@ def __init__( max_lr=max_lr, step_size_up=cycliclr_step_size_up, cycle_momentum=False, + mode=clrmode, ) # number of datapoints processed in a training epoch self._train_datapoints = to.tensor([0], dtype=to.int, device=tvo.get_device()) From 97297cb187e07573b767fe25a613c6323658b86d Mon Sep 17 00:00:00 2001 From: Filippos Date: Thu, 2 Nov 2023 12:37:39 +0100 Subject: [PATCH 3/3] formatting --- tvo/models/tvae.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tvo/models/tvae.py b/tvo/models/tvae.py index 1ca12150..b5fd85c7 100644 --- a/tvo/models/tvae.py +++ b/tvo/models/tvae.py @@ -205,7 +205,7 @@ def __init__( activation: Callable = None, external_model: Optional[to.nn.Module] = None, optimizer: Optional[opt.Optimizer] = None, - clrmode: str='triangular2', + clrmode: str = "triangular2", *args, **kwargs, ):