From 5f34da7127e6424c9587119240d00f80cc3a5b8c Mon Sep 17 00:00:00 2001 From: Kevin Scott <151596+thekevinscott@users.noreply.github.com> Date: Tue, 16 May 2023 14:50:21 -0400 Subject: [PATCH] Make model updates to support jax to tfjs conversion --- maxim/models/maxim.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/maxim/models/maxim.py b/maxim/models/maxim.py index c6ff95c..3e5a098 100644 --- a/maxim/models/maxim.py +++ b/maxim/models/maxim.py @@ -81,10 +81,18 @@ class UpSampleRatio(nn.Module): @nn.compact def __call__(self, x): n, h, w, c = x.shape + if self.ratio < 1: + ratio = int(1 / self.ratio) + h = h // ratio + w = w // ratio + else: + h = h * self.ratio + w = w * self.ratio x = jax.image.resize( x, - shape=(n, int(h * self.ratio), int(w * self.ratio), c), - method="bilinear") + shape=(n, h, w, c), + method="bilinear", + ) x = Conv1x1(features=self.features, use_bias=self.use_bias)(x) return x