From ac6f0fcbf6cbbb221e1752d421d3f8a2649d1cdf Mon Sep 17 00:00:00 2001 From: SabrinaDu Date: Fri, 6 Feb 2026 15:57:55 -0500 Subject: [PATCH 1/7] Revert previous two commits (74c17843 and 563bcbee). Will add them as PR instead. --- examples/trainNet.py | 17 ++--------------- 1 file changed, 2 insertions(+), 15 deletions(-) diff --git a/examples/trainNet.py b/examples/trainNet.py index 6ce70d1b..c15dc008 100644 --- a/examples/trainNet.py +++ b/examples/trainNet.py @@ -23,20 +23,15 @@ import matplotlib.pyplot as plt import torch import random -import warnings from types import SimpleNamespace -import wandb -# Ignore UserWarnings -warnings.filterwarnings("ignore", category=UserWarning) # Parse arguments + parser = argparse.ArgumentParser() ## General parameters -parser.add_argument('--wandb', action='store_true', default=True) -parser.add_argument('--no_wandb', dest='wandb', action='store_false') parser.add_argument( "--env", @@ -243,14 +238,7 @@ args = parser.parse_args() -# Wandb -if args.wandb: - run = wandb.init( - entity="blake-richards", - project="curious-george", -) -# Train savename = args.pRNNtype + "-" + args.namext + "-s" + str(args.seed) figfolder = "nets/" + args.savefolder + "/trainfigs/" + savename analysisfolder = "nets/" + args.savefolder + "/analysis/" + savename @@ -304,7 +292,6 @@ bias_lr=args.bias_lr, dataloader=args.withDataLoader, trainArgs=SimpleNamespace(**args.__dict__), - wandb_log=args.wandb, **architecture_kwargs, ) # allows values in trainArgs to be accessible @@ -347,7 +334,7 @@ # Calculate initial spatial metrics etc print("Training Baseline") predictiveNet.useDataLoader = False - predictiveNet.trainingEpoch(env, agent, sequence_duration=sequence_duration, num_trials=num_trials) + predictiveNet.trainingEpoch(env, agent, sequence_duration=sequence_duration, num_trials=1) predictiveNet.useDataLoader = args.withDataLoader print("Calculating INITIAL Spatial Representation...") place_fields, SI, decoder = predictiveNet.calculateSpatialRepresentation( From e924fb08f150f0bfbacf93a06eb55be241fbae65 Mon Sep 17 00:00:00 2001 From: Meghan Date: Wed, 18 Feb 2026 14:49:39 -0500 Subject: [PATCH 2/7] added unittests for trainnet integration --- examples/trainNet.py | 296 +++++++++++----- prnn/utils/Architectures.py | 1 - prnn/utils/predictiveNet.py | 4 +- prnn/utils/thetaRNN.py | 73 ++-- tests/conftest.py | 24 ++ tests/test.py | 651 ++++++++++++++++++++++++++++++++++++ tests/test_trainnet.py | 603 +++++++++++++++++++++++++++++++++ 7 files changed, 1525 insertions(+), 127 deletions(-) create mode 100644 tests/conftest.py create mode 100644 tests/test.py create mode 100644 tests/test_trainnet.py diff --git a/examples/trainNet.py b/examples/trainNet.py index c15dc008..81a45a31 100644 --- a/examples/trainNet.py +++ b/examples/trainNet.py @@ -105,6 +105,8 @@ parser.add_argument("-s", "--seed", default=8, type=int, help="Random Seed? (Default: 8)") +parser.add_argument("--bptttrunc", default=None, type=float) + parser.add_argument( "--lr", default=2e-3, @@ -178,46 +180,86 @@ "--numworkers", default=1, type=int, help="Number of dataloader workers (Default: 1)" ) -parser.add_argument( - "--sparsity", - default=0.5, - type=float, - help="Activation sparsity (via layer norm, irrelevant for non-LN networks) (Default: 0.5)", -) - # Additional architecture kwargs -parser.add_argument("--use_FF", default=False, type=bool, help="Make network Feed Forward only?") +parser.add_argument( + "--use_FF", + action="store_const", + const=True, + help="Use feedforward network (disable recurrence)", +) # default is false if flag is not used and does not compromise partial use_FF flags parser.add_argument( - "--mask_actions", default=False, type=bool, help="Mask actions from model input as well?" -) + "--mask_actions", + action="store_const", + const=True, + help="Mask actions from model input as well?", +) # similar to use_FF, if flag is not used then defaults to false and wont compromise partial class flags parser.add_argument( - "--actOffset", default=0, type=int, help="Number of timesteps to offset actions by (backwards)" + "--actOffset", + default=None, # defined at the predictiveNet level to avoid overriding partials + type=int, + help="Number of timesteps to offset actions by (backwards)", ) parser.add_argument( "--k", - default=5, + default=None, # defined at the predictiveNet level to avoid overriding partials type=int, help="Number of predictions; i.e. number of future timesteps to mask or number of rollouts", ) -parser.add_argument("--rollout_action", default="full", type=str, help="Action structure") +parser.add_argument( + "--rollout_action", + default=None, # defaults to None to avoid overriding partials + type=str, + help="Action structure, options first, hold, full", +) + +parser.add_argument("--continuousTheta", action="store_true", dest="continuousTheta") +parser.add_argument("--no-continuousTheta", action="store_false", dest="continuousTheta") +parser.set_defaults(continuousTheta=None) +# weight init parameters parser.add_argument( - "--continuousTheta", - default=False, - type=bool, - help="Carry over hidden state from the kth rollout to the t+1'th timestep?", + "--init", + default=None, + type=str, + help="Weight initialization schema to use, options: xavier, log_normal, gamma", +) +# log_normal init parameters +parser.add_argument( + "--sparsity", + default=None, + type=float, + help="Activation sparsity (via layer norm, irrelevant for non-LN networks) (Default: 0.5)", ) +parser.add_argument( + "--mean_std_ratio", + default=None, + type=float, + help="Mean:STD ratio of the log normal distrubtion of the intialized weights, only used if init==log_normak", +) +# gamma init +parser.add_argument( + "--alpha", + default=None, + type=float, + help="Shape (alpha or k) of the gamma disitrbution of the initialized weights, only used if init==gamma", +) +parser.add_argument( + "--theta", + default=None, + type=float, + help="Scale (theta or beta) of the gamma disitrbution of the initialized weights, only used if init==gamma", +) # EG params parser.add_argument( "--eg_weight_decay", - default=1e-6, + default=5e-6, type=float, help="Weight Decay for Exponentiated Gradient Descent (Default: 1e-6)", ) @@ -236,28 +278,105 @@ help="Learning Rate for Biases", ) +# DivNorm Params + +parser.add_argument( + "--target_mean", + default=0.7, + type=float, + help="Target mean for activations for divisive normalization RNN", +) + +parser.add_argument( + "--train_divnorm", + action="store_true", + default=False, + help="Whether to train the divisive normalization parameters: k_div and sigma", +) + +parser.add_argument( + "--k_div", default=1, type=float, help="Parameter in divisive normalization" +) # TODO: add more descriptive language here + +parser.add_argument( + "--sigma", default=1, type=float, help="Parameter in divisive normalization" +) # TODO: add more descriptive language here + +parser.add_argument( + "--test", + action="store_true", + default=False, + help="for running test code", +) + args = parser.parse_args() savename = args.pRNNtype + "-" + args.namext + "-s" + str(args.seed) figfolder = "nets/" + args.savefolder + "/trainfigs/" + savename analysisfolder = "nets/" + args.savefolder + "/analysis/" + savename + + +# === Always include (universal parameters) === architecture_kwargs = { "dropp": args.dropout, - "use_FF": args.use_FF, - "mask_actions": args.mask_actions, - "actOffset": args.actOffset, - "k": args.k, - "rollout_action": args.rollout_action, - "continuousTheta": args.continuousTheta, + "neuralTimescale": args.neuralTimescale, + "bptttrunc": args.bptttrunc, } +architecture_kwargs = {k: v for k, v in architecture_kwargs.items() if v is not None} + + +# === Conditionally include (architecture-specific with None check) === + +# Boolean flags +if args.use_FF is not None: + architecture_kwargs["use_FF"] = args.use_FF + +if args.mask_actions is not None: + architecture_kwargs["mask_actions"] = args.mask_actions + +if args.continuousTheta is not None: + architecture_kwargs["continuousTheta"] = args.continuousTheta + +# Numeric parameters +if args.k is not None: + architecture_kwargs["k"] = args.k + +if args.actOffset is not None: + architecture_kwargs["actOffset"] = args.actOffset + +if args.rollout_action is not None: + architecture_kwargs["rollout_action"] = args.rollout_action + +# Weight initialization +if args.init is not None: + architecture_kwargs["init"] = args.init + +if args.sparsity is not None: + architecture_kwargs["sparsity"] = args.sparsity + +if args.mean_std_ratio is not None: + architecture_kwargs["mean_std_ratio"] = args.mean_std_ratio + +if args.alpha is not None: + architecture_kwargs["alpha"] = args.alpha + +if args.theta is not None: + architecture_kwargs["theta"] = args.theta + +# === Always include (DivNorm parameters) - they aren't harmful === +architecture_kwargs["target_mean"] = args.target_mean +architecture_kwargs["train_divnorm"] = args.train_divnorm +architecture_kwargs["k_div"] = args.k_div +architecture_kwargs["sigma"] = args.sigma + +# === Cell override === if args.cell is not None: if args.cell not in CELL_TYPES.keys(): raise ValueError( f"Cell type '{args.cell}' not recognized. " f"Available cell types: {list(CELL_TYPES.keys())}" ) - architecture_kwargs["cell"] = CELL_TYPES[ args.cell ] # this way default setting for cell will used for each pRNNtype unless overridden here @@ -321,75 +440,78 @@ # %% Training Epoch # Consider these as "trainingparameters" class/dictionary -numepochs = args.numepochs -sequence_duration = args.seqdur -num_trials = args.numtrials -if args.withDataLoader: - batchsize = args.batchsize +if args.test: + predictiveNet.saveNet(args.savefolder + savename) else: - batchsize = 1 - -predictiveNet.trainingCompleted = False -if predictiveNet.numTrainingTrials == -1: - # Calculate initial spatial metrics etc - print("Training Baseline") - predictiveNet.useDataLoader = False - predictiveNet.trainingEpoch(env, agent, sequence_duration=sequence_duration, num_trials=1) - predictiveNet.useDataLoader = args.withDataLoader - print("Calculating INITIAL Spatial Representation...") - place_fields, SI, decoder = predictiveNet.calculateSpatialRepresentation( - env, - agent, - trainDecoder=True, - saveTrainingData=True, - bitsec=False, - calculatesRSA=True, - sleepstd=0.03, - ) - predictiveNet.plotTuningCurvePanel(savename=savename, savefolder=figfolder) - print("Calculating INITIAL Decoding Performance...") - predictiveNet.calculateDecodingPerformance( - env, agent, decoder, savename=savename, savefolder=figfolder, saveTrainingData=True - ) - # predictiveNet.plotDelayDist(env, agent, decoder) + numepochs = args.numepochs + sequence_duration = args.seqdur + num_trials = args.numtrials + if args.withDataLoader: + batchsize = args.batchsize + else: + batchsize = 1 + + predictiveNet.trainingCompleted = False + if predictiveNet.numTrainingTrials == -1: + # Calculate initial spatial metrics etc + print("Training Baseline") + predictiveNet.useDataLoader = False + predictiveNet.trainingEpoch(env, agent, sequence_duration=sequence_duration, num_trials=1) + predictiveNet.useDataLoader = args.withDataLoader + print("Calculating INITIAL Spatial Representation...") + place_fields, SI, decoder = predictiveNet.calculateSpatialRepresentation( + env, + agent, + trainDecoder=True, + saveTrainingData=True, + bitsec=False, + calculatesRSA=True, + sleepstd=0.03, + ) + predictiveNet.plotTuningCurvePanel(savename=savename, savefolder=figfolder) + print("Calculating INITIAL Decoding Performance...") + predictiveNet.calculateDecodingPerformance( + env, agent, decoder, savename=savename, savefolder=figfolder, saveTrainingData=True + ) + # predictiveNet.plotDelayDist(env, agent, decoder) -if hasattr(predictiveNet, "numTrainingEpochs") is False: - predictiveNet.numTrainingEpochs = int(predictiveNet.numTrainingTrials / num_trials) + if hasattr(predictiveNet, "numTrainingEpochs") is False: + predictiveNet.numTrainingEpochs = int(predictiveNet.numTrainingTrials / num_trials) -progress = tqdm(total=numepochs, desc="Training Epochs") # tdqm status bar + progress = tqdm(total=numepochs, desc="Training Epochs") # tdqm status bar -while predictiveNet.numTrainingEpochs < numepochs: # run through all epochs - print(f"Training Epoch {predictiveNet.numTrainingEpochs}") - predictiveNet.trainingEpoch( - env, agent, sequence_duration=sequence_duration, num_trials=num_trials - ) - print("Calculating Spatial Representation...") - place_fields, SI, decoder = predictiveNet.calculateSpatialRepresentation( - env, - agent, - trainDecoder=True, - trainHDDecoder=False, - saveTrainingData=True, - bitsec=False, - calculatesRSA=True, - sleepstd=0.03, - ) - print("Calculating Decoding Performance...") - predictiveNet.calculateDecodingPerformance( - env, agent, decoder, savename=savename, savefolder=figfolder, saveTrainingData=True - ) - predictiveNet.plotLearningCurve(savename=savename, savefolder=figfolder, incDecode=True) - predictiveNet.plotTuningCurvePanel(savename=savename, savefolder=figfolder) - plt.show() - plt.close("all") - predictiveNet.saveNet(args.savefolder + savename) + while predictiveNet.numTrainingEpochs < numepochs: # run through all epochs + print(f"Training Epoch {predictiveNet.numTrainingEpochs}") + predictiveNet.trainingEpoch( + env, agent, sequence_duration=sequence_duration, num_trials=num_trials + ) + print("Calculating Spatial Representation...") + place_fields, SI, decoder = predictiveNet.calculateSpatialRepresentation( + env, + agent, + trainDecoder=True, + trainHDDecoder=False, + saveTrainingData=True, + bitsec=False, + calculatesRSA=True, + sleepstd=0.03, + ) + print("Calculating Decoding Performance...") + predictiveNet.calculateDecodingPerformance( + env, agent, decoder, savename=savename, savefolder=figfolder, saveTrainingData=True + ) + predictiveNet.plotLearningCurve(savename=savename, savefolder=figfolder, incDecode=True) + predictiveNet.plotTuningCurvePanel(savename=savename, savefolder=figfolder) + plt.show() + plt.close("all") + predictiveNet.saveNet(args.savefolder + savename) - progress.update(1) + progress.update(1) -progress.close() + progress.close() -predictiveNet.trainingCompleted = True -TrainingFigure(predictiveNet, savename=savename, savefolder=figfolder) + predictiveNet.trainingCompleted = True + TrainingFigure(predictiveNet, savename=savename, savefolder=figfolder) # If the user doesn't want to save all that training data, delete it except the last one if args.saveTrainData is False: diff --git a/prnn/utils/Architectures.py b/prnn/utils/Architectures.py index 5c8fce2f..fa4b09eb 100644 --- a/prnn/utils/Architectures.py +++ b/prnn/utils/Architectures.py @@ -80,7 +80,6 @@ def __init__( output_size (int): Specially defined output size. Default: obs_size """ super(pRNN, self).__init__() - # pRNN architecture parameters self.predOffset = predOffset self.actOffset = actOffset diff --git a/prnn/utils/predictiveNet.py b/prnn/utils/predictiveNet.py index 5874db03..acbd6cc7 100644 --- a/prnn/utils/predictiveNet.py +++ b/prnn/utils/predictiveNet.py @@ -71,6 +71,7 @@ "thRNN_8win": thRNN_8win, "thRNN_9win": thRNN_9win, "thRNN_10win": thRNN_10win, + "thRNN_0win_noLN": thRNN_0win_noLN, "thRNN_0win_prevAct": thRNN_0win_prevAct, "thRNN_1win_prevAct": thRNN_1win_prevAct, "thRNN_2win_prevAct": thRNN_2win_prevAct, @@ -174,7 +175,6 @@ def __init__( # get all constructor arguments and save them separately in trainArgs for later access... self.trainArgs = trainArgs - input_args = locals() input_args.pop("self") input_args.pop("trainArgs") @@ -700,7 +700,7 @@ def resetOptimizer( if update_alg == "eg": group["update_alg"] = update_alg group["lr"] = eg_lr - group["weight_decay"] = eg_weight_decay + group["weight_decay"] = eg_weight_decay * eg_lr # TODO: convert these to general.savePkl and general.loadPkl (follow SpatialTuningAnalysis.py) def saveNet(self, savename, savefolder="", cpu=False): diff --git a/prnn/utils/thetaRNN.py b/prnn/utils/thetaRNN.py index 118ff651..a68e92b2 100644 --- a/prnn/utils/thetaRNN.py +++ b/prnn/utils/thetaRNN.py @@ -78,7 +78,6 @@ def sparse_lognormal_( with torch.no_grad(): tensor.mul_(0.0) return - fan = torch.nn.init._calculate_correct_fan(tensor, mode) * sparsity std = gain / math.sqrt(fan) # std /= (1+mean_std_ratio**2)**0.5 # Adjust for multiplication with bernoulli @@ -243,7 +242,7 @@ def __init__( tau_a (_type_, optional): Decay in adaptation. Defaults to 8.. """ # initialize class attributes and weights - super().__init__(input_size, hidden_size, actfun, init) + super().__init__(input_size, hidden_size, actfun, init, *args, **kwargs) self.b = b self.tau_a = tau_a @@ -285,7 +284,7 @@ def __init__( mu (int, optional): Mean for LayerNorm. Defaults to 0. sig (int, optional): Std dev for LayerNorm. Defaults to 1. """ - super().__init__(input_size, hidden_size, actfun, init) + super().__init__(input_size, hidden_size, actfun, init, *args, **kwargs) # set up layernorm self.layernorm = LayerNorm(hidden_size, mu, sig) self.layernorm.mu = Parameter(torch.zeros(self.hidden_size) + self.layernorm.mu) @@ -517,38 +516,38 @@ def forward( import torch.nn.functional as F +if __name__ == "__main__": -def test_script_thrnn_layer(seq_len, input_size, hidden_size, trunc, theta): - """ - Compares thetaRNNLayer output to PyTorch native LSTM output. - """ - inp = torch.randn(1, seq_len, input_size) - inp = F.pad(inp, (0, 0, 0, theta)) - state = torch.randn(1, 1, hidden_size) - internal = torch.zeros(theta + 1, seq_len + theta, hidden_size) - rnn = thetaRNNLayer(RNNCell, trunc, input_size, hidden_size) - out, out_state = rnn( - inp, internal, state, theta=theta - ) # NEED TO CHANGE FROM (inp, internal, state, theta=theta) to (inp, state, internal, theta=theta) - # out, out_state = rnn(inp, state, internal, theta=theta) - - # Control: pytorch native LSTM - rnn_ctl = nn.RNN(input_size, hidden_size, batch_first=True, bias=False, nonlinearity="relu") - - for rnn_param, custom_param in zip(rnn_ctl.all_weights[0], rnn.parameters()): - assert rnn_param.shape == custom_param.shape - with torch.no_grad(): - rnn_param.copy_(custom_param) - rnn_out, rnn_out_state = rnn_ctl(inp, state) - - # Check the output matches rnn default for theta=0 - assert (out[0, :, :] - rnn_out).abs().max() < 1e-5 - assert (out_state - rnn_out_state).abs().max() < 1e-5 - - # Check the theta prediction matches the rnn output when input is withheld - assert (out[:, -theta - 1, 0] - rnn_out[0, -theta - 1 :, 0]).abs().max() < 1e-5 - - return out, rnn_out, inp, rnn - - -test_script_thrnn_layer(5, 3, 7, 10, 4) + def validate_script_thrnn_layer(seq_len, input_size, hidden_size, trunc, theta): + """ + Compares thetaRNNLayer output to PyTorch native LSTM output. + """ + inp = torch.randn(1, seq_len, input_size) + inp = F.pad(inp, (0, 0, 0, theta)) + state = torch.randn(1, 1, hidden_size) + internal = torch.zeros(theta + 1, seq_len + theta, hidden_size) + rnn = thetaRNNLayer(RNNCell, trunc, input_size, hidden_size) + out, out_state = rnn( + inp, internal, state, theta=theta + ) # NEED TO CHANGE FROM (inp, internal, state, theta=theta) to (inp, state, internal, theta=theta) + # out, out_state = rnn(inp, state, internal, theta=theta) + + # Control: pytorch native LSTM + rnn_ctl = nn.RNN(input_size, hidden_size, batch_first=True, bias=False, nonlinearity="relu") + + for rnn_param, custom_param in zip(rnn_ctl.all_weights[0], rnn.parameters()): + assert rnn_param.shape == custom_param.shape + with torch.no_grad(): + rnn_param.copy_(custom_param) + rnn_out, rnn_out_state = rnn_ctl(inp, state) + + # Check the output matches rnn default for theta=0 + assert (out[0, :, :] - rnn_out).abs().max() < 1e-5 + assert (out_state - rnn_out_state).abs().max() < 1e-5 + + # Check the theta prediction matches the rnn output when input is withheld + assert (out[:, -theta - 1, 0] - rnn_out[0, -theta - 1 :, 0]).abs().max() < 1e-5 + + return out, rnn_out, inp, rnn + + validate_script_thrnn_layer(5, 3, 7, 10, 4) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..6bbf6e16 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,24 @@ +# tests/conftest.py +import pytest +import shutil +from pathlib import Path + +REPO_ROOT = Path(__file__).parent.parent + + +@pytest.fixture(scope="session", autouse=True) +def cleanup_test_nets(): + """Delete nets/tmp/ before the entire test session completes if there are old ones.""" + tmp_nets_dir = REPO_ROOT / "nets" / "tmp" + if tmp_nets_dir.exists(): + shutil.rmtree(tmp_nets_dir) + print(f"\nCleaned up {tmp_nets_dir}") + + """Delete nets/tmp/ after the entire test session completes.""" + yield # all tests run here + + + if tmp_nets_dir.exists(): + shutil.rmtree(tmp_nets_dir) + print(f"\nCleaned up {tmp_nets_dir}") + diff --git a/tests/test.py b/tests/test.py new file mode 100644 index 00000000..594fdb32 --- /dev/null +++ b/tests/test.py @@ -0,0 +1,651 @@ +""" +test_architecture_flow.py + +Tests that verify: +1. Partial function arguments are correctly passed through +2. Argparse arguments flow correctly to architectures and cells +3. No unintended overrides occur +""" + +import torch +import pytest +from types import SimpleNamespace +import sys + +# sys.path.append("../") # Adjust path as needed + +from prnn.utils.predictiveNet import PredictiveNet, CELL_TYPES, netOptions +from prnn.utils.env import make_env +from prnn.utils.Architectures import * +from prnn.utils.thetaRNN import * + + +# ============================================================================ +# FIXTURES +# ============================================================================ + + +@pytest.fixture +def mock_env(): + """Create a simple test environment""" + env = make_env("LRoom-18x18-v0", "farama-minigrid", "SpeedHD") + return env + + +@pytest.fixture +def base_architecture_kwargs(): + """Base kwargs that should work with any architecture""" + return { + "dropp": 0.15, + "neuralTimescale": 2, + "bptttrunc": 50, + } + + +# ============================================================================ +# TEST 1: PARTIAL PRESETS ARE PRESERVED (No Argparse Override) +# ============================================================================ + + +class TestPartialPresets: + """Test that partial function presets are preserved when no override is given""" + + def test_cell_values(self, mock_env, base_architecture_kwargs): + test_cases = [ + # ("thRNN_0win_noLN", RNNCell), + ("thRNN_0win", LayerNormRNNCell), + ("thcycRNN_5win_holdc_adapt", AdaptingLayerNormRNNCell), + ("AutoencoderFF", RNNCell), + ("AutoencoderFF_LN", LayerNormRNNCell), + ("lognRNN_rollout", LayerNormRNNCell), + ] + + for pRNNtype, expected_cell in test_cases: + net = PredictiveNet( + mock_env, pRNNtype=pRNNtype, hidden_size=100, **base_architecture_kwargs + ) + actual_cell = type(net.pRNN.rnn.cell) + + assert actual_cell == expected_cell, ( + f"{pRNNtype}: Expected cell={expected_cell.__name__}, " + f"got cell={actual_cell.__name__}" + ) + + # Masked specific presets + def test_masked_k_values(self, mock_env, base_architecture_kwargs): + """Test that different masked RNN variants have correct k values""" + test_cases = [ + ("thRNN_0win", 0), + ("thRNN_1win", 1), + ("thRNN_5win", 5), + ("thRNN_10win", 10), + ] + + for pRNNtype, expected_k in test_cases: + net = PredictiveNet( + mock_env, pRNNtype=pRNNtype, hidden_size=100, **base_architecture_kwargs + ) + actual_k = len(net.pRNN.inMask) - 1 + assert actual_k == expected_k, f"{pRNNtype}: Expected k={expected_k}, got k={actual_k}" + + def test_actOffset_values(self, mock_env, base_architecture_kwargs): + """Test that _prevAct variants have correct actOffset""" + test_cases = [ + ("thRNN_0win_prevAct", 1), + ("thRNN_5win_prevAct", 1), + ("thcycRNN_5win_holdc_prevAct", 1), + ] + + for ( + pRNNtype, + expected_actOffset, + ) in test_cases: + net = PredictiveNet( + mock_env, pRNNtype=pRNNtype, hidden_size=100, **base_architecture_kwargs + ) + assert net.pRNN.actOffset == expected_actOffset, ( + f"{pRNNtype}: Expected actOffset={expected_actOffset}, got {net.pRNN.actOffset}" + ) + + def test_masked_mask_actions(self, mock_env, base_architecture_kwargs): + """Test that _mask variants have mask_actions=True""" + test_cases = [ + ("thRNN_1win_mask", [True, False], 1), + ("thRNN_5win_mask", [True, False, False, False, False, False], 5), + ] + + for pRNNtype, expected_mask, expected_k in test_cases: + net = PredictiveNet( + mock_env, pRNNtype=pRNNtype, hidden_size=100, **base_architecture_kwargs + ) + assert (net.pRNN.actMask == expected_mask).all(), ( + f"{pRNNtype}: Expected actMask={expected_mask}, got {net.pRNN.actMask}" + ) + actual_k = len(net.pRNN.inMask) - 1 + assert actual_k == expected_k + + # Rollout Specific presets + def test_rollout_continuousTheta(self, mock_env, base_architecture_kwargs): + """Test that rollout variants have correct continuousTheta""" + test_cases = [ + ("thcycRNN_5win_hold", False), + ("thcycRNN_5win_holdc", True), + ("thcycRNN_5win_first", False), + ("thcycRNN_5win_firstc", True), + ] + + for pRNNtype, expected_continuous in test_cases: + net = PredictiveNet( + mock_env, pRNNtype=pRNNtype, hidden_size=100, **base_architecture_kwargs + ) + assert net.pRNN.rnn.continuousTheta == expected_continuous, ( + f"{pRNNtype}: Expected continuousTheta={expected_continuous}, got {net.pRNN.rnn.continuousTheta}" + ) + + def test_rollout_rollout_action(self, mock_env, base_architecture_kwargs): + """Test that rollout variants have correct rollout_action""" + test_cases = [ + ("thcycRNN_5win_hold", "hold"), # hold = hold + ("thcycRNN_5win_first", False), # first = False + ("thcycRNN_5win_full", True), # full = True + ] + + for pRNNtype, expected_action in test_cases: + net = PredictiveNet( + mock_env, pRNNtype=pRNNtype, hidden_size=100, **base_architecture_kwargs + ) + assert net.pRNN.actionTheta == expected_action, ( + f"{pRNNtype}: Expected rollout_action={expected_action}, got {net.pRNN.actionTheta}" + ) + + def test_rollout_k_values(self, mock_env, base_architecture_kwargs): + """Test that different masked RNN variants have correct k values""" + test_cases = [("thcycRNN_3win", 3), ("lognRNN_rollout", 5)] + + for pRNNtype, expected_k in test_cases: + net = PredictiveNet( + mock_env, pRNNtype=pRNNtype, hidden_size=100, **base_architecture_kwargs + ) + actual_k = net.pRNN.k + assert actual_k == expected_k, f"{pRNNtype}: Expected k={expected_k}, got k={actual_k}" + + # lognRNN presets + def test_lognRNN_init_and_sparsity(self, mock_env, base_architecture_kwargs): + """Test that lognRNN variants have correct init and sparsity""" + test_cases = [ + ("lognRNN_mask"), + ("lognRNN_rollout"), + ] + + for pRNNtype in test_cases: + net = PredictiveNet( + mock_env, pRNNtype=pRNNtype, hidden_size=100, **base_architecture_kwargs + ) + # Check that log_normal init was used (weights should follow log-normal) + assert (net.pRNN.W >= 0).all(), ( + "Expected postiive lognormal distribution, but weights are negative" + ) + sparsity = (net.pRNN.W == 0).float().mean().item() * 100 + expected_sparsity = 95.0 + tolerance = 5.0 + + assert abs(sparsity - expected_sparsity) < tolerance, ( + f"Expected ~{expected_sparsity}% zeros (sparsity=0.05), but got {sparsity:.2f}% zeros" + ) + + def test_autoencoder_use_FF(self, mock_env, base_architecture_kwargs): + """Test that Autoencoder variants have correct use_FF setting""" + test_cases = [ + ("AutoencoderFF", False), # use_FF = true , therefor W does not require grad + ("AutoencoderRec", True), # use_FF = false , therefor W does require grad + ("AutoencoderFFPred", False), + ] + + for pRNNtype, expected_use_FF in test_cases: + net = PredictiveNet( + mock_env, pRNNtype=pRNNtype, hidden_size=100, **base_architecture_kwargs + ) + W_found = False + for name, param in net.pRNN.named_parameters(): + if name == "W": + W_found = True + assert param.requires_grad == expected_use_FF, ( + f"{pRNNtype}: Expected W.requires_grad={expected_use_FF}, " + f"got {param.requires_grad}" + ) + break + + assert W_found, f"{pRNNtype}: W parameter not found in named_parameters()" + + def test_predOffset_values(self, mock_env, base_architecture_kwargs): + """Test that _prevAct variants have correct actOffset""" + test_cases = [ + ("AutoencoderFF", 0), + ("AutoencoderPred", 1), + ] + + for ( + pRNNtype, + expected_predOffset, + ) in test_cases: + net = PredictiveNet( + mock_env, pRNNtype=pRNNtype, hidden_size=100, **base_architecture_kwargs + ) + assert net.pRNN.predOffset == expected_predOffset, ( + f"{pRNNtype}: Expected predOffset={expected_predOffset}, got {net.pRNN.predOffset}" + ) + + +# ============================================================================ +# TEST 2: ARGPARSE OVERRIDES WORK CORRECTLY +# ============================================================================ + + +# class TestArgparseOverrides: +# """Test that argparse arguments correctly override partial presets""" + +# def test_k_override(self, mock_env, base_architecture_kwargs): +# """Test that explicit k argument overrides partial preset""" +# # thRNN_5win has k=5 by default +# kwargs = {**base_architecture_kwargs, "k": 10} +# net = PredictiveNet(mock_env, pRNNtype="thRNN_5win", hidden_size=100, **kwargs) +# assert net.pRNN.k == 10, f"Expected k=10 (override), got k={net.pRNN.k}" + +# def test_actOffset_override(self, mock_env, base_architecture_kwargs): +# """Test that explicit actOffset overrides partial preset""" +# # thRNN_5win_prevAct has actOffset=1 by default +# kwargs = {**base_architecture_kwargs, "actOffset": 3} +# net = PredictiveNet(mock_env, pRNNtype="thRNN_5win_prevAct", hidden_size=100, **kwargs) +# assert net.pRNN.actOffset == 3, f"Expected actOffset=3 (override), got {net.pRNN.actOffset}" + +# def test_continuousTheta_override(self, mock_env, base_architecture_kwargs): +# """Test that explicit continuousTheta overrides partial preset""" +# # thcycRNN_5win_hold has continuousTheta=False by default +# kwargs = {**base_architecture_kwargs, "continuousTheta": True} +# net = PredictiveNet(mock_env, pRNNtype="thcycRNN_5win_hold", hidden_size=100, **kwargs) +# assert net.pRNN.continuousTheta == True, ( +# f"Expected continuousTheta=True (override), got {net.pRNN.continuousTheta}" +# ) + +# def test_rollout_action_override(self, mock_env, base_architecture_kwargs): +# """Test that explicit rollout_action overrides partial preset""" +# # thcycRNN_5win_hold has rollout_action="hold" by default +# kwargs = {**base_architecture_kwargs, "rollout_action": "first"} +# net = PredictiveNet(mock_env, pRNNtype="thcycRNN_5win_hold", hidden_size=100, **kwargs) +# assert net.pRNN.rollout_action == "first", ( +# f"Expected rollout_action='first' (override), got {net.pRNN.rollout_action}" +# ) + +# def test_use_FF_override(self, mock_env, base_architecture_kwargs): +# """Test that explicit use_FF overrides partial preset""" +# # AutoencoderRec has use_FF=False by default +# kwargs = {**base_architecture_kwargs, "use_FF": True} +# net = PredictiveNet(mock_env, pRNNtype="AutoencoderRec", hidden_size=100, **kwargs) +# assert net.pRNN.use_FF == True, f"Expected use_FF=True (override), got {net.pRNN.use_FF}" + +# def test_mask_actions_override(self, mock_env, base_architecture_kwargs): +# """Test that explicit mask_actions overrides partial preset""" +# # thRNN_5win has mask_actions=False by default +# kwargs = {**base_architecture_kwargs, "mask_actions": True} +# net = PredictiveNet(mock_env, pRNNtype="thRNN_5win", hidden_size=100, **kwargs) +# assert net.pRNN.mask_actions == True, ( +# f"Expected mask_actions=True (override), got {net.pRNN.mask_actions}" +# ) + + +# # ============================================================================ +# # TEST 3: NONE VALUES DON'T OVERRIDE PARTIALS +# # ============================================================================ + + +# class TestNoneDoesNotOverride: +# """Test that None values in kwargs don't override partial presets""" + +# def test_k_none_preserves_partial(self, mock_env, base_architecture_kwargs): +# """Test that k=None doesn't override partial preset""" +# kwargs = {**base_architecture_kwargs, "k": None} +# net = PredictiveNet(mock_env, pRNNtype="thRNN_5win", hidden_size=100, **kwargs) +# assert net.pRNN.k == 5, f"k=None should preserve partial preset (5), got {net.pRNN.k}" + +# def test_continuousTheta_none_preserves_partial(self, mock_env, base_architecture_kwargs): +# """Test that continuousTheta=None doesn't override partial preset""" +# kwargs = {**base_architecture_kwargs, "continuousTheta": None} +# net = PredictiveNet(mock_env, pRNNtype="thcycRNN_5win_holdc", hidden_size=100, **kwargs) +# assert net.pRNN.continuousTheta == True, ( +# f"continuousTheta=None should preserve partial preset (True), got {net.pRNN.continuousTheta}" +# ) + +# def test_init_none_preserves_partial(self, mock_env, base_architecture_kwargs): +# """Test that init=None doesn't override lognRNN partial preset""" +# kwargs = {**base_architecture_kwargs, "init": None} +# net = PredictiveNet(mock_env, pRNNtype="lognRNN_mask", hidden_size=100, **kwargs) +# # Can't directly check init, but weights should still be log-normal +# W_hh = net.pRNN.W.detach().cpu().numpy().flatten() +# assert W_hh.mean() > 0 # Basic check for log-normal + + +# # ============================================================================ +# # TEST 4: CELL-SPECIFIC PARAMETERS FLOW CORRECTLY +# # ============================================================================ + + +# class TestCellParameters: +# """Test that cell-specific parameters reach the cell correctly""" + +# def test_divnorm_parameters_reach_cell(self, mock_env, base_architecture_kwargs): +# """Test that DivNorm parameters are correctly passed to cell""" +# kwargs = { +# **base_architecture_kwargs, +# "target_mean": 0.8, +# "k_div": 2.0, +# "sigma": 1.5, +# "train_divnorm": False, +# } + +# net = PredictiveNet( +# mock_env, pRNNtype="Masked", hidden_size=100, cell="DivNormRNNCell", **kwargs +# ) + +# cell = net.pRNN.rnn.cell + +# # Check parameters reached the cell +# assert hasattr(cell, "divnorm"), "Cell should have divnorm module" +# assert hasattr(cell, "target_mean"), "Cell should have target_mean" +# assert cell.target_mean == 0.8, f"Expected target_mean=0.8, got {cell.target_mean}" + +# # Check k_div and sigma values +# assert cell.divnorm.k_div.item() == 2.0, ( +# f"Expected k_div=2.0, got {cell.divnorm.k_div.item()}" +# ) +# assert cell.divnorm.sigma.item() == 1.5, ( +# f"Expected sigma=1.5, got {cell.divnorm.sigma.item()}" +# ) + +# def test_divnorm_trainable_parameters(self, mock_env, base_architecture_kwargs): +# """Test that train_divnorm correctly makes parameters trainable""" +# # Test with train_divnorm=True +# kwargs_trainable = { +# **base_architecture_kwargs, +# "target_mean": 0.7, +# "k_div": 1.0, +# "sigma": 1.0, +# "train_divnorm": True, +# } + +# net = PredictiveNet( +# mock_env, pRNNtype="Masked", hidden_size=100, cell="DivNormRNNCell", **kwargs_trainable +# ) + +# cell = net.pRNN.rnn.cell +# assert isinstance(cell.divnorm.k_div, torch.nn.Parameter), ( +# "k_div should be a Parameter when train_divnorm=True" +# ) +# assert isinstance(cell.divnorm.sigma, torch.nn.Parameter), ( +# "sigma should be a Parameter when train_divnorm=True" +# ) + +# # Test with train_divnorm=False +# kwargs_fixed = { +# **base_architecture_kwargs, +# "target_mean": 0.7, +# "k_div": 1.0, +# "sigma": 1.0, +# "train_divnorm": False, +# } + +# net_fixed = PredictiveNet( +# mock_env, pRNNtype="Masked", hidden_size=100, cell="DivNormRNNCell", **kwargs_fixed +# ) + +# cell_fixed = net_fixed.pRNN.rnn.cell +# assert not isinstance(cell_fixed.divnorm.k_div, torch.nn.Parameter), ( +# "k_div should be a buffer when train_divnorm=False" +# ) +# assert not isinstance(cell_fixed.divnorm.sigma, torch.nn.Parameter), ( +# "sigma should be a buffer when train_divnorm=False" +# ) + +# def test_cell_override_works(self, mock_env, base_architecture_kwargs): +# """Test that cell override argument works""" +# # Test with LayerNormRNNCell +# net_ln = PredictiveNet( +# mock_env, +# pRNNtype="Masked", +# hidden_size=100, +# cell="LayerNormRNNCell", +# **base_architecture_kwargs, +# ) + +# assert net_ln.pRNN.rnn.cell.__class__.__name__ == "LayerNormRNNCell", ( +# f"Expected LayerNormRNNCell, got {net_ln.pRNN.rnn.cell.__class__.__name__}" +# ) + +# # Test with DivNormRNNCell +# kwargs_divnorm = { +# **base_architecture_kwargs, +# "target_mean": 0.7, +# "train_divnorm": False, +# } +# net_dn = PredictiveNet( +# mock_env, pRNNtype="Masked", hidden_size=100, cell="DivNormRNNCell", **kwargs_divnorm +# ) + +# assert net_dn.pRNN.rnn.cell.__class__.__name__ == "DivNormRNNCell", ( +# f"Expected DivNormRNNCell, got {net_dn.pRNN.rnn.cell.__class__.__name__}" +# ) + +# def test_sparsity_reaches_layernorm_cell(self, mock_env, base_architecture_kwargs): +# """Test that sparsity parameter reaches LayerNormRNNCell""" +# kwargs = { +# **base_architecture_kwargs, +# "sparsity": 0.3, +# } + +# net = PredictiveNet( +# mock_env, +# pRNNtype="lognRNN_mask", # Uses LayerNormRNNCell +# hidden_size=100, +# **kwargs, +# ) + +# cell = net.pRNN.rnn.cell +# assert hasattr(cell, "f"), "LayerNormRNNCell should have sparsity (f)" +# # sparsity should equal cell.f +# assert cell.f == 0.3, f"Expected f=0.3, got {cell.f}" + + +# # ============================================================================ +# # TEST 5: OPTIMIZER PARAMETER GROUPS +# # ============================================================================ + + +# class TestOptimizerParameterGroups: +# """Test that optimizer parameter groups are correctly configured""" + +# def test_divnorm_in_optimizer_when_trainable(self, mock_env, base_architecture_kwargs): +# """Test that k_div and sigma are in optimizer when train_divnorm=True""" +# kwargs = { +# **base_architecture_kwargs, +# "target_mean": 0.7, +# "train_divnorm": True, +# } + +# net = PredictiveNet( +# mock_env, +# pRNNtype="Masked", +# hidden_size=100, +# cell="DivNormRNNCell", +# trainBias=False, +# **kwargs, +# ) + +# # Check optimizer has k_div and sigma groups +# param_group_names = [g["name"] for g in net.optimizer.param_groups] +# assert "k_divnorm" in param_group_names, "k_divnorm should be in optimizer parameter groups" +# assert "sigma_divnorm" in param_group_names, ( +# "sigma_divnorm should be in optimizer parameter groups" +# ) + +# def test_divnorm_not_in_optimizer_when_fixed(self, mock_env, base_architecture_kwargs): +# """Test that k_div and sigma are NOT in optimizer when train_divnorm=False""" +# kwargs = { +# **base_architecture_kwargs, +# "target_mean": 0.7, +# "train_divnorm": False, +# } + +# net = PredictiveNet( +# mock_env, +# pRNNtype="Masked", +# hidden_size=100, +# cell="DivNormRNNCell", +# trainBias=False, +# **kwargs, +# ) + +# # Check optimizer does NOT have k_div and sigma groups +# param_group_names = [g["name"] for g in net.optimizer.param_groups] +# assert "k_divnorm" not in param_group_names, ( +# "k_divnorm should NOT be in optimizer when train_divnorm=False" +# ) +# assert "sigma_divnorm" not in param_group_names, ( +# "sigma_divnorm should NOT be in optimizer when train_divnorm=False" +# ) + +# def test_bias_in_optimizer_when_trainable(self, mock_env, base_architecture_kwargs): +# """Test that bias is in optimizer when trainBias=True""" +# net = PredictiveNet( +# mock_env, pRNNtype="Masked", hidden_size=100, trainBias=True, **base_architecture_kwargs +# ) + +# param_group_names = [g["name"] for g in net.optimizer.param_groups] +# assert "biases" in param_group_names, ( +# "biases should be in optimizer parameter groups when trainBias=True" +# ) + +# def test_eg_parameter_groups_configured(self, mock_env, base_architecture_kwargs): +# """Test that EG is correctly configured for parameter groups""" +# kwargs = { +# **base_architecture_kwargs, +# "target_mean": 0.7, +# "train_divnorm": True, +# } + +# net = PredictiveNet( +# mock_env, +# pRNNtype="Masked", +# hidden_size=100, +# cell="DivNormRNNCell", +# eg_lr=1e-3, +# eg_weight_decay=1e-6, +# **kwargs, +# ) + +# # Check that positive parameter groups have update_alg="eg" +# for group in net.optimizer.param_groups: +# if group["name"] in ["RecurrentWeights", "k_divnorm", "sigma_divnorm"]: +# # These should have EG (assuming they're all positive) +# assert group.get("update_alg") == "eg", ( +# f"{group['name']} should use EG update algorithm" +# ) +# assert group["lr"] == 1e-3, f"{group['name']} should have eg_lr=1e-3" +# # Check weight decay scaling +# expected_wd = 1e-6 * 1e-3 # eg_weight_decay * eg_lr +# assert abs(group["weight_decay"] - expected_wd) < 1e-12, ( +# f"{group['name']} weight_decay should be {expected_wd}" +# ) + + +# ============================================================================ +# TEST 6: INTEGRATION TESTS (Full Workflow) +# ============================================================================ + + +# class TestIntegration: +# """Integration tests for full workflow""" + +# def test_train_with_divnorm_trainable(self, mock_env, base_architecture_kwargs): +# """Test that training works with trainable DivNorm parameters""" +# kwargs = { +# **base_architecture_kwargs, +# "target_mean": 0.7, +# "k_div": 1.0, +# "sigma": 1.0, +# "train_divnorm": True, +# } + +# net = PredictiveNet( +# mock_env, +# pRNNtype="Masked", +# hidden_size=50, # Small for speed +# cell="DivNormRNNCell", +# **kwargs, +# ) + +# # Get initial values +# k_div_initial = net.pRNN.rnn.cell.divnorm.k_div.item() +# sigma_initial = net.pRNN.rnn.cell.divnorm.sigma.item() + +# # Run a few training steps +# from prnn.utils.agent import create_agent + +# agent = create_agent("RandomActionAgent", mock_env) + +# # Just run a couple iterations to make sure it doesn't crash +# for _ in range(2): +# obs, act, _, _ = net.collectObservationSequence(mock_env, agent, tsteps=50) +# net.trainStep(obs, act) + +# # Check that parameters changed (they should if trainable) +# k_div_final = net.pRNN.rnn.cell.divnorm.k_div.item() +# sigma_final = net.pRNN.rnn.cell.divnorm.sigma.item() + +# # At least one should have changed +# assert k_div_initial != k_div_final or sigma_initial != sigma_final, ( +# "Trainable DivNorm parameters should change during training" +# ) + +# def test_all_partial_architectures_instantiate(self, mock_env, base_architecture_kwargs): +# """Smoke test: all partial architectures can be instantiated""" +# architectures_to_test = [ +# "NextStep", +# "Masked", +# "Rollout", +# "AutoencoderFF", +# "AutoencoderRec", +# "AutoencoderPred", +# "AutoencoderFFPred", +# "AutoencoderFF_LN", +# "AutoencoderRec_LN", +# "AutoencoderPred_LN", +# "AutoencoderFFPred_LN", +# "thRNN_0win", +# "thRNN_5win", +# "thRNN_10win", +# "thRNN_0win_prevAct", +# "thRNN_5win_prevAct", +# "thRNN_1win_mask", +# "thRNN_5win_mask", +# "thcycRNN_5win", +# "thcycRNN_5win_hold", +# "thcycRNN_5win_holdc", +# "lognRNN_mask", +# "lognRNN_rollout", +# ] + +# for pRNNtype in architectures_to_test: +# try: +# net = PredictiveNet( +# mock_env, pRNNtype=pRNNtype, hidden_size=50, **base_architecture_kwargs +# ) +# assert net is not None, f"{pRNNtype} failed to instantiate" +# except Exception as e: +# pytest.fail(f"{pRNNtype} raised exception: {e}") + + +# ============================================================================ +# RUN TESTS +# ============================================================================ + +if __name__ == "__main__": + pytest.main([__file__, "-v", "--tb=short"]) diff --git a/tests/test_trainnet.py b/tests/test_trainnet.py new file mode 100644 index 00000000..f19acab5 --- /dev/null +++ b/tests/test_trainnet.py @@ -0,0 +1,603 @@ +# test_trainnet_cli.py +""" +End-to-end tests for trainNet.py command-line interface. +Tests the actual user experience - no imports from trainNet.py needed. +""" + +import subprocess +import pytest +from pathlib import Path +from prnn.utils.thetaRNN import RNNCell, LayerNormRNNCell + +REPO_ROOT = Path(__file__).parent.parent +TRAIN_SCRIPT = REPO_ROOT / "examples" / "trainNet.py" + + +@pytest.fixture +def run_trainnet(tmp_path): + """ + Helper fixture to run trainNet.py via subprocess. + Returns a function that runs training and returns the saved network. + """ + + def _run(pRNNtype, **cli_args): + """ + Run trainNet.py with given arguments. + + Args: + pRNNtype: The RNN architecture type + **cli_args: Additional CLI arguments as key=value + + Returns: + Loaded PredictiveNet object + """ + save_subfolder = "tmp/" + tmp_path.name + "/" + # Build command + cmd = [ + "python", + str(TRAIN_SCRIPT), + "--pRNNtype", + pRNNtype, + "--hidden_size", + "100", + "--numepochs", + "0", # Minimal for speed + "--numtrials", + "0", # Minimal for speed + "--savefolder", + save_subfolder, + "--namext", + "test", + "--test", + "--noDataLoader", + ] + + # Add any extra CLI arguments + for key, value in cli_args.items(): + if isinstance(value, bool): + if value: # Only add flag if True + cmd.append(f"--{key}") + else: + cmd.extend([f"--{key}", str(value)]) + + # Run trainNet.py + result = subprocess.run( + cmd, + capture_output=True, + text=True, # Adjust path as needed + cwd=REPO_ROOT, + ) + + # Check for errors + assert result.returncode == 0, ( + f"trainNet.py failed\n" + f"Command: {' '.join(cmd)}\n" + f"Stdout: {result.stdout}\n" + f"Stderr: {result.stderr}" + ) + seed = cli_args.get("seed", 8) + savename = f"{pRNNtype}-test-s{seed}" + pkl_file = REPO_ROOT / "nets" / save_subfolder / (savename + ".pkl") + + if not pkl_file.exists(): + # Helpful diagnostics + nets_dir = REPO_ROOT / "nets" / save_subfolder + if nets_dir.exists(): + found = [f.name for f in nets_dir.glob("*")] + listing = "\n ".join(found) or "(empty)" + msg = ( + f"Expected: {pkl_file}\n" + f"Directory exists but contains:\n {listing}\n" + f"Stdout: {result.stdout}\n" + f"Stderr: {result.stderr}" + ) + else: + # Walk nets/ to see what DID get created + nets_root = REPO_ROOT / "nets" + found_any = list(nets_root.rglob("*.pkl")) if nets_root.exists() else [] + msg = ( + f"Expected: {pkl_file}\n" + f"Directory does not exist: {nets_dir}\n" + f"All .pkl files under nets/: {found_any}\n" + f"Stdout: {result.stdout}\n" + f"Stderr: {result.stderr}" + ) + pytest.fail(msg) + + # loadNet likely expects the same string saveNet received (without "nets/" prefix and without ".pkl") + load_string = str( + Path("tmp") / tmp_path.name / savename + ) # e.g. "test_thRNN_0win_00/thRNN-test-s8" + + try: + from prnn.utils.predictiveNet import PredictiveNet + + net = PredictiveNet.loadNet(load_string) + return net + except Exception as e: + pytest.fail( + f"loadNet failed for string '{load_string}'\n" + f"pkl_file confirmed exists: {pkl_file.exists()}\n" + f"cwd: {Path.cwd()}\n" + f"Error: {e}" + ) + + return _run + + +# ============================================================================ +# TEST PARTIAL PRESETS ARE PRESERVED +# ============================================================================ + + +class TestPartialPresetsViaCLI: + """Test that CLI preserves partial function presets""" + + # Masked + def test_thRNN_0win_noLN_k_and_cell(self, run_trainnet): + """Critical: thRNN_0win should preserve k=0, not override with default""" + net = run_trainnet("thRNN_0win_noLN") + actual_k = len(net.pRNN.inMask) - 1 + actual_cell = type(net.pRNN.rnn.cell) + assert actual_k == 0, f"Expected k=0, got {actual_k}" + assert actual_cell == RNNCell, f"Expected RNNCell, got {actual_cell}" + + def test_thRNN_10win_k_and_cell(self, run_trainnet): + """Critical: thRNN_0win should preserve k=0, not override with default""" + net = run_trainnet("thRNN_10win") + actual_k = len(net.pRNN.inMask) - 1 + actual_cell = type(net.pRNN.rnn.cell) + assert actual_k == 10, f"Expected k=10, got {actual_k}" + assert actual_cell == LayerNormRNNCell, f"Expected LayerNormRNNCell, got {actual_cell}" + + def test_thRNN_prevAct_preserves_actOffset(self, run_trainnet): + """Test _prevAct variants preserve actOffset=1""" + net = run_trainnet("thRNN_6win_prevAct") + actual_k = len(net.pRNN.inMask) - 1 + actual_cell = type(net.pRNN.rnn.cell) + assert actual_k == 6, f"Expected k=6, got {actual_k}" + assert net.pRNN.actOffset == 1, f"Expected actOffset=1, got {net.pRNN.actOffset}" + assert actual_cell == LayerNormRNNCell, f"Expected LayerNormRNNCell, got {actual_cell}" + + def test_thRNN_mask_preserves_mask_actions(self, run_trainnet): + """Test _mask variants preserve mask_actions""" + net = run_trainnet("thRNN_4win_mask") + expected_mask = [True, False, False, False, False] + actual_k = len(net.pRNN.inMask) - 1 + actual_cell = type(net.pRNN.rnn.cell) + assert actual_k == 4, f"Expected k=4, got {actual_k}" + assert actual_cell == LayerNormRNNCell, f"Expected LayerNormRNNCell, got {actual_cell}" + assert list(net.pRNN.actMask) == expected_mask, ( + f"Expected actMask={expected_mask}, got {list(net.pRNN.actMask)}" + ) + + # Rollouts + def test_rollout_k(self, run_trainnet): + net = run_trainnet("thcycRNN_3win") + assert net.pRNN.k == 3, f"Expected k=3, got {net.pRNN.k}" + + def test_rollout_action_and_continuousTheta(self, run_trainnet): + """Test rollout variants preserve continuousTheta with rollout_action""" + # holdc = True + net = run_trainnet("thcycRNN_5win_fullc") + assert net.pRNN.rnn.continuousTheta, ( + f"Expected continuousTheta=True, got {net.pRNN.rnn.continuousTheta}" + ) + assert net.pRNN.actionTheta, f"Expected actionTheta=True, got {net.pRNN.actionTheta}" + + # hold = False + net = run_trainnet("thcycRNN_5win_hold") + assert not net.pRNN.rnn.continuousTheta, ( + f"Expected continuousTheta=False, got {net.pRNN.rnn.continuousTheta}" + ) + assert net.pRNN.actionTheta == "hold", ( + f"Expected actionTheta='hold', got {net.pRNN.actionTheta}" + ) + + def test_rollout_action_conttheta_actoffset(self, run_trainnet): + net = run_trainnet("thcycRNN_5win_firstc_prevAct") + actual_cell = type(net.pRNN.rnn.cell) + assert net.pRNN.rnn.continuousTheta, ( + f"Expected continuousTheta=True, got {net.pRNN.rnn.continuousTheta}" + ) + assert not net.pRNN.actionTheta, f"Expected actionTheta=False, got {net.pRNN.actionTheta}" + assert net.pRNN.actOffset == 1, f"Expected actOffset=1, got {net.pRNN.actOffset}" + assert actual_cell == LayerNormRNNCell, f"Expected LayerNormRNNCell, got {actual_cell}" + + # NextStep + def test_autoencoder_use_FF(self, run_trainnet): + """Test that Autoencoder variants freeze W for use_FF (fast forward)""" + net = run_trainnet("AutoencoderFF") + actual_cell = type(net.pRNN.rnn.cell) + W_found = False + for name, param in net.pRNN.named_parameters(): + if name == "W": + W_found = True + assert not param.requires_grad, ( + f"Expected W.requires_grad=False, got {param.requires_grad}" + ) + break + + assert W_found, "W parameter not found" + assert net.pRNN.predOffset == 0, f"Expected predOffset=0, got {net.pRNN.predOffset}" + assert actual_cell == RNNCell, f"Expected RNNCell, got {actual_cell}" + net = run_trainnet("AutoencoderPred_LN") + actual_cell = type(net.pRNN.rnn.cell) + W_found = False + for name, param in net.pRNN.named_parameters(): + if name == "W": + W_found = True + assert param.requires_grad, ( + f"Expected W.requires_grad=True, got {param.requires_grad}" + ) + break + + assert W_found, "W parameter not found" + assert net.pRNN.predOffset == 1, f"Expected predOffset=1, got {net.pRNN.predOffset}" + assert actual_cell == LayerNormRNNCell, f"Expected LayerNormCell, got {actual_cell}" + + # Test LognRNNs + def test_lognRNNs(self, run_trainnet): + net = run_trainnet("lognRNN_mask") + actual_k = len(net.pRNN.inMask) - 1 + actual_cell = type(net.pRNN.rnn.cell) + + assert actual_k == 5, f"Expected k=5 got {actual_k}" + assert actual_cell == LayerNormRNNCell, f"Expected RNNCell, got {actual_cell}" + assert (net.pRNN.W >= 0).all(), ( + "Expected postiive lognormal distribution, but recurrent weights are negative" + ) + assert (net.pRNN.W_in >= 0).all(), ( + "Expected postiive lognormal distribution, but input weights are negative" + ) + sparsity_w = (net.pRNN.W == 0).float().mean().item() * 100 + sparsity_w_in = (net.pRNN.W_in == 0).float().mean().item() * 100 + expected_sparsity = 95.0 + tolerance = 5.0 + + assert abs(sparsity_w - expected_sparsity) < tolerance, ( + f"Expected ~{expected_sparsity}% zeros (sparsity=0.05), but got {sparsity_w:.2f}% zeros for recurrent weights" + ) + assert abs(sparsity_w_in - expected_sparsity) < tolerance, ( + f"Expected ~{expected_sparsity}% zeros (sparsity=0.05), but got {sparsity_w_in:.2f}% zeros for input weights" + ) + net = run_trainnet("lognRNN_rollout") + actual_k = net.pRNN.k + actual_cell = type(net.pRNN.rnn.cell) + + assert actual_k == 5, f"Expected k=5 got {actual_k}" + assert actual_cell == LayerNormRNNCell, f"Expected RNNCell, got {actual_cell}" + assert (net.pRNN.W >= 0).all(), ( + "Expected postiive lognormal distribution, but recurrent weights are negative" + ) + assert (net.pRNN.W_in >= 0).all(), ( + "Expected postiive lognormal distribution, but input weights are negative" + ) + sparsity_w = (net.pRNN.W == 0).float().mean().item() * 100 + sparsity_w_in = (net.pRNN.W_in == 0).float().mean().item() * 100 + expected_sparsity = 95.0 + tolerance = 5.0 + + assert abs(sparsity_w - expected_sparsity) < tolerance, ( + f"Expected ~{expected_sparsity}% zeros (sparsity=0.05), but got {sparsity_w:.2f}% zeros for recurrent weights" + ) + assert abs(sparsity_w_in - expected_sparsity) < tolerance, ( + f"Expected ~{expected_sparsity}% zeros (sparsity=0.05), but got {sparsity_w_in:.2f}% zeros for input weights" + ) + assert not net.pRNN.rnn.continuousTheta, ( + f"Expected continuousTheta=False, got {net.pRNN.rnn.continuousTheta}" + ) + assert net.pRNN.actionTheta, f"Expected actionTheta=True, got {net.pRNN.actionTheta}" + + +# ============================================================================ +# TEST CLI OVERRIDES WORK +# ============================================================================ + + +class TestCLIOverrides: + """Test that command-line arguments override partial presets""" + + def test_k_override(self, run_trainnet): + """Test --k=10 overrides thRNN_5win's k=5""" + net = run_trainnet("thRNN_5win", k=10) + actual_k = len(net.pRNN.inMask) - 1 + assert actual_k == 10, f"Expected k=10 (override), got {actual_k}" + + def test_actOffset_override(self, run_trainnet): + """Test --actOffset=3 overrides preset""" + net = run_trainnet("thRNN_5win_prevAct", actOffset=3) + assert net.pRNN.actOffset == 3, f"Expected actOffset=3 (override), got {net.pRNN.actOffset}" + + def test_continuousTheta_override(self, run_trainnet): + """Test --continuousTheta overrides preset""" + # thcycRNN_5win_hold has continuousTheta=False by default + net = run_trainnet("thcycRNN_5win_hold", continuousTheta=True) + assert net.pRNN.rnn.continuousTheta == True, ( + f"Expected continuousTheta=True (override), got {net.pRNN.rnn.continuousTheta}" + ) + + def test_cell_override(self, run_trainnet): + """Test --cell overrides partial preset""" + # thRNN_0win has LayerNormRNNCell by default + net = run_trainnet("thRNN_0win", cell="RNNCell") + assert type(net.pRNN.rnn.cell) == RNNCell, ( + f"Expected RNNCell (override), got {type(net.pRNN.rnn.cell).__name__}" + ) + + +# ============================================================================ +# TEST DIVNORM PARAMETERS +# ============================================================================ + + +# class TestDivNormViaCLI: +# """Test that DivNorm parameters flow through CLI correctly""" + +# def test_divnorm_parameters_trainable(self, run_trainnet): +# """Test --train_divnorm makes parameters trainable""" +# net = run_trainnet( +# "Masked", +# cell="DivNormRNNCell", +# target_mean=0.8, +# k_div=2.0, +# sigma=1.5, +# train_divnorm=True, +# ) + +# cell = net.pRNN.rnn.cell + +# # Check values reached the cell +# # assert cell.target_mean == 0.8, f"Expected target_mean=0.8, got {cell.target_mean}" +# assert cell.divnorm.k_div.item() == 2.0, ( +# f"Expected k_div=2.0, got {cell.divnorm.k_div.item()}" +# ) +# assert cell.divnorm.sigma.item() == 1.5, ( +# f"Expected sigma=1.5, got {cell.divnorm.sigma.item()}" +# ) + +# # Check they're trainable +# import torch.nn as nn + +# assert isinstance(cell.divnorm.k_div, nn.Parameter), ( +# "k_div should be a Parameter when train_divnorm=True" +# ) +# assert isinstance(cell.divnorm.sigma, nn.Parameter), ( +# "sigma should be a Parameter when train_divnorm=True" +# ) + +# def test_divnorm_parameters_fixed(self, run_trainnet): +# """Test that without --train_divnorm, parameters are fixed""" +# net = run_trainnet( +# "Masked", +# cell="DivNormRNNCell", +# target_mean=0.7, +# k_div=1.0, +# sigma=1.0, +# # train_divnorm=False is default +# ) + +# cell = net.pRNN.rnn.cell + +# # Check they're NOT trainable (buffers, not Parameters) +# import torch.nn as nn + +# assert not isinstance(cell.divnorm.k_div, nn.Parameter), ( +# "k_div should be a buffer when train_divnorm=False" +# ) +# assert not isinstance(cell.divnorm.sigma, nn.Parameter), ( +# "sigma should be a buffer when train_divnorm=False" +# ) + + +# ============================================================================ +# TEST DIVNORM PARAMETERS +# ============================================================================ + + +# class TestDivNormViaCLI: +# """Test that DivNorm parameters flow through CLI correctly""" + +# def test_divnorm_parameters_trainable(self, run_trainnet): +# """Test --train_divnorm makes parameters trainable""" +# net = run_trainnet( +# "Masked", +# cell="DivNormRNNCell", +# target_mean=0.8, +# k_div=2.0, +# sigma=1.5, +# train_divnorm=True, +# ) + +# cell = net.pRNN.rnn.cell + +# # Check values reached the cell +# # assert cell.target_mean == 0.8, f"Expected target_mean=0.8, got {cell.target_mean}" +# assert cell.divnorm.k_div.item() == 2.0, ( +# f"Expected k_div=2.0, got {cell.divnorm.k_div.item()}" +# ) +# assert cell.divnorm.sigma.item() == 1.5, ( +# f"Expected sigma=1.5, got {cell.divnorm.sigma.item()}" +# ) + +# # Check they're trainable +# import torch.nn as nn + +# assert isinstance(cell.divnorm.k_div, nn.Parameter), ( +# "k_div should be a Parameter when train_divnorm=True" +# ) +# assert isinstance(cell.divnorm.sigma, nn.Parameter), ( +# "sigma should be a Parameter when train_divnorm=True" +# ) + +# def test_divnorm_parameters_fixed(self, run_trainnet): +# """Test that without --train_divnorm, parameters are fixed""" +# net = run_trainnet( +# "Masked", +# cell="DivNormRNNCell", +# target_mean=0.7, +# k_div=1.0, +# sigma=1.0, +# # train_divnorm=False is default +# ) + +# cell = net.pRNN.rnn.cell + +# # Check they're NOT trainable (buffers, not Parameters) +# import torch.nn as nn + +# assert not isinstance(cell.divnorm.k_div, nn.Parameter), ( +# "k_div should be a buffer when train_divnorm=False" +# ) +# assert not isinstance(cell.divnorm.sigma, nn.Parameter), ( +# "sigma should be a buffer when train_divnorm=False" +# ) + +# ============================================================================ +# TEST LR OPTIMIZER CONFIGURATION +# ============================================================================ + + +class TestLRCLI: + """Test that learning rate configurations work correctly via CLI""" + + def test_eg_lr_configured(self, run_trainnet): + """Test --eg_lr configures EG optimizer with scaled weight decay""" + net = run_trainnet("lognRNN_mask", eg_lr=1e-3, eg_weight_decay=1e-6) + + # Check that EG parameter groups exist + has_eg = False + W_group = None + # W_out = None + W_in = None + for group in net.optimizer.param_groups: + if group.get("update_alg") == "eg": + if group["name"] == "RecurrentWeights": + W_group = group + # if group["name"] == "OutputWeights": + # W_out = group + if group["name"] == "InputWeights": + W_in = group + has_eg = True + assert group["lr"] == 1e-3, f"Expected eg_lr=1e-3, got {group['lr']}" + + # Check weight decay is scaled by lr + expected_wd = 1e-6 * 1e-3 + assert abs(group["weight_decay"] - expected_wd) < 1e-12, ( + f"Expected weight_decay={expected_wd}, got {group['weight_decay']}" + ) + + assert W_group is not None, "Recurring weights not properly updated to EG" + assert W_in is not None, "Input weights not properly updated to EG" + # assert W_out is not None, "Output weights not properly updated to EG" + assert has_eg, "No parameter groups configured for EG" + + def test_bias_lr_uses_gd_not_eg(self, run_trainnet): + """Test that --bias_lr configures bias learning rate and doesn't use EG""" + net = run_trainnet( + "thRNN_5win", + trainBias=True, + lr=2e-3, + weight_decay=3e-3, + bias_lr=2.0, + eg_lr=1e-3, + eg_weight_decay=1e-6, + ) + + # Find the bias parameter group + bias_group = None + for group in net.optimizer.param_groups: + if group["name"] == "biases": + bias_group = group + break + + # Check bias group exists + assert bias_group is not None, "Bias parameter group not found in optimizer" + + # Check bias_lr is applied + expected_bias_lr = 2.0 * 2e-3 + assert bias_group["lr"] == expected_bias_lr, ( + f"Expected bias lr={expected_bias_lr}, got {bias_group['lr']}" + ) + + # Check biases DON'T use EG (should use GD) + assert bias_group.get("update_alg") != "eg", ( + f"Biases should use GD, not EG. Got update_alg={bias_group.get('update_alg')}" + ) + + # Biases should have standard weight decay (not EG scaled) + # For GD: weight_decay is scaled by base_lr * bias_lr * weigth_decay + # Check it's NOT scaled by eg_lr + assert bias_group["weight_decay"] == 2.0 * 2e-3 * 3e-3, ( + "Biases should not use EG-scaled weight decay" + ) + + +# def test_divnorm_trainable_uses_eg(self, run_trainnet): +# """Test that trainable DivNorm parameters (k_div, sigma) use EG with scaled weight decay""" +# net = run_trainnet( +# "Masked", +# cell="DivNormRNNCell", +# target_mean=0.7, +# k_div=1.0, +# sigma=1.0, +# train_divnorm=True, +# eg_lr=1e-3, +# eg_weight_decay=1e-6, +# ) + +# # Find k_div and sigma parameter groups +# k_div_group = None +# sigma_group = None + +# for group in net.optimizer.param_groups: +# if group["name"] == "k_divnorm": +# k_div_group = group +# elif group["name"] == "sigma_divnorm": +# sigma_group = group + +# # Check both groups exist +# assert k_div_group is not None, "k_divnorm parameter group not found in optimizer" +# assert sigma_group is not None, "sigma_divnorm parameter group not found in optimizer" + +# # Check k_div uses EG +# assert k_div_group.get("update_alg") == "eg", ( +# f"k_div should use EG, got update_alg={k_div_group.get('update_alg')}" +# ) + +# # Check sigma uses EG +# assert sigma_group.get("update_alg") == "eg", ( +# f"sigma should use EG, got update_alg={sigma_group.get('update_alg')}" +# ) + +# # Check k_div has EG lr +# assert k_div_group["lr"] == 1e-3, f"k_div should use eg_lr=1e-3, got {k_div_group['lr']}" + +# # Check sigma has EG lr +# assert sigma_group["lr"] == 1e-3, f"sigma should use eg_lr=1e-3, got {sigma_group['lr']}" + +# # Check k_div has scaled weight decay (eg_weight_decay * eg_lr) +# expected_wd = 1e-6 * 1e-3 +# assert abs(k_div_group["weight_decay"] - expected_wd) < 1e-12, ( +# f"k_div should use scaled weight_decay={expected_wd}, got {k_div_group['weight_decay']}" +# ) + +# # Check sigma has scaled weight decay (eg_weight_decay * eg_lr) +# assert abs(sigma_group["weight_decay"] - expected_wd) < 1e-12, ( +# f"sigma should use scaled weight_decay={expected_wd}, got {sigma_group['weight_decay']}" +# ) + +# # Verify parameters are actually trainable (Parameters, not buffers) +# import torch.nn as nn + +# cell = net.pRNN.rnn.cell +# assert isinstance(cell.divnorm.k_div, nn.Parameter), ( +# "k_div should be a Parameter when train_divnorm=True" +# ) +# assert isinstance(cell.divnorm.sigma, nn.Parameter), ( +# "sigma should be a Parameter when train_divnorm=True" +# ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "--tb=short"]) From 532eac364e3d83a256eba63916ed92533807d0f4 Mon Sep 17 00:00:00 2001 From: Meghan Date: Wed, 18 Feb 2026 15:52:03 -0500 Subject: [PATCH 3/7] unsure --- tests/{test.py => test_predictivenet.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/{test.py => test_predictivenet.py} (100%) diff --git a/tests/test.py b/tests/test_predictivenet.py similarity index 100% rename from tests/test.py rename to tests/test_predictivenet.py From 2102704bac9cc76623108ca694fb6b532a8456e7 Mon Sep 17 00:00:00 2001 From: Meghan Date: Wed, 18 Feb 2026 16:09:07 -0500 Subject: [PATCH 4/7] got rid of unused tests in predictivenet --- tests/test_predictivenet.py | 176 ------------------------------------ 1 file changed, 176 deletions(-) diff --git a/tests/test_predictivenet.py b/tests/test_predictivenet.py index 594fdb32..a469e2a4 100644 --- a/tests/test_predictivenet.py +++ b/tests/test_predictivenet.py @@ -236,94 +236,6 @@ def test_predOffset_values(self, mock_env, base_architecture_kwargs): ) -# ============================================================================ -# TEST 2: ARGPARSE OVERRIDES WORK CORRECTLY -# ============================================================================ - - -# class TestArgparseOverrides: -# """Test that argparse arguments correctly override partial presets""" - -# def test_k_override(self, mock_env, base_architecture_kwargs): -# """Test that explicit k argument overrides partial preset""" -# # thRNN_5win has k=5 by default -# kwargs = {**base_architecture_kwargs, "k": 10} -# net = PredictiveNet(mock_env, pRNNtype="thRNN_5win", hidden_size=100, **kwargs) -# assert net.pRNN.k == 10, f"Expected k=10 (override), got k={net.pRNN.k}" - -# def test_actOffset_override(self, mock_env, base_architecture_kwargs): -# """Test that explicit actOffset overrides partial preset""" -# # thRNN_5win_prevAct has actOffset=1 by default -# kwargs = {**base_architecture_kwargs, "actOffset": 3} -# net = PredictiveNet(mock_env, pRNNtype="thRNN_5win_prevAct", hidden_size=100, **kwargs) -# assert net.pRNN.actOffset == 3, f"Expected actOffset=3 (override), got {net.pRNN.actOffset}" - -# def test_continuousTheta_override(self, mock_env, base_architecture_kwargs): -# """Test that explicit continuousTheta overrides partial preset""" -# # thcycRNN_5win_hold has continuousTheta=False by default -# kwargs = {**base_architecture_kwargs, "continuousTheta": True} -# net = PredictiveNet(mock_env, pRNNtype="thcycRNN_5win_hold", hidden_size=100, **kwargs) -# assert net.pRNN.continuousTheta == True, ( -# f"Expected continuousTheta=True (override), got {net.pRNN.continuousTheta}" -# ) - -# def test_rollout_action_override(self, mock_env, base_architecture_kwargs): -# """Test that explicit rollout_action overrides partial preset""" -# # thcycRNN_5win_hold has rollout_action="hold" by default -# kwargs = {**base_architecture_kwargs, "rollout_action": "first"} -# net = PredictiveNet(mock_env, pRNNtype="thcycRNN_5win_hold", hidden_size=100, **kwargs) -# assert net.pRNN.rollout_action == "first", ( -# f"Expected rollout_action='first' (override), got {net.pRNN.rollout_action}" -# ) - -# def test_use_FF_override(self, mock_env, base_architecture_kwargs): -# """Test that explicit use_FF overrides partial preset""" -# # AutoencoderRec has use_FF=False by default -# kwargs = {**base_architecture_kwargs, "use_FF": True} -# net = PredictiveNet(mock_env, pRNNtype="AutoencoderRec", hidden_size=100, **kwargs) -# assert net.pRNN.use_FF == True, f"Expected use_FF=True (override), got {net.pRNN.use_FF}" - -# def test_mask_actions_override(self, mock_env, base_architecture_kwargs): -# """Test that explicit mask_actions overrides partial preset""" -# # thRNN_5win has mask_actions=False by default -# kwargs = {**base_architecture_kwargs, "mask_actions": True} -# net = PredictiveNet(mock_env, pRNNtype="thRNN_5win", hidden_size=100, **kwargs) -# assert net.pRNN.mask_actions == True, ( -# f"Expected mask_actions=True (override), got {net.pRNN.mask_actions}" -# ) - - -# # ============================================================================ -# # TEST 3: NONE VALUES DON'T OVERRIDE PARTIALS -# # ============================================================================ - - -# class TestNoneDoesNotOverride: -# """Test that None values in kwargs don't override partial presets""" - -# def test_k_none_preserves_partial(self, mock_env, base_architecture_kwargs): -# """Test that k=None doesn't override partial preset""" -# kwargs = {**base_architecture_kwargs, "k": None} -# net = PredictiveNet(mock_env, pRNNtype="thRNN_5win", hidden_size=100, **kwargs) -# assert net.pRNN.k == 5, f"k=None should preserve partial preset (5), got {net.pRNN.k}" - -# def test_continuousTheta_none_preserves_partial(self, mock_env, base_architecture_kwargs): -# """Test that continuousTheta=None doesn't override partial preset""" -# kwargs = {**base_architecture_kwargs, "continuousTheta": None} -# net = PredictiveNet(mock_env, pRNNtype="thcycRNN_5win_holdc", hidden_size=100, **kwargs) -# assert net.pRNN.continuousTheta == True, ( -# f"continuousTheta=None should preserve partial preset (True), got {net.pRNN.continuousTheta}" -# ) - -# def test_init_none_preserves_partial(self, mock_env, base_architecture_kwargs): -# """Test that init=None doesn't override lognRNN partial preset""" -# kwargs = {**base_architecture_kwargs, "init": None} -# net = PredictiveNet(mock_env, pRNNtype="lognRNN_mask", hidden_size=100, **kwargs) -# # Can't directly check init, but weights should still be log-normal -# W_hh = net.pRNN.W.detach().cpu().numpy().flatten() -# assert W_hh.mean() > 0 # Basic check for log-normal - - # # ============================================================================ # # TEST 4: CELL-SPECIFIC PARAMETERS FLOW CORRECTLY # # ============================================================================ @@ -555,94 +467,6 @@ def test_predOffset_values(self, mock_env, base_architecture_kwargs): # f"{group['name']} weight_decay should be {expected_wd}" # ) - -# ============================================================================ -# TEST 6: INTEGRATION TESTS (Full Workflow) -# ============================================================================ - - -# class TestIntegration: -# """Integration tests for full workflow""" - -# def test_train_with_divnorm_trainable(self, mock_env, base_architecture_kwargs): -# """Test that training works with trainable DivNorm parameters""" -# kwargs = { -# **base_architecture_kwargs, -# "target_mean": 0.7, -# "k_div": 1.0, -# "sigma": 1.0, -# "train_divnorm": True, -# } - -# net = PredictiveNet( -# mock_env, -# pRNNtype="Masked", -# hidden_size=50, # Small for speed -# cell="DivNormRNNCell", -# **kwargs, -# ) - -# # Get initial values -# k_div_initial = net.pRNN.rnn.cell.divnorm.k_div.item() -# sigma_initial = net.pRNN.rnn.cell.divnorm.sigma.item() - -# # Run a few training steps -# from prnn.utils.agent import create_agent - -# agent = create_agent("RandomActionAgent", mock_env) - -# # Just run a couple iterations to make sure it doesn't crash -# for _ in range(2): -# obs, act, _, _ = net.collectObservationSequence(mock_env, agent, tsteps=50) -# net.trainStep(obs, act) - -# # Check that parameters changed (they should if trainable) -# k_div_final = net.pRNN.rnn.cell.divnorm.k_div.item() -# sigma_final = net.pRNN.rnn.cell.divnorm.sigma.item() - -# # At least one should have changed -# assert k_div_initial != k_div_final or sigma_initial != sigma_final, ( -# "Trainable DivNorm parameters should change during training" -# ) - -# def test_all_partial_architectures_instantiate(self, mock_env, base_architecture_kwargs): -# """Smoke test: all partial architectures can be instantiated""" -# architectures_to_test = [ -# "NextStep", -# "Masked", -# "Rollout", -# "AutoencoderFF", -# "AutoencoderRec", -# "AutoencoderPred", -# "AutoencoderFFPred", -# "AutoencoderFF_LN", -# "AutoencoderRec_LN", -# "AutoencoderPred_LN", -# "AutoencoderFFPred_LN", -# "thRNN_0win", -# "thRNN_5win", -# "thRNN_10win", -# "thRNN_0win_prevAct", -# "thRNN_5win_prevAct", -# "thRNN_1win_mask", -# "thRNN_5win_mask", -# "thcycRNN_5win", -# "thcycRNN_5win_hold", -# "thcycRNN_5win_holdc", -# "lognRNN_mask", -# "lognRNN_rollout", -# ] - -# for pRNNtype in architectures_to_test: -# try: -# net = PredictiveNet( -# mock_env, pRNNtype=pRNNtype, hidden_size=50, **base_architecture_kwargs -# ) -# assert net is not None, f"{pRNNtype} failed to instantiate" -# except Exception as e: -# pytest.fail(f"{pRNNtype} raised exception: {e}") - - # ============================================================================ # RUN TESTS # ============================================================================ From 40424456d7ef8ebf5117a0ef6eda3639232e5ea1 Mon Sep 17 00:00:00 2001 From: Meghan Date: Wed, 18 Feb 2026 16:23:24 -0500 Subject: [PATCH 5/7] all tests now pass, and all divnorm tests have been removed --- tests/test_predictivenet.py | 325 +++++++++++------------------------- tests/test_trainnet.py | 131 --------------- 2 files changed, 95 insertions(+), 361 deletions(-) diff --git a/tests/test_predictivenet.py b/tests/test_predictivenet.py index a469e2a4..3d08055b 100644 --- a/tests/test_predictivenet.py +++ b/tests/test_predictivenet.py @@ -236,236 +236,101 @@ def test_predOffset_values(self, mock_env, base_architecture_kwargs): ) -# # ============================================================================ -# # TEST 4: CELL-SPECIFIC PARAMETERS FLOW CORRECTLY -# # ============================================================================ - - -# class TestCellParameters: -# """Test that cell-specific parameters reach the cell correctly""" - -# def test_divnorm_parameters_reach_cell(self, mock_env, base_architecture_kwargs): -# """Test that DivNorm parameters are correctly passed to cell""" -# kwargs = { -# **base_architecture_kwargs, -# "target_mean": 0.8, -# "k_div": 2.0, -# "sigma": 1.5, -# "train_divnorm": False, -# } - -# net = PredictiveNet( -# mock_env, pRNNtype="Masked", hidden_size=100, cell="DivNormRNNCell", **kwargs -# ) - -# cell = net.pRNN.rnn.cell - -# # Check parameters reached the cell -# assert hasattr(cell, "divnorm"), "Cell should have divnorm module" -# assert hasattr(cell, "target_mean"), "Cell should have target_mean" -# assert cell.target_mean == 0.8, f"Expected target_mean=0.8, got {cell.target_mean}" - -# # Check k_div and sigma values -# assert cell.divnorm.k_div.item() == 2.0, ( -# f"Expected k_div=2.0, got {cell.divnorm.k_div.item()}" -# ) -# assert cell.divnorm.sigma.item() == 1.5, ( -# f"Expected sigma=1.5, got {cell.divnorm.sigma.item()}" -# ) - -# def test_divnorm_trainable_parameters(self, mock_env, base_architecture_kwargs): -# """Test that train_divnorm correctly makes parameters trainable""" -# # Test with train_divnorm=True -# kwargs_trainable = { -# **base_architecture_kwargs, -# "target_mean": 0.7, -# "k_div": 1.0, -# "sigma": 1.0, -# "train_divnorm": True, -# } - -# net = PredictiveNet( -# mock_env, pRNNtype="Masked", hidden_size=100, cell="DivNormRNNCell", **kwargs_trainable -# ) - -# cell = net.pRNN.rnn.cell -# assert isinstance(cell.divnorm.k_div, torch.nn.Parameter), ( -# "k_div should be a Parameter when train_divnorm=True" -# ) -# assert isinstance(cell.divnorm.sigma, torch.nn.Parameter), ( -# "sigma should be a Parameter when train_divnorm=True" -# ) - -# # Test with train_divnorm=False -# kwargs_fixed = { -# **base_architecture_kwargs, -# "target_mean": 0.7, -# "k_div": 1.0, -# "sigma": 1.0, -# "train_divnorm": False, -# } - -# net_fixed = PredictiveNet( -# mock_env, pRNNtype="Masked", hidden_size=100, cell="DivNormRNNCell", **kwargs_fixed -# ) - -# cell_fixed = net_fixed.pRNN.rnn.cell -# assert not isinstance(cell_fixed.divnorm.k_div, torch.nn.Parameter), ( -# "k_div should be a buffer when train_divnorm=False" -# ) -# assert not isinstance(cell_fixed.divnorm.sigma, torch.nn.Parameter), ( -# "sigma should be a buffer when train_divnorm=False" -# ) - -# def test_cell_override_works(self, mock_env, base_architecture_kwargs): -# """Test that cell override argument works""" -# # Test with LayerNormRNNCell -# net_ln = PredictiveNet( -# mock_env, -# pRNNtype="Masked", -# hidden_size=100, -# cell="LayerNormRNNCell", -# **base_architecture_kwargs, -# ) - -# assert net_ln.pRNN.rnn.cell.__class__.__name__ == "LayerNormRNNCell", ( -# f"Expected LayerNormRNNCell, got {net_ln.pRNN.rnn.cell.__class__.__name__}" -# ) - -# # Test with DivNormRNNCell -# kwargs_divnorm = { -# **base_architecture_kwargs, -# "target_mean": 0.7, -# "train_divnorm": False, -# } -# net_dn = PredictiveNet( -# mock_env, pRNNtype="Masked", hidden_size=100, cell="DivNormRNNCell", **kwargs_divnorm -# ) - -# assert net_dn.pRNN.rnn.cell.__class__.__name__ == "DivNormRNNCell", ( -# f"Expected DivNormRNNCell, got {net_dn.pRNN.rnn.cell.__class__.__name__}" -# ) - -# def test_sparsity_reaches_layernorm_cell(self, mock_env, base_architecture_kwargs): -# """Test that sparsity parameter reaches LayerNormRNNCell""" -# kwargs = { -# **base_architecture_kwargs, -# "sparsity": 0.3, -# } - -# net = PredictiveNet( -# mock_env, -# pRNNtype="lognRNN_mask", # Uses LayerNormRNNCell -# hidden_size=100, -# **kwargs, -# ) - -# cell = net.pRNN.rnn.cell -# assert hasattr(cell, "f"), "LayerNormRNNCell should have sparsity (f)" -# # sparsity should equal cell.f -# assert cell.f == 0.3, f"Expected f=0.3, got {cell.f}" - - -# # ============================================================================ -# # TEST 5: OPTIMIZER PARAMETER GROUPS -# # ============================================================================ - - -# class TestOptimizerParameterGroups: -# """Test that optimizer parameter groups are correctly configured""" - -# def test_divnorm_in_optimizer_when_trainable(self, mock_env, base_architecture_kwargs): -# """Test that k_div and sigma are in optimizer when train_divnorm=True""" -# kwargs = { -# **base_architecture_kwargs, -# "target_mean": 0.7, -# "train_divnorm": True, -# } - -# net = PredictiveNet( -# mock_env, -# pRNNtype="Masked", -# hidden_size=100, -# cell="DivNormRNNCell", -# trainBias=False, -# **kwargs, -# ) - -# # Check optimizer has k_div and sigma groups -# param_group_names = [g["name"] for g in net.optimizer.param_groups] -# assert "k_divnorm" in param_group_names, "k_divnorm should be in optimizer parameter groups" -# assert "sigma_divnorm" in param_group_names, ( -# "sigma_divnorm should be in optimizer parameter groups" -# ) - -# def test_divnorm_not_in_optimizer_when_fixed(self, mock_env, base_architecture_kwargs): -# """Test that k_div and sigma are NOT in optimizer when train_divnorm=False""" -# kwargs = { -# **base_architecture_kwargs, -# "target_mean": 0.7, -# "train_divnorm": False, -# } - -# net = PredictiveNet( -# mock_env, -# pRNNtype="Masked", -# hidden_size=100, -# cell="DivNormRNNCell", -# trainBias=False, -# **kwargs, -# ) - -# # Check optimizer does NOT have k_div and sigma groups -# param_group_names = [g["name"] for g in net.optimizer.param_groups] -# assert "k_divnorm" not in param_group_names, ( -# "k_divnorm should NOT be in optimizer when train_divnorm=False" -# ) -# assert "sigma_divnorm" not in param_group_names, ( -# "sigma_divnorm should NOT be in optimizer when train_divnorm=False" -# ) - -# def test_bias_in_optimizer_when_trainable(self, mock_env, base_architecture_kwargs): -# """Test that bias is in optimizer when trainBias=True""" -# net = PredictiveNet( -# mock_env, pRNNtype="Masked", hidden_size=100, trainBias=True, **base_architecture_kwargs -# ) - -# param_group_names = [g["name"] for g in net.optimizer.param_groups] -# assert "biases" in param_group_names, ( -# "biases should be in optimizer parameter groups when trainBias=True" -# ) - -# def test_eg_parameter_groups_configured(self, mock_env, base_architecture_kwargs): -# """Test that EG is correctly configured for parameter groups""" -# kwargs = { -# **base_architecture_kwargs, -# "target_mean": 0.7, -# "train_divnorm": True, -# } - -# net = PredictiveNet( -# mock_env, -# pRNNtype="Masked", -# hidden_size=100, -# cell="DivNormRNNCell", -# eg_lr=1e-3, -# eg_weight_decay=1e-6, -# **kwargs, -# ) - -# # Check that positive parameter groups have update_alg="eg" -# for group in net.optimizer.param_groups: -# if group["name"] in ["RecurrentWeights", "k_divnorm", "sigma_divnorm"]: -# # These should have EG (assuming they're all positive) -# assert group.get("update_alg") == "eg", ( -# f"{group['name']} should use EG update algorithm" -# ) -# assert group["lr"] == 1e-3, f"{group['name']} should have eg_lr=1e-3" -# # Check weight decay scaling -# expected_wd = 1e-6 * 1e-3 # eg_weight_decay * eg_lr -# assert abs(group["weight_decay"] - expected_wd) < 1e-12, ( -# f"{group['name']} weight_decay should be {expected_wd}" -# ) +# ============================================================================ +# TEST 4: CELL-SPECIFIC PARAMETERS FLOW CORRECTLY +# ============================================================================ + + +class TestCellParameters: + """Test that cell-specific parameters reach the cell correctly""" + + def test_cell_override_works(self, mock_env, base_architecture_kwargs): + """Test that cell override argument works""" + # Test with LayerNormRNNCell + net_ln = PredictiveNet( + mock_env, + pRNNtype="Masked", + hidden_size=100, + cell=LayerNormRNNCell, + **base_architecture_kwargs, + ) + + assert net_ln.pRNN.rnn.cell.__class__.__name__ == "LayerNormRNNCell", ( + f"Expected LayerNormRNNCell, got {net_ln.pRNN.rnn.cell.__class__.__name__}" + ) + + def test_sparsity_reaches_layernorm_cell(self, mock_env, base_architecture_kwargs): + """Test that sparsity parameter reaches LayerNormRNNCell""" + kwargs = { + **base_architecture_kwargs, + "sparsity": 0.3, + } + + net = PredictiveNet( + mock_env, + pRNNtype="lognRNN_mask", # Uses LayerNormRNNCell + hidden_size=100, + **kwargs, + ) + + sparsity_w = (net.pRNN.W == 0).float().mean().item() * 100 + expected_sparsity = 70.0 + tolerance = 5.0 + assert abs(sparsity_w - expected_sparsity) < tolerance, ( + f"Expected ~{expected_sparsity}% zeros (sparsity=0.05), but got {sparsity_w:.2f}% zeros for recurrent weights" + ) + + +# ============================================================================ +# TEST 5: OPTIMIZER PARAMETER GROUPS +# ============================================================================ + + +class TestOptimizerParameterGroups: + """Test that optimizer parameter groups are correctly configured""" + + def test_bias_in_optimizer_when_trainable(self, mock_env, base_architecture_kwargs): + """Test that bias is in optimizer when trainBias=True""" + net = PredictiveNet( + mock_env, pRNNtype="Masked", hidden_size=100, trainBias=True, **base_architecture_kwargs + ) + + param_group_names = [g["name"] for g in net.optimizer.param_groups] + assert "biases" in param_group_names, ( + "biases should be in optimizer parameter groups when trainBias=True" + ) + + def test_eg_parameter_groups_configured(self, mock_env, base_architecture_kwargs): + """Test that EG is correctly configured for parameter groups""" + kwargs = { + **base_architecture_kwargs, + } + + net = PredictiveNet( + mock_env, + pRNNtype="Masked", + hidden_size=100, + cell=LayerNormRNNCell, + init="log_normal", + eg_lr=1e-3, + eg_weight_decay=1e-6, + **kwargs, + ) + + # Check that positive parameter groups have update_alg="eg" + for group in net.optimizer.param_groups: + if group["name"] in ["RecurrentWeights", "InputWeights"]: + # These should have EG (assuming they're all positive) + assert group.get("update_alg") == "eg", ( + f"{group['name']} should use EG update algorithm" + ) + assert group["lr"] == 1e-3, f"{group['name']} should have eg_lr=1e-3" + # Check weight decay scaling + expected_wd = 1e-6 * 1e-3 # eg_weight_decay * eg_lr + assert abs(group["weight_decay"] - expected_wd) < 1e-12, ( + f"{group['name']} weight_decay should be {expected_wd}" + ) + # ============================================================================ # RUN TESTS diff --git a/tests/test_trainnet.py b/tests/test_trainnet.py index f19acab5..fcd0a4c3 100644 --- a/tests/test_trainnet.py +++ b/tests/test_trainnet.py @@ -326,133 +326,6 @@ def test_cell_override(self, run_trainnet): ) -# ============================================================================ -# TEST DIVNORM PARAMETERS -# ============================================================================ - - -# class TestDivNormViaCLI: -# """Test that DivNorm parameters flow through CLI correctly""" - -# def test_divnorm_parameters_trainable(self, run_trainnet): -# """Test --train_divnorm makes parameters trainable""" -# net = run_trainnet( -# "Masked", -# cell="DivNormRNNCell", -# target_mean=0.8, -# k_div=2.0, -# sigma=1.5, -# train_divnorm=True, -# ) - -# cell = net.pRNN.rnn.cell - -# # Check values reached the cell -# # assert cell.target_mean == 0.8, f"Expected target_mean=0.8, got {cell.target_mean}" -# assert cell.divnorm.k_div.item() == 2.0, ( -# f"Expected k_div=2.0, got {cell.divnorm.k_div.item()}" -# ) -# assert cell.divnorm.sigma.item() == 1.5, ( -# f"Expected sigma=1.5, got {cell.divnorm.sigma.item()}" -# ) - -# # Check they're trainable -# import torch.nn as nn - -# assert isinstance(cell.divnorm.k_div, nn.Parameter), ( -# "k_div should be a Parameter when train_divnorm=True" -# ) -# assert isinstance(cell.divnorm.sigma, nn.Parameter), ( -# "sigma should be a Parameter when train_divnorm=True" -# ) - -# def test_divnorm_parameters_fixed(self, run_trainnet): -# """Test that without --train_divnorm, parameters are fixed""" -# net = run_trainnet( -# "Masked", -# cell="DivNormRNNCell", -# target_mean=0.7, -# k_div=1.0, -# sigma=1.0, -# # train_divnorm=False is default -# ) - -# cell = net.pRNN.rnn.cell - -# # Check they're NOT trainable (buffers, not Parameters) -# import torch.nn as nn - -# assert not isinstance(cell.divnorm.k_div, nn.Parameter), ( -# "k_div should be a buffer when train_divnorm=False" -# ) -# assert not isinstance(cell.divnorm.sigma, nn.Parameter), ( -# "sigma should be a buffer when train_divnorm=False" -# ) - - -# ============================================================================ -# TEST DIVNORM PARAMETERS -# ============================================================================ - - -# class TestDivNormViaCLI: -# """Test that DivNorm parameters flow through CLI correctly""" - -# def test_divnorm_parameters_trainable(self, run_trainnet): -# """Test --train_divnorm makes parameters trainable""" -# net = run_trainnet( -# "Masked", -# cell="DivNormRNNCell", -# target_mean=0.8, -# k_div=2.0, -# sigma=1.5, -# train_divnorm=True, -# ) - -# cell = net.pRNN.rnn.cell - -# # Check values reached the cell -# # assert cell.target_mean == 0.8, f"Expected target_mean=0.8, got {cell.target_mean}" -# assert cell.divnorm.k_div.item() == 2.0, ( -# f"Expected k_div=2.0, got {cell.divnorm.k_div.item()}" -# ) -# assert cell.divnorm.sigma.item() == 1.5, ( -# f"Expected sigma=1.5, got {cell.divnorm.sigma.item()}" -# ) - -# # Check they're trainable -# import torch.nn as nn - -# assert isinstance(cell.divnorm.k_div, nn.Parameter), ( -# "k_div should be a Parameter when train_divnorm=True" -# ) -# assert isinstance(cell.divnorm.sigma, nn.Parameter), ( -# "sigma should be a Parameter when train_divnorm=True" -# ) - -# def test_divnorm_parameters_fixed(self, run_trainnet): -# """Test that without --train_divnorm, parameters are fixed""" -# net = run_trainnet( -# "Masked", -# cell="DivNormRNNCell", -# target_mean=0.7, -# k_div=1.0, -# sigma=1.0, -# # train_divnorm=False is default -# ) - -# cell = net.pRNN.rnn.cell - -# # Check they're NOT trainable (buffers, not Parameters) -# import torch.nn as nn - -# assert not isinstance(cell.divnorm.k_div, nn.Parameter), ( -# "k_div should be a buffer when train_divnorm=False" -# ) -# assert not isinstance(cell.divnorm.sigma, nn.Parameter), ( -# "sigma should be a buffer when train_divnorm=False" -# ) - # ============================================================================ # TEST LR OPTIMIZER CONFIGURATION # ============================================================================ @@ -468,14 +341,11 @@ def test_eg_lr_configured(self, run_trainnet): # Check that EG parameter groups exist has_eg = False W_group = None - # W_out = None W_in = None for group in net.optimizer.param_groups: if group.get("update_alg") == "eg": if group["name"] == "RecurrentWeights": W_group = group - # if group["name"] == "OutputWeights": - # W_out = group if group["name"] == "InputWeights": W_in = group has_eg = True @@ -489,7 +359,6 @@ def test_eg_lr_configured(self, run_trainnet): assert W_group is not None, "Recurring weights not properly updated to EG" assert W_in is not None, "Input weights not properly updated to EG" - # assert W_out is not None, "Output weights not properly updated to EG" assert has_eg, "No parameter groups configured for EG" def test_bias_lr_uses_gd_not_eg(self, run_trainnet): From 5a71aefdc58f004ce532aed2d701b34e7a2156cc Mon Sep 17 00:00:00 2001 From: Meghan Date: Thu, 26 Feb 2026 12:45:59 -0500 Subject: [PATCH 6/7] fixed and tested sparsity, moved test_imports into tests folder and updated --- setup.py | 9 +++++---- test/test_imports.py | 16 ---------------- tests/test_imports.py | 14 ++++++++++++++ tests/test_predictivenet.py | 30 ++++++++++++++++++++++++++++++ tests/test_trainnet.py | 12 ++++++++++++ 5 files changed, 61 insertions(+), 20 deletions(-) delete mode 100644 test/test_imports.py create mode 100644 tests/test_imports.py diff --git a/setup.py b/setup.py index 54be33bf..e918c2a6 100644 --- a/setup.py +++ b/setup.py @@ -28,14 +28,15 @@ "importlib_metadata", "ruamel.yaml", "gymnasium==0.29.1", + "pytest==8.4.2", ] setup( author="Daniel Levenstein", - author_email='daniel.levenstein@mila.quebec', - python_requires='>=3.9', - name='prnn', - version='v0.1', + author_email="daniel.levenstein@mila.quebec", + python_requires=">=3.9", + name="prnn", + version="v0.1", packages=find_packages(), install_requires=dependencies, description="Python Library for Predictive RNNs Modeling Hippocampal Representation and Replay", diff --git a/test/test_imports.py b/test/test_imports.py deleted file mode 100644 index 959827c3..00000000 --- a/test/test_imports.py +++ /dev/null @@ -1,16 +0,0 @@ -import importlib.util - -def test_import(package_name): - try: - importlib.util.find_spec(package_name) - print(f"Import of {package_name} successful!") - except ImportError: - print(f"Failed to import {package_name}") - -if __name__ == "__main__": - # Specify the name of your package here - package_name = "prnn.utils.predictiveNet" - - # Test the import - test_import("prnn.utils.predictiveNet") - test_import("prnn.analysis.trajectoryAnalysis") diff --git a/tests/test_imports.py b/tests/test_imports.py new file mode 100644 index 00000000..bb02d463 --- /dev/null +++ b/tests/test_imports.py @@ -0,0 +1,14 @@ +import importlib.util +import pytest + + +@pytest.mark.parametrize( + "package_name", + [ + "prnn.utils.predictiveNet", + "prnn.analysis.trajectoryAnalysis", + ], +) +def test_import(package_name): + spec = importlib.util.find_spec(package_name) + assert spec is not None, f"Failed to import {package_name}" diff --git a/tests/test_predictivenet.py b/tests/test_predictivenet.py index 3d08055b..36968145 100644 --- a/tests/test_predictivenet.py +++ b/tests/test_predictivenet.py @@ -193,6 +193,36 @@ def test_lognRNN_init_and_sparsity(self, mock_env, base_architecture_kwargs): f"Expected ~{expected_sparsity}% zeros (sparsity=0.05), but got {sparsity:.2f}% zeros" ) + def test_default_sparsity(self, mock_env, base_architecture_kwargs): + """Test that default sparsity is 0 for non-lognRNN architectures""" + test_cases = [ + ("thRNN_0win"), + ("thRNN_5win"), + ] + for pRNNtype in test_cases: + net = PredictiveNet( + mock_env, pRNNtype=pRNNtype, hidden_size=100, **base_architecture_kwargs + ) + sparsity = (net.pRNN.W == 0).float().mean().item() * 100 + expected_sparsity = 0.0 + tolerance = 2.0 + + assert abs(sparsity - expected_sparsity) < tolerance, ( + f"Expected ~{expected_sparsity}% zeros (sparsity=1), but got {sparsity:.2f}% zeros" + ) + + for pRNNtype in test_cases: + net = PredictiveNet( + mock_env, pRNNtype=pRNNtype, hidden_size=100, **base_architecture_kwargs + ) + sparsity = (net.pRNN.W == 0).float().mean().item() * 100 + expected_sparsity = 0.0 + tolerance = 1.0 + + assert abs(sparsity - expected_sparsity) < tolerance, ( + f"{pRNNtype}: Expected ~{expected_sparsity}% zeros (default sparsity), but got {sparsity:.2f}% zeros" + ) + def test_autoencoder_use_FF(self, mock_env, base_architecture_kwargs): """Test that Autoencoder variants have correct use_FF setting""" test_cases = [ diff --git a/tests/test_trainnet.py b/tests/test_trainnet.py index fcd0a4c3..4717469e 100644 --- a/tests/test_trainnet.py +++ b/tests/test_trainnet.py @@ -289,6 +289,18 @@ def test_lognRNNs(self, run_trainnet): ) assert net.pRNN.actionTheta, f"Expected actionTheta=True, got {net.pRNN.actionTheta}" + def test_default_sparsity(self, run_trainnet): + """Test that default sparsity is 0 for non-lognRNN architectures""" + + net = run_trainnet("thRNN_3win") + sparsity = (net.pRNN.W == 0).float().mean().item() * 100 + expected_sparsity = 0.0 + tolerance = 2.0 + + assert abs(sparsity - expected_sparsity) < tolerance, ( + f"Expected ~{expected_sparsity}% zeros (sparsity=1), but got {sparsity:.2f}% zeros" + ) + # ============================================================================ # TEST CLI OVERRIDES WORK From 0d3f21d53215f8c9b2fb18aa1fde7f100d9cfedb Mon Sep 17 00:00:00 2001 From: Meghan Date: Tue, 3 Mar 2026 13:14:01 -0500 Subject: [PATCH 7/7] deleted sparsity line about f and reran tests --- prnn/utils/Architectures.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/prnn/utils/Architectures.py b/prnn/utils/Architectures.py index fa4b09eb..a5a1631b 100644 --- a/prnn/utils/Architectures.py +++ b/prnn/utils/Architectures.py @@ -127,9 +127,6 @@ def __init__( self.W_hs = self.rnn.cell.weight_hs self.neuralTimescale = neuralTimescale - self.sparsity = ( - cell_kwargs["sparsity"] if "sparsity" in cell_kwargs else f - ) # backwards compatibility with torch.no_grad(): self.W.add_(torch.eye(hidden_size).mul_(1 - 1 / self.neuralTimescale).to_sparse())