From f808cb5dee01871d6610d060dff4903ddb8553ca Mon Sep 17 00:00:00 2001 From: mknull Date: Fri, 1 Mar 2024 15:49:42 +0100 Subject: [PATCH] format with black=23.1.0 --- examples/bars-test/data.py | 2 +- test/blackbox_test.py | 4 ++-- test/bsc_test.py | 2 +- test/evo_test.py | 2 +- test/fullem_test.py | 4 ++-- test/tvae_test.py | 2 +- tvo/exp/_EStepConfig.py | 2 +- tvo/exp/_utils.py | 2 +- tvo/models/sssc.py | 2 +- tvo/utils/gen.py | 2 +- tvo/variational/_utils.py | 2 +- tvo/variational/fullem.py | 6 +++--- 12 files changed, 16 insertions(+), 16 deletions(-) diff --git a/examples/bars-test/data.py b/examples/bars-test/data.py index d9321cd3..4041292e 100644 --- a/examples/bars-test/data.py +++ b/examples/bars-test/data.py @@ -30,7 +30,7 @@ def get_bars_gfs( """ assert no_bars % 2 == 0, "no_gen_fields must be a multiple of two" R = no_bars // 2 - D = R**2 + D = R ** 2 bg_amp = 0.0 W = bg_amp * to.ones((R, R, no_bars), dtype=precision) diff --git a/test/blackbox_test.py b/test/blackbox_test.py index f97190b2..3265abb4 100644 --- a/test/blackbox_test.py +++ b/test/blackbox_test.py @@ -43,7 +43,7 @@ class Setup: W_init = to.full((D, H), 1.0, dtype=precision, device=_device) sigma_init = to.tensor([1.0], dtype=precision, device=_device) - conf = {"D": D, "H": H, "S": 2**H, "Snew": 0, "batch_size": N, "precision": precision} + conf = {"D": D, "H": H, "S": 2 ** H, "Snew": 0, "batch_size": N, "precision": precision} m = BlackBoxBSC(H, D) m._theta["pi"][:] = pies_init m._theta["W"][:] = W_init.T @@ -61,7 +61,7 @@ class Setup: 0.0, np.log(1.0) - (1.0 / 2), np.log(1.0) - (1.0 / 2), - 2.0 * np.log(1.0) - (1.0 / 2) * 2.0**2, + 2.0 * np.log(1.0) - (1.0 / 2) * 2.0 ** 2, ], [-(1.0 / 2), np.log(1.0), np.log(1.0), 2.0 * np.log(1.0) - (1.0 / 2)], ], diff --git a/test/bsc_test.py b/test/bsc_test.py index bce6db87..b7deacf9 100644 --- a/test/bsc_test.py +++ b/test/bsc_test.py @@ -57,7 +57,7 @@ class Setup: 0.0, np.log(1.0) - (1.0 / 2), np.log(1.0) - (1.0 / 2), - 2.0 * np.log(1.0) - (1.0 / 2) * 2.0**2, + 2.0 * np.log(1.0) - (1.0 / 2) * 2.0 ** 2, ], [-(1.0 / 2), np.log(1.0), np.log(1.0), 2.0 * np.log(1.0) - (1.0 / 2)], ], diff --git a/test/evo_test.py b/test/evo_test.py index 7515b920..4d71bfdd 100644 --- a/test/evo_test.py +++ b/test/evo_test.py @@ -33,7 +33,7 @@ def log_joint(self, data: Tensor, states: Tensor, lpj: Tensor = None) -> Tensor: s_ids = to.empty((H,), dtype=to.int64, device=states.device) for h in range(H): - s_ids[h] = 2**h + s_ids[h] = 2 ** h return ( to.mul(states.to(dtype=to.int64), s_ids[None, None, :].expand(N, S, -1)) diff --git a/test/fullem_test.py b/test/fullem_test.py index 709812be..553b66b7 100644 --- a/test/fullem_test.py +++ b/test/fullem_test.py @@ -29,8 +29,8 @@ def setup(request): def test_init(setup): var_states = setup.var_states - assert var_states.K.shape == (setup.N, 2**setup.H, setup.H) - assert to.unique(var_states.K[0], dim=0).shape[0] == 2**setup.H + assert var_states.K.shape == (setup.N, 2 ** setup.H, setup.H) + assert to.unique(var_states.K[0], dim=0).shape[0] == 2 ** setup.H def test_update(setup): diff --git a/test/tvae_test.py b/test/tvae_test.py index 27a20809..a983751a 100644 --- a/test/tvae_test.py +++ b/test/tvae_test.py @@ -112,7 +112,7 @@ def true_free_energy(tvae_model, data, states): def test_lpj(simple_tvae): N = 2 D, H1, H0 = simple_tvae.net_shape - S = 2**H0 + S = 2 ** H0 states = fullem_for(simple_tvae, N=N) assert (H0, H1, D) == (2, 3, 1), "test assumes this shape for tvae but shape changed" assert states.K.shape == (N, S, H0) diff --git a/tvo/exp/_EStepConfig.py b/tvo/exp/_EStepConfig.py index 322b0cf4..271345bd 100644 --- a/tvo/exp/_EStepConfig.py +++ b/tvo/exp/_EStepConfig.py @@ -119,7 +119,7 @@ def as_dict(self) -> Dict[str, Any]: class FullEMConfig(EStepConfig): def __init__(self, n_latents: int): """Full EM configuration.""" - super().__init__(2**n_latents) + super().__init__(2 ** n_latents) def as_dict(self) -> Dict[str, Any]: return vars(self) diff --git a/tvo/exp/_utils.py b/tvo/exp/_utils.py index cb891fa5..9eda2cff 100644 --- a/tvo/exp/_utils.py +++ b/tvo/exp/_utils.py @@ -32,7 +32,7 @@ def make_var_states( RandomSampledVarStates, ]: if isinstance(conf, FullEMConfig): - assert conf.n_states == 2**H, "FullEMConfig and model have different H" + assert conf.n_states == 2 ** H, "FullEMConfig and model have different H" return FullEM(N, H, precision) elif isinstance(conf, FullEMSingleCauseConfig): assert conf.n_states == H, "FullEMSingleCauseConfig and model have different H" diff --git a/tvo/models/sssc.py b/tvo/models/sssc.py index a2522ffd..b6dc869c 100644 --- a/tvo/models/sssc.py +++ b/tvo/models/sssc.py @@ -442,7 +442,7 @@ def update_param_batch( self._my_sum_xpt_sz.add_(to.sum(batch_xpt_sz, dim=0)) # (H,) self._my_sum_xpt_sz_xpt_szT.add_(batch_xpt_sz.t() @ batch_xpt_sz) # (H, H) self._my_sum_xpt_szszT.add_(to.sum(batch_xpt_szszT, dim=0)) # (H, H) - self._my_sum_diag_yyT.add_(to.sum(batch**2, dim=0)) # (D,) + self._my_sum_diag_yyT.add_(to.sum(batch ** 2, dim=0)) # (D,) self._my_sum_y_szT.add_(batch.t() @ batch_xpt_sz) # (D, H) self._my_N.add_(batch_size) # (1,) if self._reformulated_psi_update: diff --git a/tvo/utils/gen.py b/tvo/utils/gen.py index 044a253a..8253e39e 100644 --- a/tvo/utils/gen.py +++ b/tvo/utils/gen.py @@ -28,7 +28,7 @@ def generate_bars( :returns: tensor containing the bars dictionary """ R = H // 2 - D = R**2 + D = R ** 2 W = bg_amp * to.ones((R, R, H), dtype=precision, device=tvo.get_device()) for i in range(R): diff --git a/tvo/variational/_utils.py b/tvo/variational/_utils.py index 9df5ba85..b1ca144b 100644 --- a/tvo/variational/_utils.py +++ b/tvo/variational/_utils.py @@ -106,7 +106,7 @@ def generate_unique_states( """ if device is None: device = tvo.get_device() - assert n_states <= 2**H, "n_states must be smaller than 2**H" + assert n_states <= 2 ** H, "n_states must be smaller than 2**H" n_samples = max(n_states // 2, 1) s_set = {tuple(s) for s in np.random.binomial(1, p=crowdedness / H, size=(n_samples, H))} diff --git a/tvo/variational/fullem.py b/tvo/variational/fullem.py index 2f655a71..c7432a74 100644 --- a/tvo/variational/fullem.py +++ b/tvo/variational/fullem.py @@ -21,8 +21,8 @@ def state_matrix(H: int, device: to.device = None): if device is None: device = tvo.get_device() - all_states = to.empty((2**H, H), dtype=to.uint8, device=device) - for state in range(2**H): + all_states = to.empty((2 ** H, H), dtype=to.uint8, device=device) + for state in range(2 ** H): bit_sequence = tuple(int(bit) for bit in f"{state:0{H}b}") all_states[state] = to.tensor(bit_sequence, dtype=to.uint8, device=device) return all_states @@ -43,7 +43,7 @@ def __init__(self, N: int, H: int, precision: to.dtype, K_init=None): for c in required_keys: assert c in conf and conf[c] is not None self.config = conf - self.lpj = to.empty((N, 2**H), dtype=precision, device=tvo.get_device()) + self.lpj = to.empty((N, 2 ** H), dtype=precision, device=tvo.get_device()) self.precision = precision self.K = state_matrix(H)[None, :, :].expand(N, -1, -1)