diff --git a/src/csrc/training.cpp b/src/csrc/training.cpp index 58469f2..f50d3dd 100644 --- a/src/csrc/training.cpp +++ b/src/csrc/training.cpp @@ -178,6 +178,8 @@ void train_multiarg(const char* name, torchfort_tensor_list_t inputs_in, torchfo inputs->reset(); labels->reset(); + if (extra_loss_args) + extra_loss_args->reset(); torchfort::nvtx::rangePop(); } diff --git a/tests/supervised/scripts/setup_tests.py b/tests/supervised/scripts/setup_tests.py index 6f2c4b8..693bca1 100644 --- a/tests/supervised/scripts/setup_tests.py +++ b/tests/supervised/scripts/setup_tests.py @@ -36,14 +36,14 @@ def __init__(self): super(Loss1, self).__init__() def forward(self, prediction, label): - return (torch.sum(prediction) + torch.sum(label)) / (2 * prediction.numel()) + return (torch.sum(prediction) + torch.sum(label)) class Loss2(torch.nn.Module): def __init__(self): super(Loss2, self).__init__() def forward(self, prediction1, prediction2, label1, label2): - return (torch.sum(prediction1) + torch.sum(prediction2) + torch.sum(label1) + torch.sum(label2)) / (4 * prediction1.numel()) + return (torch.sum(prediction1) + torch.sum(prediction2) + torch.sum(label1) + torch.sum(label2)) class Loss2Extra(torch.nn.Module): def __init__(self): @@ -51,7 +51,7 @@ def __init__(self): def forward(self, prediction1, prediction2, label1, label2, extra_args1, extra_args2): return (torch.sum(prediction1) + torch.sum(prediction2) + torch.sum(label1) + torch.sum(label2) + - torch.sum(extra_args1) + torch.sum(extra_args2)) / (6 * prediction1.numel()) + torch.sum(extra_args1) + torch.sum(extra_args2)) def main(): model1 = Net1() diff --git a/tests/supervised/test_training.cpp b/tests/supervised/test_training.cpp index 1d5bd52..709b22e 100644 --- a/tests/supervised/test_training.cpp +++ b/tests/supervised/test_training.cpp @@ -220,26 +220,31 @@ void training_test_multiarg(const std::string& model_config, int dev_model, int } // Check that external data changes reflect in tensor list + float expected_loss_val = 0.0; for (int i = 0; i < 2; ++i) { - auto tmp = generate_random(shape); + auto tmp = generate_constant(shape, 1); inputs[i].assign(tmp.begin(), tmp.end()); + labels[i].assign(tmp.begin(), tmp.end()); + expected_loss_val += 2 * 1.0f * tmp.size(); + if (use_extra_args) { + extra_args[i].assign(tmp.begin(), tmp.end()); + expected_loss_val += 1.0f * tmp.size(); + } #ifdef ENABLE_GPU if (dev_input != TORCHFORT_DEVICE_CPU) { copy_from_host_vector(input_ptrs[i], inputs[i]); + copy_from_host_vector(label_ptrs[i], labels[i]); + if (use_extra_args) { + copy_from_host_vector(extra_args_ptrs[i], extra_args[i]); + } } #endif } - CHECK_TORCHFORT(torchfort_inference_multiarg(model_name.c_str(), inputs_tl, outputs_tl, 0)); + CHECK_TORCHFORT(torchfort_train_multiarg(model_name.c_str(), inputs_tl, labels_tl, &loss_val, + (use_extra_args) ? extra_args_tl : nullptr, 0)); - for (int i = 0; i < 2; ++i) { -#ifdef ENABLE_GPU - if (dev_input != TORCHFORT_DEVICE_CPU) { - copy_to_host_vector(outputs[i], output_ptrs[i]); - } -#endif - EXPECT_EQ(inputs[i], outputs[i]); - } + EXPECT_EQ(loss_val, expected_loss_val); } for (int i = 0; i < 2; ++i) { diff --git a/tests/test_utils.h b/tests/test_utils.h index 2bb2030..b2366c5 100644 --- a/tests/test_utils.h +++ b/tests/test_utils.h @@ -45,6 +45,15 @@ template std::vector generate_random(const std::vector& return data; } +// Generate constant vector data for testing +template std::vector generate_constant(const std::vector& shape, T value) { + + int64_t num_values = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); + std::vector data(num_values, value); + + return data; +} + // Generate random names to use as model keys to avoid conflicts between tests std::string generate_random_name(int length) {