Skip to content

Commit b7655c7

Browse files
committed
[tmva][sofie] Improve AD-friendlieness of emitted code for Clad
This commit refactors SOFIE-generated inference code to enable correct and efficient reverse-mode automatic differentiation with Clad. Key changes: * Introduce explicit primitive operations (`Copy`, `Fill`, `Relu`) in SOFIE_common.hxx and provide corresponding custom pullbacks in CladDerivator.h. This replaces previously inlined loops and allows Clad to generate efficient gradient code without relying on tapes or loop-level differentiation. * Update Gemm code generation to emit Copy/Fill instead of manually expanding bias initialization loops. This better exposes the intent and improves AD performance and correctness. * Replace manual ReLU loops with a dedicated Relu() call, enabling a custom pullback that avoids tape-based condition tracking. * Generate an additional "unoptimized" model variant in the SOFIE test suite (`OptimizationLevel::kBasic`), and use it for AD tests. This disables memory reuse of intermediate tensors. Opaque memory reuse is safe for inference but breaks source-transformation AD. * Improve gradient test diagnostics in SOFIE Clad tests by reporting mismatched indices instead of only checking a global max difference. With these changes, Clad-generated gradients for SOFIE models are both correct and significantly faster, reaching performance comparable to frameworks such as PyTorch and JAX on the CPU for the tested cases (fully-connected neural networks with multiple layers).
1 parent a22fe0e commit b7655c7

5 files changed

Lines changed: 139 additions & 47 deletions

File tree

math/mathcore/inc/Math/CladDerivator.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1169,6 +1169,35 @@ inline void Gemm_Call_pullback(float *output, bool transa, bool transb, int m, i
11691169
}
11701170
}
11711171

1172+
inline void Copy_pullback(float *output, const float *input, int size, float *_d_output, float *_d_input, int *)
1173+
{
1174+
for (int i = 0; i < size; i++) {
1175+
output[i] = input[i];
1176+
_d_input[i] += _d_output[i];
1177+
_d_output[i] = 0.F;
1178+
}
1179+
}
1180+
1181+
inline void Fill_pullback(float *output, float value, int size, float *_d_output, float *_d_value, int *)
1182+
{
1183+
for (int i = 0; i < size; i++) {
1184+
output[i] = value;
1185+
*_d_value += _d_output[i];
1186+
_d_output[i] = 0.F;
1187+
}
1188+
}
1189+
1190+
inline void Relu_pullback(float *output, const float *input, int size, float *_d_output, float *_d_input, int *)
1191+
{
1192+
for (int i = 0; i < size; i++) {
1193+
output[i] = input[i] > 0.F ? input[i] : 0.F;
1194+
float _r_d0 = _d_output[i];
1195+
_d_output[i] = 0.F;
1196+
if (input[i] > 0.F)
1197+
_d_input[i] += _r_d0;
1198+
}
1199+
}
1200+
11721201
} // namespace TMVA::Experimental::SOFIE
11731202

11741203
} // namespace clad::custom_derivatives

