Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
311 changes: 210 additions & 101 deletions examples/trainNet.py

Large diffs are not rendered by default.

4 changes: 0 additions & 4 deletions prnn/utils/Architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ def __init__(
output_size (int): Specially defined output size. Default: obs_size
"""
super(pRNN, self).__init__()

# pRNN architecture parameters
self.predOffset = predOffset
self.actOffset = actOffset
Expand Down Expand Up @@ -128,9 +127,6 @@ def __init__(
self.W_hs = self.rnn.cell.weight_hs

self.neuralTimescale = neuralTimescale
self.sparsity = (
cell_kwargs["sparsity"] if "sparsity" in cell_kwargs else f
) # backwards compatibility

with torch.no_grad():
self.W.add_(torch.eye(hidden_size).mul_(1 - 1 / self.neuralTimescale).to_sparse())
Expand Down
4 changes: 2 additions & 2 deletions prnn/utils/predictiveNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
"thRNN_8win": thRNN_8win,
"thRNN_9win": thRNN_9win,
"thRNN_10win": thRNN_10win,
"thRNN_0win_noLN": thRNN_0win_noLN,
"thRNN_0win_prevAct": thRNN_0win_prevAct,
"thRNN_1win_prevAct": thRNN_1win_prevAct,
"thRNN_2win_prevAct": thRNN_2win_prevAct,
Expand Down Expand Up @@ -174,7 +175,6 @@ def __init__(
# get all constructor arguments and save them separately in trainArgs for later access...

self.trainArgs = trainArgs

input_args = locals()
input_args.pop("self")
input_args.pop("trainArgs")
Expand Down Expand Up @@ -700,7 +700,7 @@ def resetOptimizer(
if update_alg == "eg":
group["update_alg"] = update_alg
group["lr"] = eg_lr
group["weight_decay"] = eg_weight_decay
group["weight_decay"] = eg_weight_decay * eg_lr

# TODO: convert these to general.savePkl and general.loadPkl (follow SpatialTuningAnalysis.py)
def saveNet(self, savename, savefolder="", cpu=False):
Expand Down
69 changes: 34 additions & 35 deletions prnn/utils/thetaRNN.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def sparse_lognormal_(
with torch.no_grad():
tensor.mul_(0.0)
return

fan = torch.nn.init._calculate_correct_fan(tensor, mode) * sparsity
std = gain / math.sqrt(fan)
# std /= (1+mean_std_ratio**2)**0.5 # Adjust for multiplication with bernoulli
Expand Down Expand Up @@ -517,38 +516,38 @@ def forward(

import torch.nn.functional as F

if __name__ == "__main__":

def test_script_thrnn_layer(seq_len, input_size, hidden_size, trunc, theta):
"""
Compares thetaRNNLayer output to PyTorch native LSTM output.
"""
inp = torch.randn(1, seq_len, input_size)
inp = F.pad(inp, (0, 0, 0, theta))
state = torch.randn(1, 1, hidden_size)
internal = torch.zeros(theta + 1, seq_len + theta, hidden_size)
rnn = thetaRNNLayer(RNNCell, trunc, input_size, hidden_size)
out, out_state = rnn(
inp, internal, state, theta=theta
) # NEED TO CHANGE FROM (inp, internal, state, theta=theta) to (inp, state, internal, theta=theta)
# out, out_state = rnn(inp, state, internal, theta=theta)

# Control: pytorch native LSTM
rnn_ctl = nn.RNN(input_size, hidden_size, batch_first=True, bias=False, nonlinearity="relu")

for rnn_param, custom_param in zip(rnn_ctl.all_weights[0], rnn.parameters()):
assert rnn_param.shape == custom_param.shape
with torch.no_grad():
rnn_param.copy_(custom_param)
rnn_out, rnn_out_state = rnn_ctl(inp, state)

# Check the output matches rnn default for theta=0
assert (out[0, :, :] - rnn_out).abs().max() < 1e-5
assert (out_state - rnn_out_state).abs().max() < 1e-5

# Check the theta prediction matches the rnn output when input is withheld
assert (out[:, -theta - 1, 0] - rnn_out[0, -theta - 1 :, 0]).abs().max() < 1e-5

return out, rnn_out, inp, rnn


test_script_thrnn_layer(5, 3, 7, 10, 4)
def validate_script_thrnn_layer(seq_len, input_size, hidden_size, trunc, theta):
"""
Compares thetaRNNLayer output to PyTorch native LSTM output.
"""
inp = torch.randn(1, seq_len, input_size)
inp = F.pad(inp, (0, 0, 0, theta))
state = torch.randn(1, 1, hidden_size)
internal = torch.zeros(theta + 1, seq_len + theta, hidden_size)
rnn = thetaRNNLayer(RNNCell, trunc, input_size, hidden_size)
out, out_state = rnn(
inp, internal, state, theta=theta
) # NEED TO CHANGE FROM (inp, internal, state, theta=theta) to (inp, state, internal, theta=theta)
# out, out_state = rnn(inp, state, internal, theta=theta)

# Control: pytorch native LSTM
rnn_ctl = nn.RNN(input_size, hidden_size, batch_first=True, bias=False, nonlinearity="relu")

for rnn_param, custom_param in zip(rnn_ctl.all_weights[0], rnn.parameters()):
assert rnn_param.shape == custom_param.shape
with torch.no_grad():
rnn_param.copy_(custom_param)
rnn_out, rnn_out_state = rnn_ctl(inp, state)

# Check the output matches rnn default for theta=0
assert (out[0, :, :] - rnn_out).abs().max() < 1e-5
assert (out_state - rnn_out_state).abs().max() < 1e-5

# Check the theta prediction matches the rnn output when input is withheld
assert (out[:, -theta - 1, 0] - rnn_out[0, -theta - 1 :, 0]).abs().max() < 1e-5

return out, rnn_out, inp, rnn

validate_script_thrnn_layer(5, 3, 7, 10, 4)
9 changes: 5 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,15 @@
"importlib_metadata",
"ruamel.yaml",
"gymnasium==0.29.1",
"pytest==8.4.2",
]

setup(
author="Daniel Levenstein",
author_email='daniel.levenstein@mila.quebec',
python_requires='>=3.9',
name='prnn',
version='v0.1',
author_email="daniel.levenstein@mila.quebec",
python_requires=">=3.9",
name="prnn",
version="v0.1",
packages=find_packages(),
install_requires=dependencies,
description="Python Library for Predictive RNNs Modeling Hippocampal Representation and Replay",
Expand Down
16 changes: 0 additions & 16 deletions test/test_imports.py

This file was deleted.

24 changes: 24 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# tests/conftest.py
import pytest
import shutil
from pathlib import Path

REPO_ROOT = Path(__file__).parent.parent


@pytest.fixture(scope="session", autouse=True)
def cleanup_test_nets():
"""Delete nets/tmp/ before the entire test session completes if there are old ones."""
tmp_nets_dir = REPO_ROOT / "nets" / "tmp"
if tmp_nets_dir.exists():
shutil.rmtree(tmp_nets_dir)
print(f"\nCleaned up {tmp_nets_dir}")

"""Delete nets/tmp/ after the entire test session completes."""
yield # all tests run here


if tmp_nets_dir.exists():
shutil.rmtree(tmp_nets_dir)
print(f"\nCleaned up {tmp_nets_dir}")

14 changes: 14 additions & 0 deletions tests/test_imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import importlib.util
import pytest


@pytest.mark.parametrize(
"package_name",
[
"prnn.utils.predictiveNet",
"prnn.analysis.trajectoryAnalysis",
],
)
def test_import(package_name):
spec = importlib.util.find_spec(package_name)
assert spec is not None, f"Failed to import {package_name}"
Loading