Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions odyssnet/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch.optim as optim
import time
import math
import numbers
from typing import Callable
from ..utils.data import prepare_input, to_tensor

Expand Down Expand Up @@ -194,6 +195,12 @@ def train_batch(self, input_features, target_values, thinking_steps, gradient_ac
"""
Runs a single training step on a batch.
"""
if isinstance(gradient_accumulation_steps, bool) or not isinstance(gradient_accumulation_steps, numbers.Integral):
raise ValueError("gradient_accumulation_steps must be an integer >= 1")
gradient_accumulation_steps = int(gradient_accumulation_steps)
if gradient_accumulation_steps < 1:
raise ValueError("gradient_accumulation_steps must be an integer >= 1")

self.model.train()

self._ensure_scaler()
Expand Down Expand Up @@ -386,6 +393,15 @@ def fit(self, input_features, target_values, epochs, batch_size=32, thinking_ste
input_features = to_tensor(input_features, self.device)
target_values = to_tensor(target_values, self.device)

Comment thread
theomgdev marked this conversation as resolved.
if isinstance(batch_size, bool) or not isinstance(batch_size, int):
raise TypeError("batch_size must be an integer")
if batch_size < 1:
raise ValueError("batch_size must be >= 1")
if len(input_features) != len(target_values):
raise ValueError("input_features and target_values must have the same length")
if len(input_features) == 0:
raise ValueError("input_features and target_values must be non-empty")

history = []

# Prepare Data
Expand Down
10 changes: 6 additions & 4 deletions odyssnet/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ def prepare_input(input_features, model_input_ids, num_neurons, device):
x_input = torch.zeros(batch_size, steps, num_neurons, device=device)

num_assigned = min(num_features, len(model_input_ids))
for k in range(num_assigned):
x_input[:, :, model_input_ids[k]] = input_features[:, :, k]
if num_assigned > 0:
target_ids = model_input_ids[:num_assigned]
x_input[:, :, target_ids] = input_features[:, :, :num_assigned]

return x_input, batch_size

Expand All @@ -61,8 +62,9 @@ def prepare_input(input_features, model_input_ids, num_neurons, device):
num_assigned = min(num_features, len(model_input_ids))

# Assign features to neurons
for k in range(num_assigned):
x_input[:, model_input_ids[k]] = input_features[:, k]
if num_assigned > 0:
target_ids = model_input_ids[:num_assigned]
x_input[:, target_ids] = input_features[:, :num_assigned]

return x_input, batch_size

Expand Down
3 changes: 0 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,3 @@ all = [

[tool.setuptools.packages.find]
include = ["odyssnet*"]

[tool.pytest.ini_options]
testpaths = ["tests"]
34 changes: 34 additions & 0 deletions tests/training/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,16 @@ def test_non_pulse_2d_input_full_sequence(self):
loss = t.train_batch(x, y, thinking_steps=10, full_sequence=True)
assert isinstance(loss, float)

def test_gradient_accumulation_steps_must_be_positive(self):
model = _model()
t = _trainer(model)
x = _batch()
y = _targets()
with pytest.raises(ValueError):
t.train_batch(x, y, thinking_steps=2, gradient_accumulation_steps=0)
with pytest.raises(ValueError):
t.train_batch(x, y, thinking_steps=2, gradient_accumulation_steps=-1)


# ===========================================================================
# predict
Expand Down Expand Up @@ -341,6 +351,30 @@ def test_fit_loss_trend_downward_on_simple_data(self):
history = t.fit(x, y, epochs=20, batch_size=n, thinking_steps=5, verbose=False)
assert history[-1] < history[0], "Loss should decrease over training"

def test_fit_empty_dataset_raises(self):
model = _model()
t = _trainer(model)
x = torch.empty(0, 5)
y = torch.empty(0, 2)
with pytest.raises(ValueError):
t.fit(x, y, epochs=1, batch_size=4, thinking_steps=2, verbose=False)

def test_fit_length_mismatch_raises(self):
model = _model()
t = _trainer(model)
x = torch.randn(3, 5)
y = torch.randn(2, 2)
with pytest.raises(ValueError):
t.fit(x, y, epochs=1, batch_size=2, thinking_steps=2, verbose=False)

def test_fit_invalid_batch_size_raises(self):
model = _model()
t = _trainer(model)
x = torch.randn(4, 5)
y = torch.randn(4, 2)
with pytest.raises(ValueError):
t.fit(x, y, epochs=1, batch_size=0, thinking_steps=2, verbose=False)


# ===========================================================================
# regenerate_synapses
Expand Down
Loading