tmva/sofie/inc/TMVA/ROperator_Gemm.hxx

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -397,29 +397,27 @@ namespace SOFIE{
397397
else
398398
out << "j;\n";
399399

400-
out << SP2 << SP << "for (size_t k = 0; k < " << sY[1] << "; k++) { \n";
401-
std::string bias_index;
402-
if (sC.size() != 2)
400+
std::string prefix = SP2 + SP + "TMVA::Experimental::SOFIE::";
401+
std::string target = "tensor_" + fNY;
402+
if (sC.size() != 2) {
403403
throw std::runtime_error("TMVA SOFIE Gemm Op - invalid rank for bias tensor " + ConvertDimShapeToString(fDimShapeC) + ConvertDimShapeToString(sC));
404-
if (sC[0].GetVal() == "1" && sC[1].GetVal() == sY[1].GetVal())
405-
bias_index = "k";
406-
else if (sC[1].GetVal() == "1" && sC[0].GetVal() == sY[0].GetVal())
407-
bias_index = "j";
408-
else if (sC[0].GetVal() == "1" && sC[1].GetVal() == "1") // scalar case
409-
bias_index = "0";
410-
else {
404+
} if (sC[0].GetVal() == "1" && sC[1].GetVal() == sY[1].GetVal()) {
405+
out << prefix << "Copy(" << target << " + y_index, tensor_" << fNC << ", " << sY[1] << ");\n";
406+
} else if (sC[1].GetVal() == "1" && sC[0].GetVal() == sY[0].GetVal()) {
407+
out << prefix << "Fill(" << target << " + y_index, tensor_" << fNC << "[j], " << sY[1] << ");\n";
408+
} else if (sC[0].GetVal() == "1" && sC[1].GetVal() == "1") {
409+
// scalar case
410+
out << prefix << "Fill(" << target << " + y_index, tensor_" << fNC << "[0], " << sY[1] << ");\n";
411+
} else {
411412
throw std::runtime_error("TMVA SOFIE Gemm Op - invalid shape for bias tensor " + ConvertDimShapeToString(fDimShapeC));
412413
}
413414

414-
out << SP2 << SP << SP << "tensor_" << fNY << "[y_index + k] = " << "tensor_" << fNC << "[" << bias_index << "];\n";
415-
out << SP2 << SP << "}\n";
416415
out << SP2 << "}\n";
417416
}
418417

419418
if (fType == "float"){
420419

421-
out << SP2 << "TMVA::Experimental::SOFIE::Gemm_Call("
422-
<< "tensor_" << fNY;
420+
out << SP2 << "TMVA::Experimental::SOFIE::Gemm_Call(" << "tensor_" << fNY;
423421
if (doStackMul) out << " + " << opName << "_y_offset";
424422
out << ", "
425423
<< (fAttrTransB ? "true, " : "false, ")
@@ -461,15 +459,15 @@ namespace SOFIE{
461459
// fuse with Relu
462460
if(fActivation == EActivationType::RELU){
463461
out << SP << "//--- applying RELU to output\n";
464-
out << SP << "for (int id = 0; id < " << ConvertDimShapeToLength(fShapeY) << " ; id++){\n";
465-
out << SP << SP << "tensor_" << fNY << "[id] = ((tensor_" << fNY << "[id] > 0 )? tensor_" << fNY << "[id] : 0);\n";
466-
out << SP << "}\n";
462+
std::string tnsr = "tensor_" + fNY;
463+
std::string reluSize = ConvertDimShapeToLength(fShapeY);
464+
out << SP << "TMVA::Experimental::SOFIE::Relu(" << tnsr << ", " << tnsr << ", " << reluSize << ");\n";
467465
}
468466

469467
return out.str();
470468
}
471469

472-
std::vector<std::string> GetBlasRoutines() override { return { std::string("Gemm"), std::string("Gemv") }; }
470+
std::vector<std::string> GetBlasRoutines() override { return {"Gemm", "Gemv"}; }
473471

474472
};
475473

tmva/sofie/inc/TMVA/SOFIE_common.hxx

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -772,6 +772,23 @@ inline void Gemm_Call(float *output, bool transa, bool transb, int m, int n, int
772772
&beta, output, ldc);
773773
}
774774

