diff --git a/odyssnet/training/trainer.py b/odyssnet/training/trainer.py index c6697c7..1654ba2 100644 --- a/odyssnet/training/trainer.py +++ b/odyssnet/training/trainer.py @@ -3,6 +3,7 @@ import torch.optim as optim import time import math +import numbers from typing import Callable from ..utils.data import prepare_input, to_tensor @@ -194,6 +195,12 @@ def train_batch(self, input_features, target_values, thinking_steps, gradient_ac """ Runs a single training step on a batch. """ + if isinstance(gradient_accumulation_steps, bool) or not isinstance(gradient_accumulation_steps, numbers.Integral): + raise ValueError("gradient_accumulation_steps must be an integer >= 1") + gradient_accumulation_steps = int(gradient_accumulation_steps) + if gradient_accumulation_steps < 1: + raise ValueError("gradient_accumulation_steps must be an integer >= 1") + self.model.train() self._ensure_scaler() @@ -386,6 +393,15 @@ def fit(self, input_features, target_values, epochs, batch_size=32, thinking_ste input_features = to_tensor(input_features, self.device) target_values = to_tensor(target_values, self.device) + if isinstance(batch_size, bool) or not isinstance(batch_size, int): + raise TypeError("batch_size must be an integer") + if batch_size < 1: + raise ValueError("batch_size must be >= 1") + if len(input_features) != len(target_values): + raise ValueError("input_features and target_values must have the same length") + if len(input_features) == 0: + raise ValueError("input_features and target_values must be non-empty") + history = [] # Prepare Data diff --git a/odyssnet/utils/data.py b/odyssnet/utils/data.py index f4d85b5..f7a5c5e 100644 --- a/odyssnet/utils/data.py +++ b/odyssnet/utils/data.py @@ -48,8 +48,9 @@ def prepare_input(input_features, model_input_ids, num_neurons, device): x_input = torch.zeros(batch_size, steps, num_neurons, device=device) num_assigned = min(num_features, len(model_input_ids)) - for k in range(num_assigned): - x_input[:, :, model_input_ids[k]] = input_features[:, :, k] + if num_assigned > 0: + target_ids = model_input_ids[:num_assigned] + x_input[:, :, target_ids] = input_features[:, :, :num_assigned] return x_input, batch_size @@ -61,8 +62,9 @@ def prepare_input(input_features, model_input_ids, num_neurons, device): num_assigned = min(num_features, len(model_input_ids)) # Assign features to neurons - for k in range(num_assigned): - x_input[:, model_input_ids[k]] = input_features[:, k] + if num_assigned > 0: + target_ids = model_input_ids[:num_assigned] + x_input[:, target_ids] = input_features[:, :num_assigned] return x_input, batch_size diff --git a/pyproject.toml b/pyproject.toml index 92e652f..92fa947 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,3 @@ all = [ [tool.setuptools.packages.find] include = ["odyssnet*"] - -[tool.pytest.ini_options] -testpaths = ["tests"] diff --git a/tests/training/test_trainer.py b/tests/training/test_trainer.py index 34ed31c..e59c7fa 100644 --- a/tests/training/test_trainer.py +++ b/tests/training/test_trainer.py @@ -242,6 +242,16 @@ def test_non_pulse_2d_input_full_sequence(self): loss = t.train_batch(x, y, thinking_steps=10, full_sequence=True) assert isinstance(loss, float) + def test_gradient_accumulation_steps_must_be_positive(self): + model = _model() + t = _trainer(model) + x = _batch() + y = _targets() + with pytest.raises(ValueError): + t.train_batch(x, y, thinking_steps=2, gradient_accumulation_steps=0) + with pytest.raises(ValueError): + t.train_batch(x, y, thinking_steps=2, gradient_accumulation_steps=-1) + # =========================================================================== # predict @@ -341,6 +351,30 @@ def test_fit_loss_trend_downward_on_simple_data(self): history = t.fit(x, y, epochs=20, batch_size=n, thinking_steps=5, verbose=False) assert history[-1] < history[0], "Loss should decrease over training" + def test_fit_empty_dataset_raises(self): + model = _model() + t = _trainer(model) + x = torch.empty(0, 5) + y = torch.empty(0, 2) + with pytest.raises(ValueError): + t.fit(x, y, epochs=1, batch_size=4, thinking_steps=2, verbose=False) + + def test_fit_length_mismatch_raises(self): + model = _model() + t = _trainer(model) + x = torch.randn(3, 5) + y = torch.randn(2, 2) + with pytest.raises(ValueError): + t.fit(x, y, epochs=1, batch_size=2, thinking_steps=2, verbose=False) + + def test_fit_invalid_batch_size_raises(self): + model = _model() + t = _trainer(model) + x = torch.randn(4, 5) + y = torch.randn(4, 2) + with pytest.raises(ValueError): + t.fit(x, y, epochs=1, batch_size=0, thinking_steps=2, verbose=False) + # =========================================================================== # regenerate_synapses