From d0b495676fbf8727c8005198332539f5fd151a74 Mon Sep 17 00:00:00 2001 From: Ed Savage Date: Fri, 13 Mar 2026 10:03:59 +1300 Subject: [PATCH 1/3] [ML] Harden pytorch_inference with TorchScript model graph validation (#2936) Add a static TorchScript graph validation layer that rejects models containing operations not observed in supported transformer architectures. This reduces the attack surface by ensuring only known-safe operation sets are permitted, complementing the existing Sandbox2/seccomp defenses. (cherry picked from commit 38f66534ed1a64f3d3110fc86f7d1c30ebecdb2e) --- bin/pytorch_inference/CMakeLists.txt | 2 + bin/pytorch_inference/CModelGraphValidator.cc | 115 +++ bin/pytorch_inference/CModelGraphValidator.h | 91 +++ bin/pytorch_inference/CSupportedOperations.cc | 129 ++++ bin/pytorch_inference/CSupportedOperations.h | 68 ++ bin/pytorch_inference/Main.cc | 39 +- .../unittest/CCommandParserTest.cc | 2 +- bin/pytorch_inference/unittest/CMakeLists.txt | 3 + .../unittest/CModelGraphValidatorTest.cc | 483 +++++++++++++ .../unittest/CResultWriterTest.cc | 4 +- .../unittest/CThreadSettingsTest.cc | 2 +- .../malicious_models/malicious_conditional.pt | Bin 0 -> 2205 bytes .../malicious_models/malicious_file_reader.pt | Bin 0 -> 2141 bytes .../malicious_file_reader_in_submodule.pt | Bin 0 -> 2488 bytes .../malicious_models/malicious_heap_leak.pt | Bin 0 -> 4623 bytes .../malicious_hidden_in_submodule.pt | Bin 0 -> 2517 bytes .../malicious_many_unrecognised.pt | Bin 0 -> 2311 bytes .../malicious_mixed_file_reader.pt | Bin 0 -> 2311 bytes .../malicious_models/malicious_rop_exploit.pt | Bin 0 -> 6109 bytes .../testfiles/reference_model_ops.json | 682 ++++++++++++++++++ cmake/run-validation.cmake | 186 +++++ dev-tools/extract_model_ops/.gitignore | 1 + dev-tools/extract_model_ops/README.md | 166 +++++ .../extract_model_ops/es_it_models/README.md | 41 ++ .../supersimple_pytorch_model_it.pt | Bin 0 -> 1630 bytes .../es_it_models/tiny_text_embedding.pt | Bin 0 -> 1694 bytes .../es_it_models/tiny_text_expansion.pt | Bin 0 -> 2078 bytes .../extract_model_ops/extract_model_ops.py | 142 ++++ .../extract_model_ops/reference_models.json | 20 + dev-tools/extract_model_ops/requirements.txt | 4 + .../extract_model_ops/torchscript_utils.py | 74 ++ .../extract_model_ops/validate_allowlist.py | 194 +++++ .../extract_model_ops/validation_models.json | 29 + dev-tools/generate_malicious_models.py | 274 +++++++ docs/CHANGELOG.asciidoc | 1 + test/CMakeLists.txt | 19 + 36 files changed, 2753 insertions(+), 18 deletions(-) create mode 100644 bin/pytorch_inference/CModelGraphValidator.cc create mode 100644 bin/pytorch_inference/CModelGraphValidator.h create mode 100644 bin/pytorch_inference/CSupportedOperations.cc create mode 100644 bin/pytorch_inference/CSupportedOperations.h create mode 100644 bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc create mode 100644 bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_conditional.pt create mode 100644 bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_file_reader.pt create mode 100644 bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_file_reader_in_submodule.pt create mode 100644 bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_heap_leak.pt create mode 100644 bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_hidden_in_submodule.pt create mode 100644 bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_many_unrecognised.pt create mode 100644 bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_mixed_file_reader.pt create mode 100644 bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_rop_exploit.pt create mode 100644 bin/pytorch_inference/unittest/testfiles/reference_model_ops.json create mode 100644 cmake/run-validation.cmake create mode 100644 dev-tools/extract_model_ops/.gitignore create mode 100644 dev-tools/extract_model_ops/README.md create mode 100644 dev-tools/extract_model_ops/es_it_models/README.md create mode 100644 dev-tools/extract_model_ops/es_it_models/supersimple_pytorch_model_it.pt create mode 100644 dev-tools/extract_model_ops/es_it_models/tiny_text_embedding.pt create mode 100644 dev-tools/extract_model_ops/es_it_models/tiny_text_expansion.pt create mode 100644 dev-tools/extract_model_ops/extract_model_ops.py create mode 100644 dev-tools/extract_model_ops/reference_models.json create mode 100644 dev-tools/extract_model_ops/requirements.txt create mode 100644 dev-tools/extract_model_ops/torchscript_utils.py create mode 100644 dev-tools/extract_model_ops/validate_allowlist.py create mode 100644 dev-tools/extract_model_ops/validation_models.json create mode 100644 dev-tools/generate_malicious_models.py diff --git a/bin/pytorch_inference/CMakeLists.txt b/bin/pytorch_inference/CMakeLists.txt index 5c6ff6352..5e565caa0 100644 --- a/bin/pytorch_inference/CMakeLists.txt +++ b/bin/pytorch_inference/CMakeLists.txt @@ -35,7 +35,9 @@ ml_add_executable(pytorch_inference CBufferedIStreamAdapter.cc CCmdLineParser.cc CCommandParser.cc + CModelGraphValidator.cc CResultWriter.cc + CSupportedOperations.cc CThreadSettings.cc ) diff --git a/bin/pytorch_inference/CModelGraphValidator.cc b/bin/pytorch_inference/CModelGraphValidator.cc new file mode 100644 index 000000000..01658b440 --- /dev/null +++ b/bin/pytorch_inference/CModelGraphValidator.cc @@ -0,0 +1,115 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the following additional limitation. Functionality enabled by the + * files subject to the Elastic License 2.0 may only be used in production when + * invoked by an Elasticsearch process with a license key installed that permits + * use of machine learning features. You may not use this file except in + * compliance with the Elastic License 2.0 and the foregoing additional + * limitation. + */ + +#include "CModelGraphValidator.h" + +#include "CSupportedOperations.h" + +#include + +#include + +#include + +namespace ml { +namespace torch { + +CModelGraphValidator::SResult CModelGraphValidator::validate(const ::torch::jit::Module& module) { + + TStringSet observedOps; + std::size_t nodeCount{0}; + collectModuleOps(module, observedOps, nodeCount); + + if (nodeCount > MAX_NODE_COUNT) { + LOG_ERROR(<< "Model graph is too large: " << nodeCount + << " nodes exceeds limit of " << MAX_NODE_COUNT); + return {false, {}, {}, nodeCount}; + } + + LOG_DEBUG(<< "Model graph contains " << observedOps.size() + << " distinct operations across " << nodeCount << " nodes"); + for (const auto& op : observedOps) { + LOG_DEBUG(<< " observed op: " << op); + } + + auto result = validate(observedOps, CSupportedOperations::ALLOWED_OPERATIONS, + CSupportedOperations::FORBIDDEN_OPERATIONS); + result.s_NodeCount = nodeCount; + return result; +} + +CModelGraphValidator::SResult +CModelGraphValidator::validate(const TStringSet& observedOps, + const std::unordered_set& allowedOps, + const std::unordered_set& forbiddenOps) { + + SResult result; + + // Two-pass check: forbidden ops first, then unrecognised. This lets us + // fail fast when a known-dangerous operation is present and avoids the + // cost of scanning for unrecognised ops on a model we will reject anyway. + for (const auto& op : observedOps) { + if (forbiddenOps.contains(op)) { + result.s_IsValid = false; + result.s_ForbiddenOps.push_back(op); + } + } + + if (result.s_ForbiddenOps.empty()) { + for (const auto& op : observedOps) { + if (allowedOps.contains(op) == false) { + result.s_IsValid = false; + result.s_UnrecognisedOps.push_back(op); + } + } + } + + std::sort(result.s_ForbiddenOps.begin(), result.s_ForbiddenOps.end()); + std::sort(result.s_UnrecognisedOps.begin(), result.s_UnrecognisedOps.end()); + + return result; +} + +void CModelGraphValidator::collectBlockOps(const ::torch::jit::Block& block, + TStringSet& ops, + std::size_t& nodeCount) { + for (const auto* node : block.nodes()) { + if (++nodeCount > MAX_NODE_COUNT) { + return; + } + ops.emplace(node->kind().toQualString()); + for (const auto* subBlock : node->blocks()) { + collectBlockOps(*subBlock, ops, nodeCount); + if (nodeCount > MAX_NODE_COUNT) { + return; + } + } + } +} + +void CModelGraphValidator::collectModuleOps(const ::torch::jit::Module& module, + TStringSet& ops, + std::size_t& nodeCount) { + for (const auto& method : module.get_methods()) { + // Inline all method calls so that operations hidden behind + // prim::CallMethod are surfaced. After inlining, any remaining + // prim::CallMethod indicates a call that could not be resolved + // statically and will be flagged as unrecognised. + auto graph = method.graph()->copy(); + ::torch::jit::Inline(*graph); + collectBlockOps(*graph->block(), ops, nodeCount); + if (nodeCount > MAX_NODE_COUNT) { + return; + } + } +} +} +} diff --git a/bin/pytorch_inference/CModelGraphValidator.h b/bin/pytorch_inference/CModelGraphValidator.h new file mode 100644 index 000000000..2c589dab5 --- /dev/null +++ b/bin/pytorch_inference/CModelGraphValidator.h @@ -0,0 +1,91 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the following additional limitation. Functionality enabled by the + * files subject to the Elastic License 2.0 may only be used in production when + * invoked by an Elasticsearch process with a license key installed that permits + * use of machine learning features. You may not use this file except in + * compliance with the Elastic License 2.0 and the foregoing additional + * limitation. + */ + +#ifndef INCLUDED_ml_torch_CModelGraphValidator_h +#define INCLUDED_ml_torch_CModelGraphValidator_h + +#include + +#include +#include +#include +#include + +namespace ml { +namespace torch { + +//! \brief +//! Validates TorchScript model computation graphs against a set of +//! allowed operations. +//! +//! DESCRIPTION:\n +//! Provides defense-in-depth by statically inspecting the TorchScript +//! graph of a loaded model and rejecting any model that contains +//! operations not present in the allowlist derived from supported +//! transformer architectures. +//! +//! IMPLEMENTATION DECISIONS:\n +//! The validation walks all methods of the module and its submodules +//! recursively, collecting every distinct operation. Any operation +//! that appears in the forbidden set causes immediate rejection. +//! Any operation not in the allowed set is collected and reported. +//! This ensures that even operations buried in helper methods or +//! nested submodules are inspected. +//! +class CModelGraphValidator { +public: + using TStringSet = std::unordered_set; + using TStringVec = std::vector; + + //! Upper bound on the number of graph nodes we are willing to inspect. + //! Transformer models typically have O(10k) nodes after inlining; a + //! limit of 1M provides generous headroom while preventing a + //! pathologically large graph from consuming unbounded memory or CPU. + static constexpr std::size_t MAX_NODE_COUNT{1000000}; + + //! Result of validating a model graph. + struct SResult { + bool s_IsValid{true}; + TStringVec s_ForbiddenOps; + TStringVec s_UnrecognisedOps; + std::size_t s_NodeCount{0}; + }; + +public: + //! Validate the computation graph of the given module against the + //! supported operation allowlist. Recursively inspects all methods + //! across all submodules. + static SResult validate(const ::torch::jit::Module& module); + + //! Validate a pre-collected set of operation names. Useful for + //! unit testing the matching logic without requiring a real model. + static SResult validate(const TStringSet& observedOps, + const std::unordered_set& allowedOps, + const std::unordered_set& forbiddenOps); + +private: + //! Collect all operation names from a block, recursing into sub-blocks. + static void collectBlockOps(const ::torch::jit::Block& block, + TStringSet& ops, + std::size_t& nodeCount); + + //! Inline all method calls and collect ops from the flattened graph. + //! After inlining, prim::CallMethod should not appear; if it does, + //! the call could not be resolved statically and is treated as + //! unrecognised. + static void collectModuleOps(const ::torch::jit::Module& module, + TStringSet& ops, + std::size_t& nodeCount); +}; +} +} + +#endif // INCLUDED_ml_torch_CModelGraphValidator_h diff --git a/bin/pytorch_inference/CSupportedOperations.cc b/bin/pytorch_inference/CSupportedOperations.cc new file mode 100644 index 000000000..1776d492e --- /dev/null +++ b/bin/pytorch_inference/CSupportedOperations.cc @@ -0,0 +1,129 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the following additional limitation. Functionality enabled by the + * files subject to the Elastic License 2.0 may only be used in production when + * invoked by an Elasticsearch process with a license key installed that permits + * use of machine learning features. You may not use this file except in + * compliance with the Elastic License 2.0 and the foregoing additional + * limitation. + */ + +#include "CSupportedOperations.h" + +namespace ml { +namespace torch { + +using namespace std::string_view_literals; + +const CSupportedOperations::TStringViewSet CSupportedOperations::FORBIDDEN_OPERATIONS = { + // Arbitrary memory access — enables heap scanning, address leaks, and + // ROP chain construction. + "aten::as_strided"sv, + "aten::from_file"sv, + "aten::save"sv, + // After graph inlining, method and function calls should be resolved. + // Their presence indicates an opaque call that cannot be validated. + "prim::CallFunction"sv, + "prim::CallMethod"sv, +}; + +// Generated by dev-tools/extract_model_ops/extract_model_ops.py against PyTorch 2.7.1. +// Reference models: bert-base-uncased, roberta-base, distilbert-base-uncased, +// google/electra-small-discriminator, microsoft/mpnet-base, +// microsoft/deberta-base, facebook/dpr-ctx_encoder-single-nq-base, +// google/mobilebert-uncased, xlm-roberta-base, elastic/bge-m3, +// elastic/distilbert-base-{cased,uncased}-finetuned-conll03-english, +// elastic/eis-elser-v2, elastic/elser-v2, elastic/hugging-face-elser, +// elastic/multilingual-e5-small-optimized, elastic/splade-v3, +// elastic/test-elser-v2. +// Additional ops from Elasticsearch integration test models +// (PyTorchModelIT, TextExpansionQueryIT, TextEmbeddingQueryIT). +const CSupportedOperations::TStringViewSet CSupportedOperations::ALLOWED_OPERATIONS = { + // aten operations (core tensor computations) + "aten::Int"sv, + "aten::IntImplicit"sv, + "aten::ScalarImplicit"sv, + "aten::__and__"sv, + "aten::abs"sv, + "aten::add"sv, + "aten::add_"sv, + "aten::arange"sv, + "aten::bitwise_not"sv, + "aten::cat"sv, + "aten::chunk"sv, + "aten::clamp"sv, + "aten::contiguous"sv, + "aten::cumsum"sv, + "aten::div"sv, + "aten::div_"sv, + "aten::dropout"sv, + "aten::embedding"sv, + "aten::expand"sv, + "aten::full_like"sv, + "aten::gather"sv, + "aten::ge"sv, + "aten::gelu"sv, + "aten::hash"sv, + "aten::index"sv, + "aten::index_put_"sv, + "aten::layer_norm"sv, + "aten::len"sv, + "aten::linear"sv, + "aten::log"sv, + "aten::lt"sv, + "aten::manual_seed"sv, + "aten::masked_fill"sv, + "aten::matmul"sv, + "aten::max"sv, + "aten::mean"sv, + "aten::min"sv, + "aten::mul"sv, + "aten::ne"sv, + "aten::neg"sv, + "aten::new_ones"sv, + "aten::ones"sv, + "aten::pad"sv, + "aten::permute"sv, + "aten::pow"sv, + "aten::rand"sv, + "aten::relu"sv, + "aten::repeat"sv, + "aten::reshape"sv, + "aten::rsub"sv, + "aten::scaled_dot_product_attention"sv, + "aten::select"sv, + "aten::size"sv, + "aten::slice"sv, + "aten::softmax"sv, + "aten::sqrt"sv, + "aten::squeeze"sv, + "aten::str"sv, + "aten::sub"sv, + "aten::tanh"sv, + "aten::tensor"sv, + "aten::to"sv, + "aten::transpose"sv, + "aten::type_as"sv, + "aten::unsqueeze"sv, + "aten::view"sv, + "aten::where"sv, + "aten::zeros"sv, + // prim operations (TorchScript graph infrastructure) + "prim::Constant"sv, + "prim::DictConstruct"sv, + "prim::GetAttr"sv, + "prim::If"sv, + "prim::ListConstruct"sv, + "prim::ListUnpack"sv, + "prim::Loop"sv, + "prim::NumToTensor"sv, + "prim::TupleConstruct"sv, + "prim::TupleUnpack"sv, + "prim::device"sv, + "prim::dtype"sv, + "prim::max"sv, + "prim::min"sv, +}; +} +} diff --git a/bin/pytorch_inference/CSupportedOperations.h b/bin/pytorch_inference/CSupportedOperations.h new file mode 100644 index 000000000..3719bec80 --- /dev/null +++ b/bin/pytorch_inference/CSupportedOperations.h @@ -0,0 +1,68 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the following additional limitation. Functionality enabled by the + * files subject to the Elastic License 2.0 may only be used in production when + * invoked by an Elasticsearch process with a license key installed that permits + * use of machine learning features. You may not use this file except in + * compliance with the Elastic License 2.0 and the foregoing additional + * limitation. + */ + +#ifndef INCLUDED_ml_torch_CSupportedOperations_h +#define INCLUDED_ml_torch_CSupportedOperations_h + +#include +#include + +namespace ml { +namespace torch { + +//! \brief +//! Flat allowlist of TorchScript operations observed across all +//! supported transformer architectures (BERT, RoBERTa, DistilBERT, +//! ELECTRA, MPNet, DeBERTa, BART, DPR, MobileBERT, XLM-RoBERTa). +//! +//! DESCRIPTION:\n +//! Generated by tracing reference HuggingFace models with +//! dev-tools/extract_model_ops/extract_model_ops.py and collecting the union of all +//! operations from the inlined forward() computation graphs. +//! +//! IMPLEMENTATION DECISIONS:\n +//! Stored as a compile-time data structure rather than an external +//! config file to avoid runtime loading failures and to keep the +//! security boundary self-contained. The list should be regenerated +//! whenever the set of supported architectures changes or when +//! upgrading the PyTorch version. +//! +class CSupportedOperations { +public: + using TStringViewSet = std::unordered_set; + + //! Operations explicitly forbidden regardless of the allowlist. + //! + //! The forbidden list is checked separately from (and takes precedence + //! over) the allowed list. This two-tier approach provides: + //! + //! 1. Stable, targeted error messages for known-dangerous operations + //! (e.g. "model contains forbidden operation: aten::save") rather + //! than the generic "unrecognised operation" that the allowlist + //! would produce. This helps model authors diagnose rejections. + //! + //! 2. A safety net against accidental allowlist expansion. If a + //! future PyTorch upgrade or new architecture inadvertently adds + //! a dangerous op to the allowed set, the forbidden list still + //! blocks it. The forbidden check is independent of regeneration. + //! + //! 3. Defence-in-depth: two independent mechanisms must both agree + //! before an operation is permitted, reducing the risk of a + //! single-point allowlist error opening an attack vector. + static const TStringViewSet FORBIDDEN_OPERATIONS; + + //! Union of all TorchScript operations observed in supported architectures. + static const TStringViewSet ALLOWED_OPERATIONS; +}; +} +} + +#endif // INCLUDED_ml_torch_CSupportedOperations_h diff --git a/bin/pytorch_inference/Main.cc b/bin/pytorch_inference/Main.cc index 98f303df4..776ee07a8 100644 --- a/bin/pytorch_inference/Main.cc +++ b/bin/pytorch_inference/Main.cc @@ -27,6 +27,7 @@ #include "CBufferedIStreamAdapter.h" #include "CCmdLineParser.h" #include "CCommandParser.h" +#include "CModelGraphValidator.h" #include "CResultWriter.h" #include "CThreadSettings.h" @@ -42,25 +43,35 @@ #include namespace { -// Add more forbidden ops here if needed -const std::unordered_set FORBIDDEN_OPERATIONS = {"aten::from_file", "aten::save"}; - void verifySafeModel(const torch::jit::script::Module& module_) { try { - const auto method = module_.get_method("forward"); - const auto graph = method.graph(); - for (const auto& node : graph->nodes()) { - const std::string opName = node->kind().toQualString(); - if (FORBIDDEN_OPERATIONS.find(opName) != FORBIDDEN_OPERATIONS.end()) { - HANDLE_FATAL(<< "Loading the inference process failed because it contains forbidden operation: " - << opName); - } + auto result = ml::torch::CModelGraphValidator::validate(module_); + + if (result.s_ForbiddenOps.empty() == false) { + std::string ops = ml::core::CStringUtils::join(result.s_ForbiddenOps, ", "); + HANDLE_FATAL(<< "Model contains forbidden operations: " << ops); } + + if (result.s_UnrecognisedOps.empty() == false) { + std::string ops = ml::core::CStringUtils::join(result.s_UnrecognisedOps, ", "); + HANDLE_FATAL(<< "Model graph does not match any supported architecture. " + << "Unrecognised operations: " << ops); + } + + if (result.s_NodeCount > ml::torch::CModelGraphValidator::MAX_NODE_COUNT) { + HANDLE_FATAL(<< "Model graph is too large: " << result.s_NodeCount << " nodes exceeds limit of " + << ml::torch::CModelGraphValidator::MAX_NODE_COUNT); + } + + if (result.s_IsValid == false) { + HANDLE_FATAL(<< "Model graph validation failed"); + } + + LOG_DEBUG(<< "Model verified: " << result.s_NodeCount + << " nodes, all operations match supported architectures."); } catch (const c10::Error& e) { - LOG_FATAL(<< "Failed to get forward method: " << e.what()); + HANDLE_FATAL(<< "Model graph validation failed: " << e.what()); } - - LOG_DEBUG(<< "Model verified: no forbidden operations detected."); } } diff --git a/bin/pytorch_inference/unittest/CCommandParserTest.cc b/bin/pytorch_inference/unittest/CCommandParserTest.cc index 7dcf6a7ef..5c7e7e4fd 100644 --- a/bin/pytorch_inference/unittest/CCommandParserTest.cc +++ b/bin/pytorch_inference/unittest/CCommandParserTest.cc @@ -9,7 +9,7 @@ * limitation. */ -#include "../CCommandParser.h" +#include #include diff --git a/bin/pytorch_inference/unittest/CMakeLists.txt b/bin/pytorch_inference/unittest/CMakeLists.txt index dd5394492..fe3c544a5 100644 --- a/bin/pytorch_inference/unittest/CMakeLists.txt +++ b/bin/pytorch_inference/unittest/CMakeLists.txt @@ -14,6 +14,7 @@ project("ML pytorch_inference unit tests") set (SRCS Main.cc CCommandParserTest.cc + CModelGraphValidatorTest.cc CResultWriterTest.cc CThreadSettingsTest.cc ) @@ -33,3 +34,5 @@ set(ML_LINK_LIBRARIES ) ml_add_test_executable(pytorch_inference ${SRCS}) + +target_include_directories(ml_test_pytorch_inference PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/..) diff --git a/bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc b/bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc new file mode 100644 index 000000000..7818e88f0 --- /dev/null +++ b/bin/pytorch_inference/unittest/CModelGraphValidatorTest.cc @@ -0,0 +1,483 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the following additional limitation. Functionality enabled by the + * files subject to the Elastic License 2.0 may only be used in production when + * invoked by an Elasticsearch process with a license key installed that permits + * use of machine learning features. You may not use this file except in + * compliance with the Elastic License 2.0 and the foregoing additional + * limitation. + */ + +#include + +#include + +#include +#include + +#include + +#include +#include +#include +#include +#include + +using namespace ml::torch; +using TStringSet = CModelGraphValidator::TStringSet; +using TStringViewSet = std::unordered_set; + +BOOST_AUTO_TEST_SUITE(CModelGraphValidatorTest) + +BOOST_AUTO_TEST_CASE(testAllAllowedOpsPass) { + // A model using only allowed ops should pass validation. + TStringSet observed{"aten::linear", "aten::layer_norm", "aten::gelu", + "aten::embedding", "prim::Constant", "prim::GetAttr"}; + + auto result = CModelGraphValidator::validate( + observed, CSupportedOperations::ALLOWED_OPERATIONS, + CSupportedOperations::FORBIDDEN_OPERATIONS); + + BOOST_REQUIRE(result.s_IsValid); + BOOST_REQUIRE(result.s_ForbiddenOps.empty()); + BOOST_REQUIRE(result.s_UnrecognisedOps.empty()); +} + +BOOST_AUTO_TEST_CASE(testEmptyGraphPasses) { + TStringSet observed; + + auto result = CModelGraphValidator::validate( + observed, CSupportedOperations::ALLOWED_OPERATIONS, + CSupportedOperations::FORBIDDEN_OPERATIONS); + + BOOST_REQUIRE(result.s_IsValid); + BOOST_REQUIRE(result.s_ForbiddenOps.empty()); + BOOST_REQUIRE(result.s_UnrecognisedOps.empty()); +} + +BOOST_AUTO_TEST_CASE(testForbiddenOpsRejected) { + TStringSet observed{"aten::linear", "aten::from_file", "prim::Constant"}; + + auto result = CModelGraphValidator::validate( + observed, CSupportedOperations::ALLOWED_OPERATIONS, + CSupportedOperations::FORBIDDEN_OPERATIONS); + + BOOST_REQUIRE(result.s_IsValid == false); + BOOST_REQUIRE_EQUAL(1, result.s_ForbiddenOps.size()); + BOOST_REQUIRE_EQUAL("aten::from_file", result.s_ForbiddenOps[0]); + BOOST_REQUIRE(result.s_UnrecognisedOps.empty()); +} + +BOOST_AUTO_TEST_CASE(testMultipleForbiddenOps) { + TStringSet observed{"aten::from_file", "aten::save"}; + + auto result = CModelGraphValidator::validate( + observed, CSupportedOperations::ALLOWED_OPERATIONS, + CSupportedOperations::FORBIDDEN_OPERATIONS); + + BOOST_REQUIRE(result.s_IsValid == false); + BOOST_REQUIRE_EQUAL(2, result.s_ForbiddenOps.size()); + BOOST_REQUIRE_EQUAL("aten::from_file", result.s_ForbiddenOps[0]); + BOOST_REQUIRE_EQUAL("aten::save", result.s_ForbiddenOps[1]); +} + +BOOST_AUTO_TEST_CASE(testUnrecognisedOpsRejected) { + TStringSet observed{"aten::linear", "custom::evil_op", "prim::Constant"}; + + auto result = CModelGraphValidator::validate( + observed, CSupportedOperations::ALLOWED_OPERATIONS, + CSupportedOperations::FORBIDDEN_OPERATIONS); + + BOOST_REQUIRE(result.s_IsValid == false); + BOOST_REQUIRE(result.s_ForbiddenOps.empty()); + BOOST_REQUIRE_EQUAL(1, result.s_UnrecognisedOps.size()); + BOOST_REQUIRE_EQUAL("custom::evil_op", result.s_UnrecognisedOps[0]); +} + +BOOST_AUTO_TEST_CASE(testMixedForbiddenAndUnrecognised) { + // When forbidden ops are present, the validator short-circuits and + // does not report unrecognised ops — we reject immediately. + TStringSet observed{"aten::save", "custom::backdoor", "aten::linear"}; + + auto result = CModelGraphValidator::validate( + observed, CSupportedOperations::ALLOWED_OPERATIONS, + CSupportedOperations::FORBIDDEN_OPERATIONS); + + BOOST_REQUIRE(result.s_IsValid == false); + BOOST_REQUIRE_EQUAL(1, result.s_ForbiddenOps.size()); + BOOST_REQUIRE_EQUAL("aten::save", result.s_ForbiddenOps[0]); + BOOST_REQUIRE(result.s_UnrecognisedOps.empty()); +} + +BOOST_AUTO_TEST_CASE(testResultsSorted) { + TStringSet observed{"zzz::unknown", "aaa::unknown", "mmm::unknown"}; + + auto result = CModelGraphValidator::validate( + observed, CSupportedOperations::ALLOWED_OPERATIONS, + CSupportedOperations::FORBIDDEN_OPERATIONS); + + BOOST_REQUIRE(result.s_IsValid == false); + BOOST_REQUIRE_EQUAL(3, result.s_UnrecognisedOps.size()); + BOOST_REQUIRE_EQUAL("aaa::unknown", result.s_UnrecognisedOps[0]); + BOOST_REQUIRE_EQUAL("mmm::unknown", result.s_UnrecognisedOps[1]); + BOOST_REQUIRE_EQUAL("zzz::unknown", result.s_UnrecognisedOps[2]); +} + +BOOST_AUTO_TEST_CASE(testTypicalBertOps) { + // Simulate a realistic BERT-like op set. + TStringSet observed{"aten::Int", + "aten::ScalarImplicit", + "aten::__and__", + "aten::add", + "aten::arange", + "aten::contiguous", + "aten::div", + "aten::dropout", + "aten::embedding", + "aten::expand", + "aten::gelu", + "aten::ge", + "aten::index", + "aten::layer_norm", + "aten::linear", + "aten::masked_fill", + "aten::matmul", + "aten::mul", + "aten::new_ones", + "aten::permute", + "aten::reshape", + "aten::scaled_dot_product_attention", + "aten::size", + "aten::slice", + "aten::softmax", + "aten::tanh", + "aten::to", + "aten::transpose", + "aten::unsqueeze", + "aten::view", + "prim::Constant", + "prim::DictConstruct", + "prim::GetAttr", + "prim::If", + "prim::ListConstruct", + "prim::NumToTensor", + "prim::TupleConstruct"}; + + auto result = CModelGraphValidator::validate( + observed, CSupportedOperations::ALLOWED_OPERATIONS, + CSupportedOperations::FORBIDDEN_OPERATIONS); + + BOOST_REQUIRE(result.s_IsValid); + BOOST_REQUIRE(result.s_ForbiddenOps.empty()); + BOOST_REQUIRE(result.s_UnrecognisedOps.empty()); +} + +BOOST_AUTO_TEST_CASE(testCustomAllowlistAndForbiddenList) { + // Verify the three-argument overload works with arbitrary lists. + TStringViewSet allowed{"op::a", "op::b", "op::c"}; + TStringViewSet forbidden{"op::bad"}; + TStringSet observed{"op::a", "op::b"}; + + auto result = CModelGraphValidator::validate(observed, allowed, forbidden); + BOOST_REQUIRE(result.s_IsValid); + + observed.emplace("op::bad"); + result = CModelGraphValidator::validate(observed, allowed, forbidden); + BOOST_REQUIRE(result.s_IsValid == false); + BOOST_REQUIRE_EQUAL(1, result.s_ForbiddenOps.size()); + + observed.erase("op::bad"); + observed.emplace("op::unknown"); + result = CModelGraphValidator::validate(observed, allowed, forbidden); + BOOST_REQUIRE(result.s_IsValid == false); + BOOST_REQUIRE_EQUAL(1, result.s_UnrecognisedOps.size()); +} + +BOOST_AUTO_TEST_CASE(testCallMethodForbiddenAfterInlining) { + // prim::CallMethod must not appear after graph inlining; its presence + // means a method call could not be resolved and the graph cannot be + // fully validated. + TStringSet observed{"aten::linear", "prim::Constant", "prim::CallMethod"}; + + auto result = CModelGraphValidator::validate( + observed, CSupportedOperations::ALLOWED_OPERATIONS, + CSupportedOperations::FORBIDDEN_OPERATIONS); + + BOOST_REQUIRE(result.s_IsValid == false); + BOOST_REQUIRE_EQUAL(1, result.s_ForbiddenOps.size()); + BOOST_REQUIRE_EQUAL("prim::CallMethod", result.s_ForbiddenOps[0]); + BOOST_REQUIRE(result.s_UnrecognisedOps.empty()); +} + +BOOST_AUTO_TEST_CASE(testCallFunctionForbiddenAfterInlining) { + TStringSet observed{"aten::linear", "prim::CallFunction"}; + + auto result = CModelGraphValidator::validate( + observed, CSupportedOperations::ALLOWED_OPERATIONS, + CSupportedOperations::FORBIDDEN_OPERATIONS); + + BOOST_REQUIRE(result.s_IsValid == false); + BOOST_REQUIRE_EQUAL(1, result.s_ForbiddenOps.size()); + BOOST_REQUIRE_EQUAL("prim::CallFunction", result.s_ForbiddenOps[0]); +} + +BOOST_AUTO_TEST_CASE(testMaxNodeCountConstant) { + BOOST_REQUIRE(CModelGraphValidator::MAX_NODE_COUNT > 0); + BOOST_REQUIRE_EQUAL(std::size_t{1000000}, CModelGraphValidator::MAX_NODE_COUNT); +} + +BOOST_AUTO_TEST_CASE(testForbiddenOpAlsoInAllowlist) { + // If an op appears in both forbidden and allowed, forbidden takes precedence. + TStringViewSet allowed{"aten::from_file", "aten::linear"}; + TStringViewSet forbidden{"aten::from_file"}; + TStringSet observed{"aten::from_file", "aten::linear"}; + + auto result = CModelGraphValidator::validate(observed, allowed, forbidden); + BOOST_REQUIRE(result.s_IsValid == false); + BOOST_REQUIRE_EQUAL(1, result.s_ForbiddenOps.size()); + BOOST_REQUIRE_EQUAL("aten::from_file", result.s_ForbiddenOps[0]); +} + +// --- Integration tests using real TorchScript modules --- + +BOOST_AUTO_TEST_CASE(testValidModuleWithAllowedOps) { + // A simple module using only aten::add and aten::mul, both of which + // are in the allowed set. + ::torch::jit::Module m("__torch__.ValidModel"); + m.define(R"( + def forward(self, x: Tensor) -> Tensor: + return x + x * x + )"); + + auto result = CModelGraphValidator::validate(m); + + BOOST_REQUIRE(result.s_IsValid); + BOOST_REQUIRE(result.s_ForbiddenOps.empty()); + BOOST_REQUIRE(result.s_UnrecognisedOps.empty()); + BOOST_REQUIRE(result.s_NodeCount > 0); +} + +BOOST_AUTO_TEST_CASE(testModuleWithUnrecognisedOps) { + // torch.sin is not in the transformer allowlist. + ::torch::jit::Module m("__torch__.UnknownOps"); + m.define(R"( + def forward(self, x: Tensor) -> Tensor: + return torch.sin(x) + )"); + + auto result = CModelGraphValidator::validate(m); + + BOOST_REQUIRE(result.s_IsValid == false); + BOOST_REQUIRE(result.s_ForbiddenOps.empty()); + BOOST_REQUIRE(result.s_UnrecognisedOps.empty() == false); + bool foundSin = false; + for (const auto& op : result.s_UnrecognisedOps) { + if (op == "aten::sin") { + foundSin = true; + } + } + BOOST_REQUIRE(foundSin); +} + +BOOST_AUTO_TEST_CASE(testModuleNodeCountPopulated) { + ::torch::jit::Module m("__torch__.NodeCount"); + m.define(R"( + def forward(self, x: Tensor) -> Tensor: + a = x + x + b = a * a + c = b - a + return c + )"); + + auto result = CModelGraphValidator::validate(m); + + BOOST_REQUIRE(result.s_NodeCount > 0); +} + +BOOST_AUTO_TEST_CASE(testModuleWithSubmoduleInlines) { + // Create a parent module with a child submodule. After inlining, + // the child's operations should be visible and validated. + ::torch::jit::Module child("__torch__.Child"); + child.define(R"( + def forward(self, x: Tensor) -> Tensor: + return torch.sin(x) + )"); + + ::torch::jit::Module parent("__torch__.Parent"); + parent.register_module("child", child); + parent.define(R"( + def forward(self, x: Tensor) -> Tensor: + return self.child.forward(x) + x + )"); + + auto result = CModelGraphValidator::validate(parent); + + BOOST_REQUIRE(result.s_IsValid == false); + bool foundSin = false; + for (const auto& op : result.s_UnrecognisedOps) { + if (op == "aten::sin") { + foundSin = true; + } + } + BOOST_REQUIRE(foundSin); +} + +// --- Integration tests with malicious .pt model fixtures --- +// +// These load real TorchScript models that simulate attack vectors. +// The .pt files are generated by testfiles/generate_malicious_models.py. + +namespace { +bool hasForbiddenOp(const CModelGraphValidator::SResult& result, const std::string& op) { + return std::find(result.s_ForbiddenOps.begin(), result.s_ForbiddenOps.end(), + op) != result.s_ForbiddenOps.end(); +} + +bool hasUnrecognisedOp(const CModelGraphValidator::SResult& result, const std::string& op) { + return std::find(result.s_UnrecognisedOps.begin(), result.s_UnrecognisedOps.end(), + op) != result.s_UnrecognisedOps.end(); +} +} + +BOOST_AUTO_TEST_CASE(testMaliciousFileReader) { + // A model that uses aten::from_file to read arbitrary files. + auto module = ::torch::jit::load("testfiles/malicious_models/malicious_file_reader.pt"); + auto result = CModelGraphValidator::validate(module); + + BOOST_REQUIRE(result.s_IsValid == false); + BOOST_REQUIRE(hasForbiddenOp(result, "aten::from_file")); +} + +BOOST_AUTO_TEST_CASE(testMaliciousMixedFileReader) { + // A model that mixes allowed ops (aten::add) with a forbidden + // aten::from_file. The entire model must be rejected. + auto module = ::torch::jit::load("testfiles/malicious_models/malicious_mixed_file_reader.pt"); + auto result = CModelGraphValidator::validate(module); + + BOOST_REQUIRE(result.s_IsValid == false); + BOOST_REQUIRE(hasForbiddenOp(result, "aten::from_file")); + BOOST_REQUIRE(result.s_UnrecognisedOps.empty()); +} + +BOOST_AUTO_TEST_CASE(testMaliciousHiddenInSubmodule) { + // Unrecognised ops buried three levels deep in nested submodules. + // The validator must inline through all submodules to find them. + auto module = ::torch::jit::load("testfiles/malicious_models/malicious_hidden_in_submodule.pt"); + auto result = CModelGraphValidator::validate(module); + + BOOST_REQUIRE(result.s_IsValid == false); + BOOST_REQUIRE(result.s_ForbiddenOps.empty()); + BOOST_REQUIRE(hasUnrecognisedOp(result, "aten::sin")); +} + +BOOST_AUTO_TEST_CASE(testMaliciousConditionalBranch) { + // An unrecognised op hidden inside a conditional branch. The + // validator must recurse into prim::If blocks to detect it. + auto module = ::torch::jit::load("testfiles/malicious_models/malicious_conditional.pt"); + auto result = CModelGraphValidator::validate(module); + + BOOST_REQUIRE(result.s_IsValid == false); + BOOST_REQUIRE(hasUnrecognisedOp(result, "aten::sin")); +} + +BOOST_AUTO_TEST_CASE(testMaliciousManyUnrecognisedOps) { + // A model using many different unrecognised ops (sin, cos, tan, exp). + auto module = ::torch::jit::load("testfiles/malicious_models/malicious_many_unrecognised.pt"); + auto result = CModelGraphValidator::validate(module); + + BOOST_REQUIRE(result.s_IsValid == false); + BOOST_REQUIRE(result.s_ForbiddenOps.empty()); + BOOST_REQUIRE(result.s_UnrecognisedOps.size() >= 4); + BOOST_REQUIRE(hasUnrecognisedOp(result, "aten::sin")); + BOOST_REQUIRE(hasUnrecognisedOp(result, "aten::cos")); + BOOST_REQUIRE(hasUnrecognisedOp(result, "aten::tan")); + BOOST_REQUIRE(hasUnrecognisedOp(result, "aten::exp")); +} + +BOOST_AUTO_TEST_CASE(testMaliciousFileReaderInSubmodule) { + // The forbidden aten::from_file is hidden inside a submodule. + // After inlining, the validator must still detect it. + auto module = ::torch::jit::load("testfiles/malicious_models/malicious_file_reader_in_submodule.pt"); + auto result = CModelGraphValidator::validate(module); + + BOOST_REQUIRE(result.s_IsValid == false); + BOOST_REQUIRE(hasForbiddenOp(result, "aten::from_file")); +} + +// --- Sandbox2 attack models --- +// +// These reproduce real-world attack vectors that exploit torch.as_strided +// to read out-of-bounds heap memory, leak libtorch addresses, and build +// ROP chains that call mprotect + shellcode to write arbitrary files. +// The graph validator must reject them because aten::as_strided is in +// the forbidden operations list. + +BOOST_AUTO_TEST_CASE(testMaliciousHeapLeak) { + // A model that uses torch.as_strided with a malicious storage offset + // to scan the heap for libtorch pointers and leak their addresses + // via an assertion message. + auto module = ::torch::jit::load("testfiles/malicious_models/malicious_heap_leak.pt"); + auto result = CModelGraphValidator::validate(module); + + BOOST_REQUIRE(result.s_IsValid == false); + BOOST_REQUIRE(hasForbiddenOp(result, "aten::as_strided")); +} + +BOOST_AUTO_TEST_CASE(testMaliciousRopExploit) { + // A model that extends the heap-leak technique to overwrite function + // pointers and build a ROP chain: mprotect a heap page as executable, + // then jump to shellcode that writes files to disk. + auto module = ::torch::jit::load("testfiles/malicious_models/malicious_rop_exploit.pt"); + auto result = CModelGraphValidator::validate(module); + + BOOST_REQUIRE(result.s_IsValid == false); + BOOST_REQUIRE(hasForbiddenOp(result, "aten::as_strided")); +} + +// --- Allowlist drift detection --- +// +// Validates that ALLOWED_OPERATIONS covers every operation observed in +// the reference HuggingFace models. The golden file is generated by +// dev-tools/extract_model_ops/extract_model_ops.py --golden and should +// be regenerated whenever PyTorch is upgraded or the set of supported +// architectures changes. + +BOOST_AUTO_TEST_CASE(testAllowlistCoversReferenceModels) { + std::ifstream file("testfiles/reference_model_ops.json"); + BOOST_REQUIRE_MESSAGE(file.is_open(), + "Could not open testfiles/reference_model_ops.json — " + "regenerate with: python3 dev-tools/extract_model_ops/" + "extract_model_ops.py --golden " + "bin/pytorch_inference/unittest/testfiles/reference_model_ops.json"); + + std::ostringstream buf; + buf << file.rdbuf(); + auto root = boost::json::parse(buf.str()).as_object(); + + auto& models = root.at("models").as_object(); + BOOST_REQUIRE_MESSAGE(models.size() > 0, "Golden file contains no models"); + + const auto& allowed = CSupportedOperations::ALLOWED_OPERATIONS; + const auto& forbidden = CSupportedOperations::FORBIDDEN_OPERATIONS; + + for (const auto & [ arch, entry ] : models) { + const auto& info = entry.as_object(); + const auto& ops = info.at("ops").as_array(); + std::string modelId{info.at("model_id").as_string()}; + + for (const auto& opVal : ops) { + std::string op{opVal.as_string()}; + + BOOST_CHECK_MESSAGE(forbidden.count(op) == 0, + arch << " (" << modelId << "): op " << op << " is in FORBIDDEN_OPERATIONS — a legitimate model " + << "should not use forbidden ops"); + + BOOST_CHECK_MESSAGE(allowed.count(op) == 1, + arch << " (" << modelId << "): op " << op << " is not in ALLOWED_OPERATIONS — update the allowlist " + << "or check if this op was introduced by a PyTorch upgrade"); + } + } +} + +BOOST_AUTO_TEST_SUITE_END() diff --git a/bin/pytorch_inference/unittest/CResultWriterTest.cc b/bin/pytorch_inference/unittest/CResultWriterTest.cc index 97b99038a..7803bbc39 100644 --- a/bin/pytorch_inference/unittest/CResultWriterTest.cc +++ b/bin/pytorch_inference/unittest/CResultWriterTest.cc @@ -9,9 +9,9 @@ * limitation. */ -#include "../CResultWriter.h" +#include -#include "../CThreadSettings.h" +#include #include #include diff --git a/bin/pytorch_inference/unittest/CThreadSettingsTest.cc b/bin/pytorch_inference/unittest/CThreadSettingsTest.cc index 8ab8d03d2..759affb02 100644 --- a/bin/pytorch_inference/unittest/CThreadSettingsTest.cc +++ b/bin/pytorch_inference/unittest/CThreadSettingsTest.cc @@ -9,7 +9,7 @@ * limitation. */ -#include "../CThreadSettings.h" +#include #include diff --git a/bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_conditional.pt b/bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_conditional.pt new file mode 100644 index 0000000000000000000000000000000000000000..114707e6a7fab8d3ab35ec81472020aba354cdc2 GIT binary patch literal 2205 zcmWIWW@cev;NW1u0CEg047rIpnaP>?rN!~d`FSasC7Jnoi8=Zyi6x181=%@nP7DkU zOv&-_CHY0k8S(L4&Im=mFr8e544RF#8WA8HN{SLQ^D^_&3mH2j#DM(x%;Na8(wv<5 zq{QUx^2DN)_>BDg>_R5L(xgIWy#Q}^juZRlpNR&l1mOUnQ$>JI#dNDWg8&YvCg-Q5 z>cbtVS5WEZ#KRDU1akWf`Iy9@HF>?nmYWO_zYE^HTgW1o{ZQQR`Mf^UWBn^O zcw0W@mXPQym^kT1gy7p3%WhxZxIKJP^yBv9zaM4_PLGvWa(Q$#P{?4`w*>Pd9KTB6_KsvQ*fma7g@c z<@#rgVV%w^=Sbh*0}L8=4uLHV(~kjz$%_$RP#H2v6B|@|DXB@N>G9x*7Dov*l<@QI zzTzOzX8n0-+R?=?*7i%z@vRnUDUF=LYCp}0D>sdCN>tA{x|K*wf z2lfu7ob`CVMIIGL6AE>@{zMI>uYP+ak zcAi>EOM5_uRcP?#6??c$*EiR_JRMn-f2GZ7@$9+k{Q2|VzEr;TyJ%v;@tf7XUdB2m zm$lM0jB+d185aM`(f;;%TWb2Aw6{Osot|=CRclRgO6=KQ$J`qQQ)3pS&EMcQRnlks z%m`d~Z5SvC#< z?Z0=8=ghy#aG;sfCy}jK!TBXhKrjGEx)XKo!h%7`!F{s1%xQp z$94c?-+~E}0vdq1E=WBpc9v5XS8W9R>v)A)p5;R`rWhi!y;lXH_Dotc}l1adVQv8z2T| zW~LVAhUS)*Mn*=)<_4BV#>NI_7G{Q~rUoWv#>U2`20+@v5ahyYmfvqQfUX1K0B=T6 zV8Kg%o<5JfX`Nu-Wq&T1eJyIB|oHM*h5iQ5px&}+aXPo$xUA`RXD z$mv%Z#rPY{xQ$23FX#p$CrMcp14Dqx6T5q%nE>4=y^M!{1wx?#wPKnKOJ zFjidd!k#n(yxG`bCde`C!p&lbvS73`IH55x`~aE{0ziFWDQjp&4vapa7%M2>Ft7t5 INIgU?0Cu6J4*&oF literal 0 HcmV?d00001 diff --git a/bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_file_reader.pt b/bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_file_reader.pt new file mode 100644 index 0000000000000000000000000000000000000000..fb0b26f4691f58d71846fc580f4db24c2022b5fd GIT binary patch literal 2141 zcmWIWW@cev;NW1u0CEg047rIpnaP>?rN!}SnK`NPMX8A?sYUuJi6x181=%@nP7DkU zOv&-_CHY0k8S(L4Za_srU^Tw^DXBSJg$$aFwHgs1!%K=1GxIX@(hC_oBgBCG_)MUg zr8znANr}nX<%vZp@frE~*@aAgrAdX%dI8?-9LFQiGKT?Gf^Y!Pl_Ee_VmecuK@^88 zlYvgqhr3O$pwi7rkRb{k#P&O$K4iet_WobjgUsgISLM1dF3#q9JUE8sFoO&w2z~udd%bZG zc$@Of?hJd(J0lKprJl?8-X7rK+Na^3*drk>l(al4s(MmMIKN%kyI-&W2-Z8qbf{Sw zh`ELaE`K%W^9%KQ$>iO~FP{vXGf(!&@0Xr7widU%AFq9TSoZ9qC9N~RO>(?_$gD`` z&?2+#Gfyfzh4LKQ9Vn6Px;2vfy1DpcjSF=~QF}@pw}nNtKE80xZOY+8S2k~%z_mEE zFYJSw@~*St?20GZW&Yp45y4%e|MPS2fdzACXB`xNQKu9=VZM;mqZz{L$`-47)`+x3 zu0EKa6>Gw~e~M`SJNwk&NYRW9LMNV2-6D4K)4CG3j`+k@waqim8L76ec|PUbD;ZJV z+0}m!znU*-@G=wu}7~-&?EVZaOGe6JGNf&N3N;<*CG3Ek!@8Z1lP$rZf~S{fM{8Jine8W|fK7#kQ_m>C!t7+P8y7@C*^rHnx?G^+A{ zq5*Ur2nTpGf&vR(awC_Pav%v5fWI7vn1v``kqZ=c6mymVdDzXuC~eUVMNYwnD2AQ~ zrfDJ#MU+hF{zp!~$|%NfX2xwiQl>yR5IIT8q8KOwOrF@?3(W-RMj@vxc@(22;WP@K zqR|aQP6RqAhKaG_au@cb8Q{&v1~Wm9Sr={=JCp^Zoxur>fnf*Gd=LQY155peDgZ_w RP>dCnZy4Bt5TqWW765rTe&zrG literal 0 HcmV?d00001 diff --git a/bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_file_reader_in_submodule.pt b/bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_file_reader_in_submodule.pt new file mode 100644 index 0000000000000000000000000000000000000000..4d6f6328b7d69658543f879c26f1b817cbc23e66 GIT binary patch literal 2488 zcmWIWW@cev;NW1u0NM=Z47rIpnaP>?rN!}SnK`NPMX8A?sYUUbdGW=iNxAtcr8%kk zDTyVCdIi}zZcfGwQKW+grsVkelKi6NjQDsiH=rd!U^_hXf?+mr6*6cx)@nq6oLEwn zn3z3g)AV$GZ;HFnEXnU3R!`eO)tQkokK0hV9`{dAs`$83?va?AYq1*2?HlFfs~w| zlBy5)r(Qv&n-eP}wDt!0dLJ=p^ zYMDn%1^2$yJh`WTPvzc+_YTiB`xIHE`10PqGk0Gc$w~P+f&ZEqLx-@s-6@5W%7^&A z-ATH_0g@Y-+a0y=VrQGlDnG2Q?|#xra;KTqP*kBWsxnL@4OK<%+yrx;7BUC z!rx(9r@mWui{M1nB}`8>ET_gvm3sXz?bqsJ=+oWf^;*_&$F9irxo+3;%O~7<^>iy! zq8j(jJ(a@RZ;(Wg(-q>p93 zPL7+k{y3g?3i0TF+JVI0PUFP!Zd#K{V0y_D3X()4(6`V_{N zK2!UA^RE~Pv`u~^ma#**?_!bXQ8PXDB}}esGcsqgEm2*uQ$f2_?Z|_gAIV3zEx-3? zd-nF&(xrVm{m0Mmd4I+?WFKGCgD1z8>~&rojV$!47wg{qK0dRsqNGlzc}t1Kghb17 z&BDOs$tKOe;#XQ-)tFVe>6Sm&o39JXCf#jre_p4vxagb0$AuSkLg%lUw7oMtm20(B z=&oCxPnSGdH}TJ<6($Fxgx5+q24)_fdCY6$w34Lhy~;bE*h;QH^I)IGd+CrH#ww@h zNjKjTs5gySpjN-ZJydc@tg6N1!$FtxRs_Dt*)!?#r@xnvzVq4a#w+_(>}Ai4JQb;# zk2hB?`e|BMwf*_FyGEZPxfpkg|9PsmSULSa%dzR+dRg123&zZju=B59o*^dSknps7 z_Y=Q=YB{!z0!kb<^Br=Zn$#;8{?7QU6?NO^y4jsm zH%zUc9(kscmCUeXgTdCH?(u6*I&MiU-r~9MN0U=(;%VR8D$Av%l;dx`5d)c49v_- zEzAwgEiH|VjEv0#dHx?qD+To9CUM$lco-exs!-97oH{1O-4>qmMA8R z0rLS7;ekD&2Y9ow!5k{btP3}o9m;~y^T27Cfnf(I(*XfcA6P1k6T%0kFrXMKSeBiE KodKjCq80$kd=vZt literal 0 HcmV?d00001 diff --git a/bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_heap_leak.pt b/bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_heap_leak.pt new file mode 100644 index 0000000000000000000000000000000000000000..3458ab76a4d5f7e16a372d7f7f9a4259e5b88c23 GIT binary patch literal 4623 zcmbVP2UJtp)(%BN7my}|4$?!1L=dC|>AfR_9!e0B&;%)>6e%hrO^S@P0TJmS5v(B_Pwg9#zY`{!bmZFMBjvyxeiQ}j zuJ@>xF|jaf9Ger9IQ_OjxTD5{bN$G9vbZqy^D=d|4+T>(+dPt&>e+dR$R|>1C3Klg zm86LjNx)BMJtxyQg6z2@R-B<+&gd_@Tbq(tzdPcQY5voGl{c1eR8FI=sR&q^DzWm2 z<;Fg`m9)(g3n|!GPTV%OJcz*Xy^!vWQ+hjmt>Nv+&7wWo%nBB3v*}a9o&1qNYm?qk zHnk?F2&qIm0oH9&wNZh?SW-jsFsoI4*n4l+Ni{poNPz{i&a#cwShHBc;*j!qMH?zv z2IXht%6Ag4VAij2gKdcQD`3$EeuwpKIXU{L+uj|RdpuVhaJyZmzB{iUg})^#Z!c@; zP>r;k%Im&cefH(W2>sa3A+15vg%aTVhc=~Py1qqSsSqEcw}ptuJDNGy#wa<`I;)?a zA0eua!aKUVSIGM5MBY5nZMrdSqr zEzOSp9DiJWI0u<_g#+VtmvonpRl1$pmk=>RcGWkUz?`vkXYUu8H#?;J!j_C>2Db|z z7#$Lij4x{e8(FJAb%bG6<+^BUyr8SLG+hp3f=BGe=i^u&JTA{R7MPv5ljJKi9+ru& z3VcCRBj zxxi+k<0@aru7w7Z007+-|EbC)0K)&S%E2zKPW}k^ABU5`$@Tu@5fGZ{g#=kLbp-}A zh9Mc2JQQy`hs=%lOw9ZSui)rPe)b~dV9RL6t<@`MQ! zU1VWytOU2QjcPj|Ij_4g69YpO0MEnIG^mpzR7%Ha##ykmn$U>(;D4!S?Qkc2pnhoz z^RfQ-A~%NpL?41T3iuN zot|dcQ@nrotdWgU9y-HDZ%K}%Pt)(Q*1 zc#3)8{Cyf(^-CFcsr~qL;O3j68L)@^7F8J;`h$GCFFq@U<*T`6OV{_raiUi7pQ_VD zM`$%fl6IWpwsp`eucT_MUL_#M!O7C~>2~(_drCu{**Xy5W{ zv2q3MP|B{q+SgnZKgxSm``zX;cYe1y$FIekzv$d{zNVPG)7QqVplOogyB{D@YR=4S zDv6x78d6&Xwjca_6Av+lgBtp##hbF7pVO5CfX)a2X~y!neUnbg9$**&y!hA{Iy)ie68SG_IEow*jn2?@+_ zDZj}vF+~He_+b00Io*thDQ3315+zqJrrg>R?S~)9-m2~FMWe(oT-u3|Uq#f__?TwE z4qQjU?@HcTqkW%ZD;FHD9*|F@bn0_1s<~b2Kp)gA)>-vO>sh}G5Pz{fjlA}NO#sZ! z(a6`UJTx=>aYSehT)-JcK_w%Vdaf_8H+zP&sX!buJ@rmDRWh^D#m(al$*sjE}LRH+9GLwDcsfx_!9Kr>Bih8Z0#UPP-eO5qO|EtcEQB$;w9v8svO z;v%_m3pezzMZ*cn4TVU{L*$hr5Q8#wmbj;Mj{7egSZMHSr57^V&#_oU*EI`aufjG~ zHSxGukp6J^GsWPPic%Ws)h$+^d$}^r8!NsAD_hNzc{F~{en(;XK-%|xWk4Zq78(sQzGO*;NX-HtUMPF-hUYus7jl& zsbfe!jL{kds{)BoMM2d&ep~HgTi=Y%jgNtiBD1;Uc~1UoZSfh5y>rX+v&# z)E?|2`9@l9ux|BHO*~u;xoBj1V=tG>9)`q8>hE0wo<=+5Qr-+9F6YYg#IFNzD%p&* zxLgJ@5?AEBZ(B!pYz*{J%M8xw{k8~oJI`mnh!)4PIC7hZEXrhz48v5aER;VpJ4$xr zK^1h)yfbnp)W27B`kX<)ElJJ@fp%J2ODveiScXD^Y}UTrfOZ#{oseNAwL?;rgugiXV8(8wY>2?j zxVc-jr`oz8*QQD?@Or*(oCt({2$`X1lJHcZFx0&mvJ8n)q3_okUxHF$_NZ`BUv$LPecTgSTDmit)Y2qfYTcU2FtePMIA3{2d1+RgHZcCZz|amS%e&l+ z=ccoBhV+JxE6k%nitx2rehXDIE=&c5=~J2mYMd>@9-ZMloXd%FApSko%`d6zhGhMP z>sJ^|p9(x`X7CQ5=j~HW6zu5Wo;yqXg=sX9JaF&Do-&q|ja2RxcfRBh_bYcDv)?G1 z@yW>Owo|P|T%B|SXsa`$0}ABZN1EydQR9mgoFOd7-HcNBulKYFB1Qeo*KX!yUB&a$ zZsrQc&+&>M=Bv2+m;FlwvG5;y$y8hWr58A~Q{XfJlG zKj$C<;5~5%z{j#y0^^GIAV}U|M?b=Dz&%_v)j0plhDgfG%Rv;R6rfOPX=xb+NvO1p zjHDa{A}u8ek&}~@hCrmG%_*d3dZF46I3{iVFW)w`hMh403H(&1l` z{?^Ce_S6F3q2v>0Cn5bqE&mAlZDLD&2N_35w0{NpcWeIRmw%hNJl}y^(fkbjn@af+ z?%Pyk{|@fePvHJc{vUz9%_olUfJ|sl=Ji(|{|NQ%j~Dq4YV0RaC#R#4E*Y8FpYL;$ wi2sEI{PXGm?{a^n761@*`s)JWy+4*p>A%j22sb0$Q^%D801~ca|F6FP2a29Z`2YX_ literal 0 HcmV?d00001 diff --git a/bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_hidden_in_submodule.pt b/bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_hidden_in_submodule.pt new file mode 100644 index 0000000000000000000000000000000000000000..39104c647ef007579fe16960f0521e914b29944c GIT binary patch literal 2517 zcmWIWW@cev;NW1u0BQ_247rIpnaP>?rN!|XnJFo$dGVQf@x`S{x%nxjIjQ<7i6x18 z1=%@nP8JMNWP%2!(!7wYh3K=vTYc(Q3&MPTO%*@NoOD|;X zj1U9z<1>rn(@Jx4;*%1Sv&$2UQsOi6^Ro+?{32L^vYB~#sYQj%NQTFIf`qsVSwJRc zFm`4DRRX=vkdvC2R>+F1&<7~ORmg@=SDIAF4#XTl%&8aP&CX%c_5JH6pk5FT0EVCl zFa$A!P@ln&KnN!1r=;q`{jXP0>E@(MT2Stta60R-fq?7xpIwn&4}LN8yxi3|P48B$ zMstp5Z--9nMAy0d*UP;-Eb=NgOz8W$$LA#1a4^dCOgSZc$NP=D{^HO#D{c1NW>kn{ z=?wJkVp4kZ@JHAcyR5yF)P#;ne&>_@a_@Fb3e&3$-ET=2^&FEbKK+~X=lSWBJ;L8+ za2KyH^b0f+Q1?H~x$&s;v`MN{+1_ey-@IVWzt}J86_ceHp7F&^<(_A8=dS+!br#n} zwwY{C-jb7-@XU4n@9z`Dx1`GB~?}xhq7p~4Qf1Wz^LUv5Q=gkG7 zoAVU47K;mQAI% zT!o8~fkBBGU%Xf{@RJ%ZdMT+%rRnkD#LkTzL#2Jb`Ij98+OC%c9bo12@Q`tc(mT<3 ztyNl|8F{e)H$!bbk&- zdAswI7Ms}H``SN?zkKuM^2s)~R(93vJ*Ip$3XI!!UHz7OtLfV4zq>DlrUYE6F0ovy z>2>)P=iU11uXEZ?otKO4<=MEXy~H#AQtFDPH#es(xTv>f3A5I-phbOQlLLB?SErcc2M(qoc)Q`H+?Wp&lRT;1@?@7tGUZ$G`iQE>CNQ|a+l zy9Fl{-TiwqcTV}9KdJ_MA_@Y(T1Or1<-NN1MRiLV8%syl(UVr1n^qj9|y3cmeXZa3~d7s@_`TGOg0mlcEuGblhyqD>T zDcYuY?UZ;u_h*Y631wz8gqO>$Q9rI%DHk()ZPklM^`JB$$D(wq510mmm>_8Yd+Dal zAW1L{I7|1sw_#(FG@)*BHI3(%#xCvRFE55Z68ar z0qqB2+-@*okRT8OdZ21Wzc{rh6Il6FC6)j~BR(_5%}E4nfEbvWnOc||np;{L85tRy z8(10{8ylFL8kiYdSX!D{nwpy%7#W*cSb$vky)XH@2GDgN9N^6e3M_cdhg_nofh14> zarGd?OhjFPTpH=4nCApcpZLwhs2R`=M^5IJD2CqxdW=NF5ycyNxFDx%Z4_(n0n<1k zYmjmfy1~dvRu#qI5Mcd{-(YAaK{paP1*)SMd5NHr@RW~kC~~4PK`}H8n9}jP6MIq) z@MdGvftn=8tP3}j9m;~y)!+om!0-c0N5KxR2EX&To&Hz#mQ40Vb CvmBHF literal 0 HcmV?d00001 diff --git a/bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_many_unrecognised.pt b/bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_many_unrecognised.pt new file mode 100644 index 0000000000000000000000000000000000000000..68639503af8672b4467a0c38dd34e3a3e0802203 GIT binary patch literal 2311 zcmWIWW@cev;NW1u04fZ247rIpnaP>?rN!~NiFuXrrFliE$@%GdnZ>Co`YDMeiFyUu zIc`o?3{hl*2Bzfr_>%mhSW=XjnU|TDUdY%P zAqM2fXBNk&mFDEcCnY9lmnRmb#AoE^XBRU0l_nK3>jikTb4b5)pB4pF3BmzDpNar| zis@Bd25!7QP0mkA)rY%Mub|S+i5=$U+`fZZhYfhz-q&);3ZFQ2uJJ9em;UXQJj%C= zvR_3y7DasfR=nw2zzatOyT9_kVx2-Q68`-?PG<&`}CKie(fAjyv8_F2r_G7^s9os#7?q4h1o$z+c%j({p(Z6_avB|!c-NnYe zZAt!wX)pQCW-oqG`#t9{V~%~p>c{_87aDo5d-skX6lB-;u)jPD46YnTd_iW;z(-n; z>7}G5m8Qpo(+d|Y@PMf%wC{A@We1VA<2#!)t;IGte7Y2P>x@`vgNO>()q@I6+-#gW z=`}vfd~(jKD&1AwID_~7fo0u#zpQFI_qka6tf2ZV9)Aao;zb>nkWoC6#Eq;4?UB%{JV`ceU7j;3&;9@MF#}r@^@MVG|1MFplCW9z}WRO=} zl9*Rg3@IoCks=;iVl*&mf^%xFK4&Yi;6Sz^W2=&F>S(L zL>Y5|eLF8b6zEX~5XNo0DgoP*DoaxHi&9dHNU%RAv!o;^7379i+sBe@K>I-$w;PNY zr0|A-9;jH>FHSAW1Qy*@i6y|$h|f%Ma}vW5AO>b;rWWRg=9ZR5Mn=Zw29`#~#s+2v z24?2Q#-^s0#%7jAmgYuAW*`@uiYPzV0J;u@1H2hQfd#KQkV{$>kOT@KrXGQqh$x|v z3nE<<(_R62_)NqouhET0PTb}wM#};{MUv5oat%F1kW;iKiX|$*)Q#T~q+Ek;>_w!c zs*Gan6avOVGY7hX$mvcM#Xuoe{O*LOcXT6>6N(Xvk<$nmi9PuSc(bwTK+Tb3MkEGy ps2GgS1t&=c1`ALg0s^2uu+&m$oewO}fMTp*S#}0?29SD)S^%&S$kzY> literal 0 HcmV?d00001 diff --git a/bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_mixed_file_reader.pt b/bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_mixed_file_reader.pt new file mode 100644 index 0000000000000000000000000000000000000000..78b8c47c43c1168d40dac2c9bd03dc71b2c3a148 GIT binary patch literal 2311 zcmWIWW@cev;NW1u04fZ247rIpnaP>?rN!~NnH8xi@oAYksqsari7BZ?`YDMeiFyUu zIc`o?3{hl*2Bzfr_>%mhq7QebUO}as6C2dexqXh=hYfgI-~Sa|w47y6*>0(AkMr&( zya{UyNUAg}U8`hm@^=0F$-NsL^_^_?zJK|3Vo_V{j2pJ?C-QyzO4a+4gP%+6+^WrR zL|9XH=Y}+wik82}7nLpk{(s3Lksy!x2@8cgH*2PRN$$B3SZ^`eO2&on>%S^3N1>yC zbLK=Im)rB=O{S51{_AF?Tl^ifqgj1Z8PvGn&)Jm0UG8tcI#zFbRMtcDCo#s7Gj6%= z|Mx2HX~({0X0LUp*Ue@xDvFb?n0SOu?%jUIr$$-7CYKr(Y@a{RB)x4Vj^-s!<_o2k=~pE@Z_G7WpSCnFEK*a&?Kh8Z`m$5q&r0gOdT%^@ z^mB7_K<_1X+m&v+uWj8v??UW~V=F)O)S9l~J$%Suv8D4&Y3&^c@8lsMN!^tCW1HT+K4bEDGtYhE$17&IvMG?XYO&#)3b ziYsF-kk2p9OAiHl zngN7y+pbE$_N2;^)chh~*+_!@IhiFTIjJBwwAwzFWCPj{!noaF#2|$?1oS|~vVL)D zQ6{kHu1YKchDLm5ikp)djsP(*Gc&a?H#E1jG%_+WHaD;|GB!3aHn%V^ur#(bGdHj> zH!`p^Ffj$WaHgc8vU=05SCl#6&~^ja&%nqL}s?$irtM zMsba9G;%^WM=@Fv=qZwnMwDsjA%dKuHBl^)0H$vImLTOIbYqc|sxpePQwSIf%^c_k zBBwi56a$4=@w*eA-qDRjPAEnwMouGOB=+PR;LXOS12spE8Ic&+p<*yP7n~#+7%V_} d2nc}sz*0-0RX(t61B$VNW!V|n89?eGY5`^-$|3*& literal 0 HcmV?d00001 diff --git a/bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_rop_exploit.pt b/bin/pytorch_inference/unittest/testfiles/malicious_models/malicious_rop_exploit.pt new file mode 100644 index 0000000000000000000000000000000000000000..08beafc14cef417a5abf4b2585832b5d09dc427b GIT binary patch literal 6109 zcmb7|2T+sQ7RM<{Qv~T9fzUfaAXKFn=^&jzsDTt9^tOtj_f8;yNRuu|N0cHUNE4(Z z2uMdME=tu0>)U;~T2b%P3Bih@rCrX z2#J_vcnWTx?L41)XDX@ARQ^>aHxBhrRN9?M5dK4(kei>X;#C~;A4j&S7IIdC{OPd? zIY{FUB!a4>2ai|p{(6bZ-D?e@H(#_^W91SM;CyZ29myqolU@=HXDW>4p zd3?t?fiBbW^E0KdI_C%Ujd5j%Qb~M#(CSlZR)-G=#`X=8$n3w`%*p5Ia&M0fr#6r8 zQ%)uJ6qV?yf8O(~tThHB~i%Ig% zL=?TUlL0-c6|ori0;b#}`7(0)sG*PSs{M14n|jKWR>_s+xkhb^QnfE^u(m6=Qa&&5 z(nYYp^u_MV#&bOA4A&5ty1(Y{c1V8Jq!Nvv?YWs2tvg`uv z;#if!42PHX#`~fP8p@Ri2b@?cHtM~)^1)K@>zn2w?wDJOVpcN*b-f2R6tn_isxxBv zg$gfAO}sNY^}0X#;6CfgrJC49`P3WC=Xe$FEMr}wE`;~$SDph`$ZA!|KG}S1$D|3E8IXSq?x3fFS^yb6Z*4^RgQ$Z>TWqIK)9BW{?F-mu>*p+A}K9CJz$OI(@Mkx*ekxLwlq zsl?VXfqhg)=h!3~KJ*XIo7|p)4(s+cmmgWPyv$?Tq-BaE=MC2s}m(WMXDo!5Hk}YK8U(O)mGiuM~0epftCOf?1u!vDgc*M!cq$)(39+b-!v{ zZkt|Yt|Vyr$e&1E*4lj-TKa|nv3&d3#G`tVa`OP7J0CTO#eh_Mn&zEFB%-QCV`8T! z4TNok5K`f^9HSJkUXXTT*CJ6AsD5V2CjRj*JR;h7q}}tGMT3kS_0bkr-fTck+3>Y$ zK>^-}xqBTB*F#zp;Wt;!7pD}_BB?|%jWlV32xTH)$*zonKy8~>eJ2QJOy20OgS<8o zY-u^>LOo)-IgU2i;l1^2ayG-DRV>A*m%S1tvN>ZC_4qTAX|>)UAbc-&;=(4#yMc4H zbt_zlb)iVq94|y2l_O#w<$D9SiB=?TDdZJk0Bysun)z_-%c*)vf87nK&v=jU2|uI= z2ko0*yFaW?Q+m?+_+XL7dclK)cZB*T1qAaF0pzyymd)+3k=_$@$}=Q3pnA3;V66w? z=Ct}+u8VJUU|`5$a8)rDl`a=O>p)~fxYwYUMNLb4SjquekDg6bUxyC@hy220_=gr! z9y}|5tCCsW0sTmX^?#>cbTW~ruPycY5-V2tt;Uw*_Uo-ZBTFPiwL?KJW;9Z`uE{J8 zpb~DP-LNyYw0P7{X{VL}vt{Y1mW0vuX;$HBF%8i&-K)HDVAnJm?K&jqL%_(J#|@UpCnkJZ(qAuUr(tl=6nNW26$w zYyd&eS(tRR*9&;J3#QlP(vwa7?PoiU;uzi5?3mv(6HdErjM2R}JX~B@)r1r75hq1o ziLuYDb|Kj$25MxtjmByl(lzF*Ca)8j<&@U@AHln~UbNc=l+$H+IR&WA;Anr==-RUYa#slM_YREQD_ZLBcjVZ8$V_@r zVb(p3GL0Rpv1S!^Y)y8Y$eYa=KdzWk-TbqnsZe)+^V})O^o-A*{Mk;W&w5$Z99-#3mK*FYd0Hd z-?9;hLmpr3RxF${3mju0hQ0Q^N{|5;Uz+NMag953<(tV|eVM5bo0Up|MHG(45>k|? zVpp2ZsF>cEcui=UMAWs>P)VQhpiQK_7SN>1@Mq57b9rQCUQOoME|{9TPS zi<8^g)%|t-vD?Z^?>(I%Q_hUm4!4#*O+8QL>S~?lx%o&a8dGq^+hV`_%D!x_r?M1x z3u?pO4Eaf28XA^t7c|}yBaO649+kzakIZ=vvoOiT?Hl`t&y&>Tx5<@ad+7&Mc{&>D zd-ms@VY$8n(z~xWZM#5W$w3yup>CwCmY|+36CD-8l{z{am#qh?dHEf_M0oDpE_sCg^?kn8l%LNyV z=fij;XQP_#LUD@)k~FcQXvLf4{+KB+6rV=<3*A~#B9x>le%=KrfprBMa=ef+E0%`l z(~ZA#+qNo!o zljdZ*y7bi+S?Cd23?RvU)&m|ji1Z>n!Y^u7ql&&f}rz zdT>EOD-@g(>7mcYr;_@R<~k%ZL&z}#14+veT#0Ccu+(P?(?sCyS@7p1j1#z%ZH*QT z0}5)}Nf%w*jY#4zL_e3RscR3!nIJ!k_0=V*LNhY$#!<-=gAIG%mkx|7%iodqG~rn= z;q#Bq!y9@_@NmqAt&g2JQ!Ugw1nIovn$krPcv*;JpV#OJw;<+b0NbmCmuQaMZ@h|eoe(pz+hJ;kQ^Vc@AI>qT*s{0vVb%gyn8&7eMR z&C%}y7_c|Z?g3T2epML+^o}wZ`4?nxEr`~w8)D^(YiOm8OZu!Aw|oGr0UjE<>3Saw zGsh3;T=%M*B~jB$s5FE21F_9^vVOZTboQ)AI#OiF8zniBIQ!wQNSLHyhfmY8yUsR8 z(+vKx-eA%Wptm&+8%w?TYtQVb0{+kWSdInpUPf6d!NQDaR9?cqh4~9-`ra&AH^x3c z3L2`K-O6Q8wuUjQL+zB0qCrVR0dK|oE>a-a59>KxIcK}qnE_f)q-YCGc@~ZVFIz=u zm)VIG*g!pr&1gfc{16-rIP@W-dSEHApesB78FUKl{UyMyUAS6P zrOyVuAr=KKLFWS)pEOm2xfVo_&-@ul#)i#CZxie0nCFJ{zmWHJk9Bq^4$ z@{{W`rl`E;_d2m`SDn{!W>R0LYUxl@r)o7vmx|vZ2>n=1BaB&XpCVXwFWlOf8F~+` zYw1VfK<;$8cvMR-Eb!44I3s5{plO!yW(pfIGbha7%VW9={yF_oAOxEUm^iQz?BgQt zeBvHv*n$>Zx~fBl8p;PK?1Y5X&_^qI5d)XI3u<9IUjONGuoEe|)|nGNbBS{e}X>wCXFNRgpMhe8sEyCBud0@N>` zdsHiyQ{n1!-v*h4kI|v)uCBN|BFbNX&PdWkOr=Qv##6($I8d*!*pGX=v^s}v+=8ej zp`r_=>2Q%6X6;H%^3g~{rm9aKykM^{%%H3%9+lusG4Md8HcFyewKl$$lqr3j%RTiV zifR)uTPWRF1rgExVA6}P%yKnevzledI~A)3XEEsIO?`4dX(IAO>BKM!L5zI(P8;xc z%tmmq{39UneoeeomSM@GY0YjLV?Jhq6&PA;qmVNmwF!&9P5A5JfP`yU=UFAdx}P){An9n zd-eml`!2cpT7vF)T(VZ1?FKX~$`BO`yU&=YTbH$`>m!nPOjtgKV}oz)X)(-37`$a# zDdc|$F)D-9Tbs$-g{ROH+#27pm?*?lw`(I(}VkiG2q$43#)_M z1M7vepZJnCu_ZMY9kBq`(YK(205=t*Y8TnIb>bvl=#D_&rRps}uDdAIQ(wolgM;wP&{0nON)~)uo&o5k3a-9#J3U*mP88muKg@>(V zKNKdFS8CQ(47je9cW+b&(K4R;8cAmo#eYfPI?9JtM?1+O2fNiXiZLw9EOaxu&X~W1 zozvR0b$TI`yd%|1K=y_AMpOh^!R=_?MdImWm)8nbucca?!q=#*Anzv;hYxEdF7d`` zRD(N^t^|_NtdTv15t_bP&*?x``Ai0KqaN>cc(15O=onTBD6>$JXN%F`C%iTY-TiT?Y(;rIUG4j%7sZWmmb*yPNJJM-NBNeMHK=ueb*J9wa=NQA1Q z;QyxqNOB&BS2<3`cxNY{&vUQs@8iJw)6*e-o(@P4dk2sI`+OMG(-Y=!t|3%ziJsu> z{PX#94FWjVei8wpb6Zjv<=_E5GcEmXJcRuyBaLvJ8$WZb$~YL?3RSUO}aslP!{C;I1N`QQBvieb_*t_4}VL<;gtj3NJey ze9(7ymg-F>fv#FF)=7SImWX}79v+dK*`9N${Qcv-Cs#>4oWT6);EsgAJ(uR_njL0y zIJF?c(&t2=m&C)8xRbL@reKE=8lZ70b5*b_Ilp@o0;t` za^S~=$tJ}z4?66G!q@zmCHqUI-2BI3-hV$1OYf-uSI3v_@-5Y<^#-fiy4&gHORB_P z@6em&FS5Tl>c#A3Ro{=U7W^b}^yCkfl>1vYSZ_C2uKM7|m8MYcrzY>bmj%?A3-?`J zCBJd5(>*b534K|sN2xV|W=Xk;J~L}IZ`Uw_Lf@A~qV^Lo=sw?a|B5ExdnFQomxts*D?Fj26ok-kbKIv!p@ndf>kjy*oYZ zo%1u|&ogB{T=(tLm&#xJ_83bn_tg9C+Ouv_jn7K)wL-ss&)D9f+@F49{qO#@ z+86iNRr$V(SHCe$Rk*}*A-AKN@0}2T)9w6MzFeAjS^Kgq`NKldaq=JFH%(T zl6O{`{=nt$oypho-f{?i2zl!h-*D>gws(G^?@pikea+5ck(AY|=MIaG_s!cLdR}%< zKxjC7W7yD6B##|upT%4C4 z$^^6-gaf=8K@_}9LM|ypKoTebTL}d*0a2nLr*>Hs-7P>KvI!XF3%XIr;jWHi6cfCTEu?7Nx{zGWnGz6*B7uc(ZeC2)VE!0;m#%1Ay)n0lE{@p|T9JDDF%K zx>g_VCcT16Hz!FX2g6+sVYK!eW*;^XX#M`DOIp{%d6U$owaK?tXB)QX%*{%SOfM12 z{a34do5f2r_I_>u_t^JaIqt;lxc@?XuaB(c<3sLWPpMT0ePv`QycXy`MT1KzW?|;_ zEvK*j>j=7{r<1y3Ve&WDq9(KI{<-gt`^)n3oq1eTKY#M$1+~EkE}5nrEYk5&eUiJE zN2Tv?L&dMu@)#}#y)%{zje`#NHWl1?bVQle=ljfUx0AD)Z@rdgzja~e(W#I3Kd@5u zaDAX3bnx-TyK#DQ5ROrL5AG+`=<;NrbIjfY;hy56$Bn#251|%hGe-G3WeR zs~Todh&KOU)AuIHP_x1{d8}YUcbwGFjDtis~1n4la>5(r)h6D z)^GWD!LjI~%JjB9huNLPPlis8k^aOP&!O#+H|xo^n{U@S zmR#hm1Lg;IjvXxB1{T1W{0ogu?D<#{BR2DjOA_-+iXoYq!uV@o(gf$#Tz$@PU_rtF z!l1m0-7Ufl)+i~fEVZaOGe6JG$(-C&WXuH$iHq~nLz#d!gK&U1BZz{RX~?CB2uK12 zU@PGuCLl^1!T^zWZzz;pEdZ_OH(&P9pk)sD07RlS~n`i{URqa!>Q zTUb<=S>N`LcaE^SyXsXV^SAf5_SSPl4G*2%ab5WCkI%V5dQ$H@ZGFTS|3O?c-EXRXcqCpp*6)%f}?kTa|Q z{+`7xes;5#{){h>6`gp@`0wO14-YTb_1Mcg@yF8@ch=~*Da~DaYQduku6#H5Im~;$ znu)*2>*<{Iqbj+2X2MJFA6dNX?n|x}W-qyCo<0!xrSrGv!Q*RIavA0dD?MjjlOoP> zDmKG-%aeD_OI~dH_)cZLqwN1QE{l%Lsrd?a`_#8`upqO|1*~9^RmBHx(2>mv;AVw@tZL(?-a|G z-O|o3pL_1or&qa)W^6ma75Fx@@P*tv$-OGKi;d4Ov3PTqX_EK_-jA$>RtbI&F1jvm zPM(#*6rQvCfX>UuHxk)3&)L4#ClT{8sScNP0-j^l+OakxplS-h4RA${bm8!+eMwCC^G z;{GcfSsPs;-tH;&aK_eaVe(rtBwW_cyL~)-@sYy|r^-zT?wMn@aMR@TpG*>_O=h$H zE71JbQJ~6CV6(vjo(aZj*N?2>6RdnHt3x|BW#b=h;Qkm)P@ z;-z7-bkS+2zTz*xBI1NLow~T&s8&DK*R@np-QrikGI{UT{j2{K?2>=_;C$e>)gr%j z=Pv6qYnk&T!%_ItzAl%?Nkt{G+83pR^j2qtdK|d8c+<}4Wqvd5^e*S>-rryJ$;|kX zz?x$h7jD`1Ir_DX?d3m;)r))kieiGLq-9LCoR6Kb%=^m2`(#h^*O_VhA@CIZu5j0tWnG>ODzHw({4`Y z#N`&`oC9GSbAfzwab9{T&@&7m9N^6eqTqE3aw#qXl0X62Y8Z$Kh|(UpbdW{Sy#mNX zHUXpTM>h&N*ws;tngnz-UZW6&3%U=H!(I}_&=yv7Ly_VS-5BIR6h<-T3Ysy{XbbRW zW7B~ulw;O~YXPbT0~nnJELlMyhz&vm^|6Ckd%z?E14uf+n-wg}3QTbz^$@iHM7!YN literal 0 HcmV?d00001 diff --git a/dev-tools/extract_model_ops/extract_model_ops.py b/dev-tools/extract_model_ops/extract_model_ops.py new file mode 100644 index 000000000..676a7ef4b --- /dev/null +++ b/dev-tools/extract_model_ops/extract_model_ops.py @@ -0,0 +1,142 @@ +#!/usr/bin/env python3 +# +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License +# 2.0 and the following additional limitation. Functionality enabled by the +# files subject to the Elastic License 2.0 may only be used in production when +# invoked by an Elasticsearch process with a license key installed that permits +# use of machine learning features. You may not use this file except in +# compliance with the Elastic License 2.0 and the foregoing additional +# limitation. +# +"""Extract TorchScript operation sets from supported HuggingFace transformer architectures. + +This developer tool traces/scripts reference models and collects the set of +TorchScript operations that appear in their forward() computation graphs. +The output is a sorted, de-duplicated union of all operations which can be +used to build the C++ allowlist in CSupportedOperations.h. + +Usage: + python3 extract_model_ops.py [--per-model] [--cpp] [--golden OUTPUT] [--config CONFIG] + +Flags: + --per-model Print the op set for each model individually. + --cpp Print the union as a C++ initializer list. + --golden OUTPUT Write per-model op sets as a JSON golden file for the + C++ allowlist drift test. + --config CONFIG Path to the reference models JSON config file. + Defaults to reference_models.json in the same directory. +""" + +import argparse +import json +import sys +from pathlib import Path + +import torch + +from torchscript_utils import collect_inlined_ops, load_and_trace_hf_model + +SCRIPT_DIR = Path(__file__).resolve().parent +DEFAULT_CONFIG = SCRIPT_DIR / "reference_models.json" + + +def load_reference_models(config_path: Path) -> dict[str, str]: + """Load the architecture-to-model mapping from a JSON config file.""" + with open(config_path) as f: + return json.load(f) + + +def extract_ops_for_model(model_name: str) -> set[str] | None: + """Trace a HuggingFace model and return its TorchScript op set. + + Returns None if the model could not be loaded or traced. + """ + print(f" Loading {model_name}...", file=sys.stderr) + traced = load_and_trace_hf_model(model_name) + if traced is None: + return None + return collect_inlined_ops(traced) + + +def format_cpp_initializer(ops: set[str]) -> str: + """Format the op set as a C++ initializer list for std::unordered_set.""" + sorted_ops = sorted(ops) + lines = [] + for op in sorted_ops: + lines.append(f' "{op}"sv,') + return "{\n" + "\n".join(lines) + "\n}" + + +def main(): + parser = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument("--per-model", action="store_true", + help="Print per-model op sets") + parser.add_argument("--cpp", action="store_true", + help="Print union as C++ initializer") + parser.add_argument("--golden", type=Path, default=None, metavar="OUTPUT", + help="Write per-model op sets as a JSON golden file") + parser.add_argument("--config", type=Path, default=DEFAULT_CONFIG, + help="Path to reference_models.json config file") + args = parser.parse_args() + + reference_models = load_reference_models(args.config) + + per_model_ops = {} + union_ops = set() + + print("Extracting TorchScript ops from supported architectures...", + file=sys.stderr) + + failed = [] + for arch, model_name in reference_models.items(): + ops = extract_ops_for_model(model_name) + if ops is None: + failed.append(arch) + print(f" {arch}: FAILED", file=sys.stderr) + continue + per_model_ops[arch] = ops + union_ops.update(ops) + print(f" {arch}: {len(ops)} ops", file=sys.stderr) + + print(f"\nTotal union: {len(union_ops)} unique ops", file=sys.stderr) + if failed: + print(f"Failed models: {', '.join(failed)}", file=sys.stderr) + + if args.golden: + golden = { + "pytorch_version": torch.__version__, + "models": { + arch: { + "model_id": reference_models[arch], + "ops": sorted(ops), + } + for arch, ops in sorted(per_model_ops.items()) + }, + } + args.golden.parent.mkdir(parents=True, exist_ok=True) + with open(args.golden, "w") as f: + json.dump(golden, f, indent=2) + f.write("\n") + print(f"Wrote golden file to {args.golden} " + f"({len(per_model_ops)} models, " + f"{len(union_ops)} unique ops)", file=sys.stderr) + + if args.per_model: + for arch, ops in sorted(per_model_ops.items()): + print(f"\n=== {arch} ({reference_models[arch]}) ===") + for op in sorted(ops): + print(f" {op}") + + if args.cpp: + print("\n// C++ initializer for SUPPORTED_OPERATIONS:") + print(format_cpp_initializer(union_ops)) + elif not args.golden: + print("\n// Sorted union of all operations:") + for op in sorted(union_ops): + print(op) + + +if __name__ == "__main__": + main() diff --git a/dev-tools/extract_model_ops/reference_models.json b/dev-tools/extract_model_ops/reference_models.json new file mode 100644 index 000000000..255762721 --- /dev/null +++ b/dev-tools/extract_model_ops/reference_models.json @@ -0,0 +1,20 @@ +{ + "bert": "bert-base-uncased", + "roberta": "roberta-base", + "distilbert": "distilbert-base-uncased", + "electra": "google/electra-small-discriminator", + "mpnet": "microsoft/mpnet-base", + "deberta": "microsoft/deberta-base", + "dpr": "facebook/dpr-ctx_encoder-single-nq-base", + "mobilebert": "google/mobilebert-uncased", + "xlm-roberta": "xlm-roberta-base", + "elastic-bge-m3": "elastic/bge-m3", + "elastic-distilbert-cased-ner": "elastic/distilbert-base-cased-finetuned-conll03-english", + "elastic-distilbert-uncased-ner": "elastic/distilbert-base-uncased-finetuned-conll03-english", + "elastic-eis-elser-v2": "elastic/eis-elser-v2", + "elastic-elser-v2": "elastic/elser-v2", + "elastic-hugging-face-elser": "elastic/hugging-face-elser", + "elastic-multilingual-e5-small-optimized": "elastic/multilingual-e5-small-optimized", + "elastic-splade-v3": "elastic/splade-v3", + "elastic-test-elser-v2": "elastic/test-elser-v2" +} diff --git a/dev-tools/extract_model_ops/requirements.txt b/dev-tools/extract_model_ops/requirements.txt new file mode 100644 index 000000000..70d0ebb78 --- /dev/null +++ b/dev-tools/extract_model_ops/requirements.txt @@ -0,0 +1,4 @@ +torch==2.7.1 +transformers>=4.40.0 +sentencepiece>=0.2.0 +protobuf>=5.0.0 diff --git a/dev-tools/extract_model_ops/torchscript_utils.py b/dev-tools/extract_model_ops/torchscript_utils.py new file mode 100644 index 000000000..7ad860b58 --- /dev/null +++ b/dev-tools/extract_model_ops/torchscript_utils.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python3 +# +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License +# 2.0 and the following additional limitation. Functionality enabled by the +# files subject to the Elastic License 2.0 may only be used in production when +# invoked by an Elasticsearch process with a license key installed that permits +# use of machine learning features. You may not use this file except in +# compliance with the Elastic License 2.0 and the foregoing additional +# limitation. +# +"""Shared utilities for extracting and inspecting TorchScript operations.""" + +import os +import sys + +import torch +from transformers import AutoConfig, AutoModel, AutoTokenizer + + +def collect_graph_ops(graph) -> set[str]: + """Collect all operation names from a TorchScript graph, including blocks.""" + ops = set() + for node in graph.nodes(): + ops.add(node.kind()) + for block in node.blocks(): + ops.update(collect_graph_ops(block)) + return ops + + +def collect_inlined_ops(module) -> set[str]: + """Clone the forward graph, inline all calls, and return the op set.""" + graph = module.forward.graph.copy() + torch._C._jit_pass_inline(graph) + return collect_graph_ops(graph) + + +def load_and_trace_hf_model(model_name: str): + """Load a HuggingFace model, tokenize sample input, and trace to TorchScript. + + Returns the traced module, or None if the model could not be loaded or traced. + """ + token = os.environ.get("HF_TOKEN") + + try: + tokenizer = AutoTokenizer.from_pretrained(model_name, token=token) + config = AutoConfig.from_pretrained( + model_name, torchscript=True, token=token) + model = AutoModel.from_pretrained( + model_name, config=config, token=token) + model.eval() + except Exception as exc: + print(f" LOAD ERROR: {exc}", file=sys.stderr) + return None + + inputs = tokenizer( + "This is a sample input for graph extraction.", + return_tensors="pt", padding="max_length", + max_length=32, truncation=True) + + input_ids = inputs["input_ids"] + attention_mask = inputs["attention_mask"] + + try: + return torch.jit.trace( + model, (input_ids, attention_mask), strict=False) + except Exception as exc: + print(f" TRACE WARNING: {exc}", file=sys.stderr) + print(" Falling back to torch.jit.script...", file=sys.stderr) + try: + return torch.jit.script(model) + except Exception as exc2: + print(f" SCRIPT ERROR: {exc2}", file=sys.stderr) + return None diff --git a/dev-tools/extract_model_ops/validate_allowlist.py b/dev-tools/extract_model_ops/validate_allowlist.py new file mode 100644 index 000000000..5d31d44bf --- /dev/null +++ b/dev-tools/extract_model_ops/validate_allowlist.py @@ -0,0 +1,194 @@ +#!/usr/bin/env python3 +# +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License +# 2.0 and the following additional limitation. Functionality enabled by the +# files subject to the Elastic License 2.0 may only be used in production when +# invoked by an Elasticsearch process with a license key installed that permits +# use of machine learning features. You may not use this file except in +# compliance with the Elastic License 2.0 and the foregoing additional +# limitation. +# +"""Validate that the C++ operation allowlist accepts all supported model architectures. + +Traces each model listed in a JSON config file, extracts its TorchScript +operations (using the same inlining approach as the C++ validator), and +checks every operation against the ALLOWED_OPERATIONS and FORBIDDEN_OPERATIONS +sets parsed from CSupportedOperations.cc. + +This is the Python-side equivalent of the C++ CModelGraphValidator and is +intended as an integration test: if any legitimate model produces an +operation that the C++ code would reject, this script exits non-zero. + +Exit codes: + 0 All models pass (no false positives). + 1 At least one model was rejected or a model failed to load/trace. + +Usage: + python3 validate_allowlist.py [--config CONFIG] [--verbose] +""" + +import argparse +import json +import re +import sys +from pathlib import Path + +import torch + +from torchscript_utils import ( + collect_graph_ops, + collect_inlined_ops, + load_and_trace_hf_model, +) + +SCRIPT_DIR = Path(__file__).resolve().parent +REPO_ROOT = SCRIPT_DIR.parents[1] +DEFAULT_CONFIG = SCRIPT_DIR / "validation_models.json" +SUPPORTED_OPS_CC = REPO_ROOT / "bin" / "pytorch_inference" / "CSupportedOperations.cc" + + +def parse_string_set_from_cc(path: Path, variable_name: str) -> set[str]: + """Extract a set of string literals from a C++ TStringViewSet definition.""" + text = path.read_text() + pattern = rf'{re.escape(variable_name)}\s*=\s*\{{(.*?)\}};' + match = re.search(pattern, text, re.DOTALL) + if not match: + raise RuntimeError(f"Could not find {variable_name} in {path}") + block = match.group(1) + return set(re.findall(r'"([^"]+)"', block)) + + +def load_cpp_sets() -> tuple[set[str], set[str]]: + """Parse ALLOWED_OPERATIONS and FORBIDDEN_OPERATIONS from the C++ source.""" + allowed = parse_string_set_from_cc(SUPPORTED_OPS_CC, "ALLOWED_OPERATIONS") + forbidden = parse_string_set_from_cc(SUPPORTED_OPS_CC, "FORBIDDEN_OPERATIONS") + return allowed, forbidden + + +def load_pt_and_collect_ops(pt_path: str) -> set[str] | None: + """Load a saved TorchScript .pt file, inline, and return its op set.""" + try: + module = torch.jit.load(pt_path) + return collect_inlined_ops(module) + except Exception as exc: + print(f" LOAD ERROR: {exc}", file=sys.stderr) + return None + + +def check_ops(ops: set[str], + allowed: set[str], + forbidden: set[str], + verbose: bool) -> bool: + """Check an op set against allowed/forbidden lists. Returns True if all pass.""" + forbidden_found = sorted(ops & forbidden) + unrecognised = sorted(ops - allowed - forbidden) + + if verbose: + print(f" {len(ops)} distinct ops", file=sys.stderr) + + if forbidden_found: + print(f" FORBIDDEN: {forbidden_found}", file=sys.stderr) + if unrecognised: + print(f" UNRECOGNISED: {unrecognised}", file=sys.stderr) + + if not forbidden_found and not unrecognised: + print(f" PASS", file=sys.stderr) + return True + + print(f" FAIL", file=sys.stderr) + return False + + +def validate_model(model_name: str, + allowed: set[str], + forbidden: set[str], + verbose: bool) -> bool: + """Validate one HuggingFace model. Returns True if all ops pass.""" + print(f" {model_name}...", file=sys.stderr) + traced = load_and_trace_hf_model(model_name) + if traced is None: + print(f" FAILED (could not load/trace)", file=sys.stderr) + return False + ops = collect_inlined_ops(traced) + return check_ops(ops, allowed, forbidden, verbose) + + +def validate_pt_file(name: str, + pt_path: str, + allowed: set[str], + forbidden: set[str], + verbose: bool) -> bool: + """Validate a local TorchScript .pt file. Returns True if all ops pass.""" + print(f" {name} ({pt_path})...", file=sys.stderr) + ops = load_pt_and_collect_ops(pt_path) + if ops is None: + print(f" FAILED (could not load)", file=sys.stderr) + return False + return check_ops(ops, allowed, forbidden, verbose) + + +def main(): + parser = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument( + "--config", type=Path, default=DEFAULT_CONFIG, + help="Path to reference_models.json (default: %(default)s)") + parser.add_argument( + "--pt-dir", type=Path, default=None, + help="Directory of pre-saved .pt TorchScript files to validate") + parser.add_argument( + "--verbose", action="store_true", + help="Print per-model op counts") + args = parser.parse_args() + + print(f"PyTorch version: {torch.__version__}", file=sys.stderr) + + allowed, forbidden = load_cpp_sets() + print(f"Parsed {len(allowed)} allowed ops and {len(forbidden)} " + f"forbidden ops from {SUPPORTED_OPS_CC.name}", file=sys.stderr) + + results: dict[str, bool] = {} + + with open(args.config) as f: + models = json.load(f) + print(f"Validating {len(models)} HuggingFace models from " + f"{args.config.name}...", file=sys.stderr) + + for arch, model_id in models.items(): + results[arch] = validate_model( + model_id, allowed, forbidden, args.verbose) + + if args.pt_dir and args.pt_dir.is_dir(): + pt_files = sorted(args.pt_dir.glob("*.pt")) + if pt_files: + print(f"Validating {len(pt_files)} local .pt files from " + f"{args.pt_dir}...", file=sys.stderr) + for pt_path in pt_files: + name = pt_path.stem + results[f"pt:{name}"] = validate_pt_file( + name, str(pt_path), allowed, forbidden, args.verbose) + + print(file=sys.stderr) + print("=" * 60, file=sys.stderr) + all_pass = all(results.values()) + for key, passed in results.items(): + status = "PASS" if passed else "FAIL" + if key.startswith("pt:"): + print(f" {key}: {status}", file=sys.stderr) + else: + print(f" {key} ({models[key]}): {status}", file=sys.stderr) + + print("=" * 60, file=sys.stderr) + if all_pass: + print("All models PASS - no false positives.", file=sys.stderr) + else: + failed = [a for a, p in results.items() if not p] + print(f"FAILED models: {', '.join(failed)}", file=sys.stderr) + + sys.exit(0 if all_pass else 1) + + +if __name__ == "__main__": + main() diff --git a/dev-tools/extract_model_ops/validation_models.json b/dev-tools/extract_model_ops/validation_models.json new file mode 100644 index 000000000..5c23eb907 --- /dev/null +++ b/dev-tools/extract_model_ops/validation_models.json @@ -0,0 +1,29 @@ +{ + "bert": "bert-base-uncased", + "roberta": "roberta-base", + "distilbert": "distilbert-base-uncased", + "electra": "google/electra-small-discriminator", + "mpnet": "microsoft/mpnet-base", + "deberta": "microsoft/deberta-base", + "dpr": "facebook/dpr-ctx_encoder-single-nq-base", + "mobilebert": "google/mobilebert-uncased", + "xlm-roberta": "xlm-roberta-base", + + "elastic-bge-m3": "elastic/bge-m3", + "elastic-distilbert-cased-ner": "elastic/distilbert-base-cased-finetuned-conll03-english", + "elastic-distilbert-uncased-ner": "elastic/distilbert-base-uncased-finetuned-conll03-english", + "elastic-eis-elser-v2": "elastic/eis-elser-v2", + "elastic-elser-v2": "elastic/elser-v2", + "elastic-hugging-face-elser": "elastic/hugging-face-elser", + "elastic-multilingual-e5-small-optimized": "elastic/multilingual-e5-small-optimized", + "elastic-splade-v3": "elastic/splade-v3", + "elastic-test-elser-v2": "elastic/test-elser-v2", + + "ner-dslim-bert-base": "dslim/bert-base-NER", + "sentiment-distilbert-sst2": "distilbert-base-uncased-finetuned-sst-2-english", + + "es-multilingual-e5-small": "intfloat/multilingual-e5-small", + "es-all-MiniLM-L6-v2": "sentence-transformers/all-MiniLM-L6-v2", + "es-cross-encoder-ms-marco": "cross-encoder/ms-marco-MiniLM-L-6-v2", + "es-dpr-question-encoder": "facebook/dpr-question_encoder-single-nq-base" +} diff --git a/dev-tools/generate_malicious_models.py b/dev-tools/generate_malicious_models.py new file mode 100644 index 000000000..21afe1110 --- /dev/null +++ b/dev-tools/generate_malicious_models.py @@ -0,0 +1,274 @@ +#!/usr/bin/env python3 +# +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License +# 2.0 and the following additional limitation. Functionality enabled by the +# files subject to the Elastic License 2.0 may only be used in production when +# invoked by an Elasticsearch process with a license key installed that permits +# use of machine learning features. You may not use this file except in +# compliance with the Elastic License 2.0 and the foregoing additional +# limitation. +# +"""Generate malicious TorchScript model fixtures for validator integration tests. + +Each model is designed to exercise a specific attack vector that the +CModelGraphValidator must detect and reject. + +Usage: + python3 generate_malicious_models.py [output_dir] + +The output directory defaults to the same directory as this script. +""" + +import os +import sys +from pathlib import Path + +import torch +from torch import Tensor +from typing import Optional + + +# --- Malicious model definitions --- + + +class FileReaderModel(torch.nn.Module): + """Uses aten::from_file to read arbitrary files from disk.""" + def forward(self, x: Tensor) -> Tensor: + stolen = torch.from_file("/etc/passwd", size=100) + return stolen + + +class MixedFileReaderModel(torch.nn.Module): + """Mixes allowed ops with a forbidden aten::from_file call.""" + def forward(self, x: Tensor) -> Tensor: + y = x + x + z = torch.from_file("/etc/shadow", size=10) + return y + z + + +class HiddenInSubmodule(torch.nn.Module): + """Hides aten::sin (unrecognised) three levels deep in submodules.""" + def __init__(self): + super().__init__() + self.inner = _Inner() + + def forward(self, x: Tensor) -> Tensor: + y = x * x + return self.inner(y) + + +class _Inner(torch.nn.Module): + def __init__(self): + super().__init__() + self.leaf = _Leaf() + + def forward(self, x: Tensor) -> Tensor: + return self.leaf(x) + x + + +class _Leaf(torch.nn.Module): + def forward(self, x: Tensor) -> Tensor: + return torch.sin(x) + + +class ConditionalMalicious(torch.nn.Module): + """Hides an unrecognised op (aten::sin) inside one branch of a conditional.""" + def forward(self, x: Tensor) -> Tensor: + if x.sum() > 0: + return torch.sin(x) + else: + return x + x + + +class ManyUnrecognisedOps(torch.nn.Module): + """Uses several different unrecognised ops to simulate an unexpected arch.""" + def forward(self, x: Tensor) -> Tensor: + a = torch.sin(x) + b = torch.cos(x) + c = torch.tan(x) + d = torch.exp(x) + return a + b + c + d + + +class FileReaderInSubmodule(torch.nn.Module): + """Hides the forbidden aten::from_file inside a submodule.""" + def __init__(self): + super().__init__() + self.reader = _FileReaderChild() + + def forward(self, x: Tensor) -> Tensor: + return x + self.reader(x) + + +class _FileReaderChild(torch.nn.Module): + def forward(self, x: Tensor) -> Tensor: + return torch.from_file("/tmp/secret", size=10) + + +# --- Sandbox2 attack models (PR #2873) --- +# +# These reproduce real-world attack vectors that exploit torch.as_strided +# to leak heap addresses and build ROP chains. The graph validator must +# reject them because as_strided (and several helper ops) are not in the +# transformer-architecture allowlist. + + +class HeapLeakModel(torch.nn.Module): + """Leaks heap addresses via torch.as_strided with a malicious storage offset. + + The attack scans the heap for libtorch pointers by reading out-of-bounds + memory through a strided view. Key unrecognised ops: aten::as_strided, + aten::item, aten::__rshift__. + """ + def find_heap_overlap(self, tensors: Tensor) -> tuple[list[int], Tensor]: + evil_length = 0x900 * 8 + leak = torch.as_strided( + tensors, [evil_length], [1], 2305843009213693952 - evil_length) + index: list[int] = [-1] + for i in range(evil_length - 1 - 5, -1, -1): + if (leak[i + 3].item() == leak[i + 5].item() + and leak[i + 3].item() >> 40 == 127): + if (leak[i + 2].item() >= 0x100000000 + and leak[i + 2].item() <= 0x100000005): + index.append(i) + return (index, leak) + + def forward(self, a: Tensor, b: Tensor, c: Tensor, d: Tensor) -> Tensor: + tmp: list[Tensor] = [] + heap_len = 200 + array_len = 0x900 + for i in range(heap_len): + tmp_values = [0xdead00000 + i] + [ + 0xdead00000 + j for j in range(2, array_len)] + tmp.append(torch.tensor(tmp_values)) + for i in range(heap_len): + indexs, leak = self.find_heap_overlap(tmp[i]) + indexs = indexs[1:] + if indexs == []: + continue + for index in indexs: + heap_addr = leak[index + 3].item() + libtorch_python_addr = leak[index + 1].item() + libtorch_cpu_addr = libtorch_python_addr - 0xc878b50 + assert 1 == 0, "leak: " + hex(libtorch_cpu_addr) + for i in range(heap_len): + print(tmp[i].shape) + return torch.zeros(0) + + +class RopExploitModel(torch.nn.Module): + """Builds a ROP chain to mprotect heap memory and execute shellcode. + + Extends the heap-leak technique to overwrite function pointers and + chain gadgets that call mprotect followed by attacker-controlled + shellcode that writes files to disk. Key unrecognised ops: + aten::as_strided, aten::copy_, aten::Delete. + """ + def find_heap_overlap(self, tensors: Tensor) -> tuple[list[int], Tensor]: + evil_length = 0x900 * 8 + leak = torch.as_strided( + tensors, [evil_length], [1], 2305843009213693952 - evil_length) + index: list[int] = [-1] + for i in range(evil_length - 1 - 5, -1, -1): + if (leak[i + 3].item() == leak[i + 5].item() + and leak[i + 3].item() >> 40 == 127): + if (leak[i + 2].item() >= 0x100000000 + and leak[i + 2].item() <= 0x100000005): + index.append(i) + return (index, leak) + + def get_address_data(self, leak: Tensor, index: int, + tmp: list[Tensor], target_addr: int) -> int: + leak[index] = target_addr + return int(tmp[0].item()) + + def forward(self, a: Tensor, b: Tensor, c: Tensor, d: Tensor) -> Tensor: + tmp: list[Tensor] = [] + heap_len = 200 + array_len = 0x900 + for i in range(heap_len): + tmp_values = [0xdead00000 + i] + [ + 0xdead00000 + j for j in range(2, array_len)] + tmp.append(torch.tensor(tmp_values)) + for i in range(heap_len): + indexs, leak = self.find_heap_overlap(tmp[i]) + indexs = indexs[1:] + if indexs == []: + continue + for index in indexs: + heap_addr = leak[index + 3].item() + libtorch_python_addr = leak[index + 1].item() + libtorch_cpu_addr = libtorch_python_addr - 0xc878b50 + leak[index + 3] = int(heap_addr - 0x8) + for j in range(heap_len): + if tmp[j][2].item() == 0xDEAD00002 and i != j: + mprotect_got_addr = 0xc9785f0 + libtorch_cpu_addr + self.get_address_data( + leak, index + 3, tmp[j:j + 1], + int(mprotect_got_addr)) + del tmp[j] + assert 1 == 0, "exploit" + return torch.zeros(0) + for i in range(heap_len): + print(tmp[i].shape) + return torch.zeros(0) + + +# --- Generation logic --- + + +MODELS = { + "malicious_file_reader.pt": FileReaderModel, + "malicious_mixed_file_reader.pt": MixedFileReaderModel, + "malicious_hidden_in_submodule.pt": HiddenInSubmodule, + "malicious_conditional.pt": ConditionalMalicious, + "malicious_many_unrecognised.pt": ManyUnrecognisedOps, + "malicious_file_reader_in_submodule.pt": FileReaderInSubmodule, + "malicious_heap_leak.pt": HeapLeakModel, + "malicious_rop_exploit.pt": RopExploitModel, +} + + +def generate(output_dir: Path): + output_dir.mkdir(parents=True, exist_ok=True) + succeeded = [] + failed = [] + + for filename, cls in MODELS.items(): + print(f" {filename}...", end=" ") + try: + model = cls() + model.eval() + scripted = torch.jit.script(model) + path = output_dir / filename + torch.jit.save(scripted, str(path)) + size = path.stat().st_size + print(f"OK ({size} bytes)") + + # Show ops for verification + graph = scripted.forward.graph.copy() + torch._C._jit_pass_inline(graph) + ops = sorted(set(n.kind() for n in graph.nodes())) + print(f" ops: {ops}") + + succeeded.append(filename) + except Exception as exc: + print(f"FAILED: {exc}") + failed.append((filename, str(exc))) + + print(f"\nGenerated {len(succeeded)}/{len(MODELS)} models") + if failed: + print("Failed:") + for name, err in failed: + print(f" {name}: {err}") + return len(failed) == 0 + + +if __name__ == "__main__": + out_dir = (Path(sys.argv[1]) if len(sys.argv) > 1 + else Path(__file__).resolve().parent.parent + / "bin" / "pytorch_inference" / "unittest" / "testfiles" / "malicious_models") + print(f"Generating malicious model fixtures in {out_dir}") + success = generate(out_dir) + sys.exit(0 if success else 1) diff --git a/docs/CHANGELOG.asciidoc b/docs/CHANGELOG.asciidoc index 5958a2533..a5bfd240d 100644 --- a/docs/CHANGELOG.asciidoc +++ b/docs/CHANGELOG.asciidoc @@ -32,6 +32,7 @@ === Enhancements +* Harden pytorch_inference with TorchScript model graph validation. (See {ml-pull}2936[#2936].) * Update the PyTorch library to version 2.7.1. (See {ml-pull}2882[#2882], {ml-pull}2883[#2883].) == {es} version 8.19.0 diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 1877e64b5..41229c335 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -36,3 +36,22 @@ add_custom_target(test COMMAND ${CMAKE_COMMAND} -DTEST_DIR=${CMAKE_BINARY_DIR} -P ${CMAKE_SOURCE_DIR}/cmake/test-check-success.cmake WORKING_DIRECTORY ${CMAKE_SOURCE_DIR} ) +) + +# Common arguments for the pytorch_inference allowlist validation script. +set(_validation_args + -DSOURCE_DIR=${CMAKE_SOURCE_DIR} + -DVALIDATE_CONFIG=${CMAKE_SOURCE_DIR}/dev-tools/extract_model_ops/validation_models.json + -DVALIDATE_PT_DIR=${CMAKE_SOURCE_DIR}/dev-tools/extract_model_ops/es_it_models + -DVALIDATE_VERBOSE=TRUE +) + +# Standalone target for the pytorch_inference allowlist validation. +# See dev-tools/extract_model_ops/README.md for details. +add_custom_target(validate_pytorch_inference_models + COMMAND ${CMAKE_COMMAND} + ${_validation_args} + -P ${CMAKE_SOURCE_DIR}/cmake/run-validation.cmake + COMMENT "Validating pytorch_inference allowlist against HuggingFace models and ES integration test models" + WORKING_DIRECTORY ${CMAKE_SOURCE_DIR} +) From a6e0e57789b93335b2d79435f584caa381260c15 Mon Sep 17 00:00:00 2001 From: Ed Savage Date: Fri, 13 Mar 2026 10:28:27 +1300 Subject: [PATCH 2/3] [ML] Fix stray closing paren in test/CMakeLists.txt Made-with: Cursor --- test/CMakeLists.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 41229c335..0bcc5eb5c 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -36,7 +36,6 @@ add_custom_target(test COMMAND ${CMAKE_COMMAND} -DTEST_DIR=${CMAKE_BINARY_DIR} -P ${CMAKE_SOURCE_DIR}/cmake/test-check-success.cmake WORKING_DIRECTORY ${CMAKE_SOURCE_DIR} ) -) # Common arguments for the pytorch_inference allowlist validation script. set(_validation_args From bc903a10293ed99598cc30eec9e50505000dc68a Mon Sep 17 00:00:00 2001 From: Ed Savage Date: Fri, 13 Mar 2026 10:50:18 +1300 Subject: [PATCH 3/3] [ML] Replace C++20 contains() with count() for C++17 compatibility The 8.19 branch uses C++17; std::unordered_set::contains() requires C++20. Made-with: Cursor --- bin/pytorch_inference/CModelGraphValidator.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bin/pytorch_inference/CModelGraphValidator.cc b/bin/pytorch_inference/CModelGraphValidator.cc index 01658b440..7469ea838 100644 --- a/bin/pytorch_inference/CModelGraphValidator.cc +++ b/bin/pytorch_inference/CModelGraphValidator.cc @@ -57,7 +57,7 @@ CModelGraphValidator::validate(const TStringSet& observedOps, // fail fast when a known-dangerous operation is present and avoids the // cost of scanning for unrecognised ops on a model we will reject anyway. for (const auto& op : observedOps) { - if (forbiddenOps.contains(op)) { + if (forbiddenOps.count(op) > 0) { result.s_IsValid = false; result.s_ForbiddenOps.push_back(op); } @@ -65,7 +65,7 @@ CModelGraphValidator::validate(const TStringSet& observedOps, if (result.s_ForbiddenOps.empty()) { for (const auto& op : observedOps) { - if (allowedOps.contains(op) == false) { + if (allowedOps.count(op) == 0) { result.s_IsValid = false; result.s_UnrecognisedOps.push_back(op); }