|
1 | | -constexpr auto modelHeaderSuffix = "_FromONNX.hxx"; |
2 | | -constexpr auto modelDataSuffix = "_FromONNX.dat"; |
| 1 | +constexpr auto modelHeaderSuffix = "_FromONNX_unoptimized.hxx"; |
| 2 | +constexpr auto modelDataSuffix = "_FromONNX_unoptimized.dat"; |
3 | 3 | #include "test_helpers.h" |
4 | 4 |
|
5 | 5 | #include "input_models/references/Linear_16.ref.hxx" |
@@ -76,28 +76,46 @@ float Linear_16_wrapper_num_diff(TMVA_SOFIE_Linear_16::Session const &session, f |
76 | 76 | .c_str()); |
77 | 77 |
|
78 | 78 | // If you want to see the gradient code: |
79 | | - // gInterpreter->ProcessLine("static_cast<void (*)(TMVA_SOFIE_Linear_16::Session const &, float const *, float |
80 | | - // *)>(Linear_16_outer_wrapper_grad_1)"); gInterpreter->ProcessLine("Linear_16_wrapper_pullback"); |
| 79 | + // clang-format off |
| 80 | + // gInterpreter->ProcessLine("static_cast<void (*)(TMVA_SOFIE_Linear_16::Session const &, float const *, float *)>(Linear_16_outer_wrapper_grad_1)"); |
| 81 | + // gInterpreter->ProcessLine("Linear_16_wrapper_pullback"); |
81 | 82 | // gInterpreter->ProcessLine("TMVA_SOFIE_Linear_16::doInfer_reverse_forw"); |
82 | 83 | // gInterpreter->ProcessLine("TMVA_SOFIE_Linear_16::doInfer_pullback"); |
| 84 | + // clang-format on |
83 | 85 |
|
84 | | - auto retVal = gInterpreter->ProcessLine((R"( |
85 | | - double maxDiff = 0; |
| 86 | + gInterpreter->ProcessLine((R"( |
| 87 | + float numeric_output[1600]{}; |
86 | 88 | for (std::size_t i = 0; i < std::size(grad_output); ++i) { |
87 | | - double val = grad_output[i]; |
88 | | - double ref = Linear_16_wrapper_num_diff(session_linear_16, )" + |
89 | | - inputInterp + R"(, i); |
90 | | - if (val != ref) { |
91 | | - maxDiff = std::max(std::abs(val - ref), maxDiff); |
92 | | - } |
| 89 | + numeric_output[i] = Linear_16_wrapper_num_diff(session_linear_16, )" + |
| 90 | + inputInterp + R"(, i); |
93 | 91 | } |
94 | | - double tol = 0.0025; |
95 | | - // the "return" value |
96 | | - (maxDiff < tol); |
97 | 92 | )") |
98 | | - .c_str()); |
| 93 | + .c_str()); |
| 94 | + |
| 95 | + double tol = 0.0025; |
| 96 | + |
| 97 | + auto arr_size = static_cast<std::size_t>(gInterpreter->ProcessLine("std::size(grad_output);")); |
| 98 | + auto grad_arr = reinterpret_cast<float *>(gInterpreter->ProcessLine("grad_output;")); |
| 99 | + auto numeric_arr = reinterpret_cast<float *>(gInterpreter->ProcessLine("numeric_output;")); |
| 100 | + |
| 101 | + constexpr std::size_t kMaxPrint = 10; |
| 102 | + std::size_t mismatchCount = 0; |
99 | 103 |
|
100 | | - EXPECT_EQ(retVal, 1) << "The gradient from Clad and the numeric gradient didn't match within tolerance."; |
| 104 | + for (std::size_t i = 0; i < arr_size; ++i) { |
| 105 | + double diff = std::abs(grad_arr[i] - numeric_arr[i]); |
| 106 | + |
| 107 | + if (diff > tol) { |
| 108 | + if (mismatchCount < kMaxPrint) { |
| 109 | + ADD_FAILURE() << "Mismatch at index " << i << " analytic=" << grad_arr[i] << " numeric=" << numeric_arr[i] |
| 110 | + << " diff=" << diff; |
| 111 | + } |
| 112 | + ++mismatchCount; |
| 113 | + } |
| 114 | + } |
| 115 | + |
| 116 | + if (mismatchCount > kMaxPrint) { |
| 117 | + ADD_FAILURE() << "Further mismatches suppressed (total mismatches: " << mismatchCount << ")"; |
| 118 | + } |
101 | 119 |
|
102 | 120 | // Checking output size |
103 | 121 | EXPECT_EQ(output.size(), sizeof(Linear_16_ExpectedOutput::all_ones) / sizeof(float)); |
|
0 commit comments