diff --git a/maxim/models/maxim.py b/maxim/models/maxim.py index c6ff95c..3e5a098 100644 --- a/maxim/models/maxim.py +++ b/maxim/models/maxim.py @@ -81,10 +81,18 @@ class UpSampleRatio(nn.Module): @nn.compact def __call__(self, x): n, h, w, c = x.shape + if self.ratio < 1: + ratio = int(1 / self.ratio) + h = h // ratio + w = w // ratio + else: + h = h * self.ratio + w = w * self.ratio x = jax.image.resize( x, - shape=(n, int(h * self.ratio), int(w * self.ratio), c), - method="bilinear") + shape=(n, h, w, c), + method="bilinear", + ) x = Conv1x1(features=self.features, use_bias=self.use_bias)(x) return x