Skip to content
This repository was archived by the owner on Mar 31, 2025. It is now read-only.

Commit 619be74

Browse files
Fix linter errors.
1 parent 290abd4 commit 619be74

2 files changed

Lines changed: 8 additions & 11 deletions

File tree

objax/nn/layers.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,7 @@ def __call__(self, x: JaxArray) -> JaxArray:
331331
self.avg.value += (self.avg.value - x) * (self.momentum - 1)
332332
return self.avg.value
333333

334+
334335
class SimpleRNN(Module):
335336
"""Simple Recurrent Neural Network (RNN) block."""
336337

@@ -363,12 +364,10 @@ def __init__(self,
363364
self.w_hh = TrainVar(w_init((self.nstate, self.nstate)))
364365
self.b_h = TrainVar(jn.zeros(self.nstate))
365366

367+
self.output_layer = Linear(self.nstate, self.num_outputs, w_init=w_init)
366368

367-
self.output_layer = Linear(self.nstate, self.num_outputs, w_init = w_init)
368-
369-
def __call__(self, inputs: JaxArray,
370-
initial_state: JaxArray = None,
371-
only_return_final = False) -> Tuple[JaxArray, JaxArray]:
369+
def __call__(self, inputs: JaxArray, initial_state: JaxArray = None,
370+
only_return_final: bool = False) -> Tuple[JaxArray, JaxArray]:
372371
"""Forward pass through RNN.
373372
374373
Args:
@@ -383,7 +382,7 @@ def __call__(self, inputs: JaxArray,
383382
"""
384383
outputs = []
385384

386-
if initial_state == None:
385+
if initial_state is None:
387386
state = jn.zeros((inputs.shape[0], self.nstate))
388387
else:
389388
state = initial_state

tests/simple_rnn.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,24 +16,21 @@
1616

1717
import unittest
1818

19-
import numpy as np
2019
import jax.numpy as jn
21-
import tensorflow as tf
2220

23-
import objax
2421
from objax.nn.layers import SimpleRNN
2522
from objax.functional import one_hot
2623
from objax.functional.core.activation import relu
2724
from objax.nn.init import identity
28-
from objax.zoo.resnet_v2 import convert_keras_model, load_pretrained_weights_from_keras
25+
2926

3027
class TestSimpleRNN(unittest.TestCase):
3128

3229
def test_simple_rnn(self):
3330
nin = nout = 3
3431
batch_size = 1
3532
num_hiddens = 1
36-
model = SimpleRNN(num_hiddens, nin, nout, activation = relu, w_init = identity)
33+
model = SimpleRNN(num_hiddens, nin, nout, activation=relu, w_init=identity)
3734

3835
X = jn.arange(batch_size)
3936
X_one_hot = one_hot(X, nin)
@@ -48,5 +45,6 @@ def test_simple_rnn(self):
4845
Z, _ = model(X_one_hot, state)
4946
self.assertTrue(jn.array_equal(Z, jn.array([[3., 0., 0.]])))
5047

48+
5149
if __name__ == '__main__':
5250
unittest.main()

0 commit comments

Comments
 (0)