775+
inline void Fill(float *output, float value, int size)
776+
{
777+
std::fill(output, output + size, value);
778+
}
779+
780+
inline void Copy(float *output, float const *input, int size)
781+
{
782+
std::copy(input, input + size, output);
783+
}
784+
785+
inline void Relu(float *output, float const *input, int size)
786+
{
787+
for (int i = 0; i < size; i++) {
788+
output[i] = (input[i] > 0.0f) ? input[i] : 0.0f;
789+
}
790+
}
791+
775792
template <class T>
776793
void ReadTensorFromStream(std::istream &is, T &target, std::string const &expectedName, std::size_t expectedLength)
777794
{

tmva/sofie/test/EmitFromONNX.cxx.in

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,22 +9,52 @@
99
#include "TMVA/RModel.hxx"
1010
#include "TMVA/RModelParser_ONNX.hxx"
1111

12-
using namespace TMVA::Experimental::SOFIE;
13-
14-
int EmitModel(std::string filename, std::string outname) {
12+
int EmitModel(std::string filename, std::string outname)
13+
{
14+
using namespace TMVA::Experimental::SOFIE;
1515

1616
std::cout << "parsing file ..." << filename << std::endl;
17-
RModelParser_ONNX parser;
18-
RModel model = parser.Parse(filename);
19-
model.Generate();
20-
model.OutputGenerated(outname+"_FromONNX.hxx");
17+
{
18+
// The generated code with all optimizations. Used for most SOFIE tests.
19+
RModelParser_ONNX parser;
20+
RModel model = parser.Parse(filename);
21+
model.Generate();
22+
model.OutputGenerated(outname + "_FromONNX.hxx");
23+
}
24+
{
25+
// Generate code without memory re-use for intermediate tensors.
26+
//
27+
// IMPORTANT:
28+
// When memory re-use is enabled, SOFIE may assign multiple intermediate
29+
// tensors to the same memory buffer. This means that values produced earlier
30+
// in the forward pass can be overwritten by later operations.
31+
//
32+
// This is safe for inference, but it breaks source-transformation automatic
33+
// differentiation (e.g. with Clad). In reverse-mode AD, the backward pass
34+
// needs access to the original intermediate values from the forward pass
35+
// (e.g. inputs to activations like ReLU). If those values have been
36+
// overwritten, the generated gradient code will read incorrect data and
37+
// produce wrong results.
38+
//
39+
// Since Clad operates on the generated source code and is not aware of these
40+
// aliasing/reuse optimizations, it cannot reconstruct or recompute the lost
41+
// values. Therefore we disable memory re-use here to ensure correctness of
42+
// the differentiated code.
43+
//
44+
// Note: this increases memory usage but is required for AD correctness.
45+
RModelParser_ONNX parser;
46+
RModel model = parser.Parse(filename);
47+
model.SetOptimizationLevel(OptimizationLevel::kBasic);
48+
model.Generate();
49+
model.OutputGenerated(outname + "_FromONNX_unoptimized.hxx");
50+
}
2151

2252
return 0;
2353
}
2454

25-
int main(int argc, char *argv[]){
26-
27-
@EMIT_CAPTURES@ ;
28-
55+
int main(int argc, char *argv[])
56+
{
57+
// clang-format off
58+
@EMIT_CAPTURES@;
59+
// clang-format on
2960
}
30-

tmva/sofie/test/TestCladAutodiff.cxx

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
constexpr auto modelHeaderSuffix = "_FromONNX.hxx";
2-
constexpr auto modelDataSuffix = "_FromONNX.dat";
1+
constexpr auto modelHeaderSuffix = "_FromONNX_unoptimized.hxx";
2+
constexpr auto modelDataSuffix = "_FromONNX_unoptimized.dat";
33
#include "test_helpers.h"
44

55
#include "input_models/references/Linear_16.ref.hxx"
@@ -76,28 +76,46 @@ float Linear_16_wrapper_num_diff(TMVA_SOFIE_Linear_16::Session const &session, f
7676
.c_str());
7777

7878
// If you want to see the gradient code:
79-
// gInterpreter->ProcessLine("static_cast<void (*)(TMVA_SOFIE_Linear_16::Session const &, float const *, float
80-
// *)>(Linear_16_outer_wrapper_grad_1)"); gInterpreter->ProcessLine("Linear_16_wrapper_pullback");
79+
// clang-format off
80+
// gInterpreter->ProcessLine("static_cast<void (*)(TMVA_SOFIE_Linear_16::Session const &, float const *, float *)>(Linear_16_outer_wrapper_grad_1)");
81+
// gInterpreter->ProcessLine("Linear_16_wrapper_pullback");
8182
// gInterpreter->ProcessLine("TMVA_SOFIE_Linear_16::doInfer_reverse_forw");
8283
// gInterpreter->ProcessLine("TMVA_SOFIE_Linear_16::doInfer_pullback");
84+
// clang-format on
8385

84-
auto retVal = gInterpreter->ProcessLine((R"(
85-
double maxDiff = 0;
86+
gInterpreter->ProcessLine((R"(
87+
float numeric_output[1600]{};
8688
for (std::size_t i = 0; i < std::size(grad_output); ++i) {
87-
double val = grad_output[i];
88-
double ref = Linear_16_wrapper_num_diff(session_linear_16, )" +
89-
inputInterp + R"(, i);
90-
if (val != ref) {
91-
maxDiff = std::max(std::abs(val - ref), maxDiff);
92-
}
89+
numeric_output[i] = Linear_16_wrapper_num_diff(session_linear_16, )" +
90+
inputInterp + R"(, i);
9391
}
94-
double tol = 0.0025;
95-
// the "return" value
96-
(maxDiff < tol);
9792
)")
98-
.c_str());
93+
.c_str());
94+
95+
double tol = 0.0025;
96+
97+
auto arr_size = static_cast<std::size_t>(gInterpreter->ProcessLine("std::size(grad_output);"));
98+
auto grad_arr = reinterpret_cast<float *>(gInterpreter->ProcessLine("grad_output;"));
99+
auto numeric_arr = reinterpret_cast<float *>(gInterpreter->ProcessLine("numeric_output;"));
100+
101+
constexpr std::size_t kMaxPrint = 10;
102+
std::size_t mismatchCount = 0;
99103

100-
EXPECT_EQ(retVal, 1) << "The gradient from Clad and the numeric gradient didn't match within tolerance.";
104+
for (std::size_t i = 0; i < arr_size; ++i) {
105+
double diff = std::abs(grad_arr[i] - numeric_arr[i]);
106+
107+
if (diff > tol) {
108+
if (mismatchCount < kMaxPrint) {
109+
ADD_FAILURE() << "Mismatch at index " << i << " analytic=" << grad_arr[i] << " numeric=" << numeric_arr[i]
110+
<< " diff=" << diff;
111+
}
112+
++mismatchCount;
113+
}
114+
}
115+
116+
if (mismatchCount > kMaxPrint) {
117+
ADD_FAILURE() << "Further mismatches suppressed (total mismatches: " << mismatchCount << ")";
118+
}
101119

102120
// Checking output size
103121
EXPECT_EQ(output.size(), sizeof(Linear_16_ExpectedOutput::all_ones) / sizeof(float));

0 commit comments

Comments
 (0)