Skip to content

Commit 2c4e51c

Browse files
committed
[ML] Add quantized model ops to pytorch_inference allowlist
Add aten::mul_ and quantized::linear_dynamic to the allowed operations list, fixing validation failures for dynamically quantized models such as ELSER v2 when imported via Eland with torch.quantization.quantize_dynamic. Also update the model extraction tooling to support a "quantize" flag in reference_models.json so that quantized variants are traced with dynamic quantization applied before graph extraction, mirroring the Eland import pipeline. Made-with: Cursor
1 parent c40b317 commit 2c4e51c

File tree

7 files changed

+146
-26
lines changed

7 files changed

+146
-26
lines changed

bin/pytorch_inference/CSupportedOperations.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ const CSupportedOperations::TStringViewSet CSupportedOperations::FORBIDDEN_OPERA
3939
// elastic/test-elser-v2.
4040
// Additional ops from Elasticsearch integration test models
4141
// (PyTorchModelIT, TextExpansionQueryIT, TextEmbeddingQueryIT).
42+
// Quantized operations from dynamically quantized variants of the above
43+
// models (torch.quantization.quantize_dynamic on nn.Linear layers).
4244
const CSupportedOperations::TStringViewSet CSupportedOperations::ALLOWED_OPERATIONS = {
4345
// aten operations (core tensor computations)
4446
"aten::Int"sv,
@@ -79,6 +81,7 @@ const CSupportedOperations::TStringViewSet CSupportedOperations::ALLOWED_OPERATI
7981
"aten::mean"sv,
8082
"aten::min"sv,
8183
"aten::mul"sv,
84+
"aten::mul_"sv,
8285
"aten::ne"sv,
8386
"aten::neg"sv,
8487
"aten::new_ones"sv,
@@ -124,6 +127,8 @@ const CSupportedOperations::TStringViewSet CSupportedOperations::ALLOWED_OPERATI
124127
"prim::dtype"sv,
125128
"prim::max"sv,
126129
"prim::min"sv,
130+
// quantized operations (dynamically quantized models, e.g. ELSER v2)
131+
"quantized::linear_dynamic"sv,
127132
};
128133
}
129134
}

bin/pytorch_inference/unittest/testfiles/reference_model_ops.json

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@
267267
},
268268
"elastic-eis-elser-v2": {
269269
"model_id": "elastic/eis-elser-v2",
270+
"quantized": false,
270271
"ops": [
271272
"aten::Int",
272273
"aten::ScalarImplicit",
@@ -303,6 +304,7 @@
303304
},
304305
"elastic-elser-v2": {
305306
"model_id": "elastic/elser-v2",
307+
"quantized": false,
306308
"ops": [
307309
"aten::Int",
308310
"aten::ScalarImplicit",
@@ -337,6 +339,44 @@
337339
"prim::NumToTensor"
338340
]
339341
},
342+
"elastic-elser-v2-quantized": {
343+
"model_id": "elastic/elser-v2",
344+
"quantized": true,
345+
"ops": [
346+
"aten::Int",
347+
"aten::ScalarImplicit",
348+
"aten::__and__",
349+
"aten::add",
350+
"aten::arange",
351+
"aten::contiguous",
352+
"aten::dropout",
353+
"aten::embedding",
354+
"aten::expand",
355+
"aten::gather",
356+
"aten::ge",
357+
"aten::gelu",
358+
"aten::index",
359+
"aten::layer_norm",
360+
"aten::mul_",
361+
"aten::new_ones",
362+
"aten::reshape",
363+
"aten::scaled_dot_product_attention",
364+
"aten::select",
365+
"aten::size",
366+
"aten::slice",
367+
"aten::tanh",
368+
"aten::to",
369+
"aten::transpose",
370+
"aten::unsqueeze",
371+
"aten::view",
372+
"prim::Constant",
373+
"prim::DictConstruct",
374+
"prim::GetAttr",
375+
"prim::ListConstruct",
376+
"prim::NumToTensor",
377+
"quantized::linear_dynamic"
378+
]
379+
},
340380
"elastic-hugging-face-elser": {
341381
"model_id": "elastic/hugging-face-elser",
342382
"ops": [

dev-tools/extract_model_ops/extract_model_ops.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,25 +35,25 @@
3535

3636
import torch
3737

38-
from torchscript_utils import collect_inlined_ops, load_and_trace_hf_model
38+
from torchscript_utils import (
39+
collect_inlined_ops,
40+
load_and_trace_hf_model,
41+
load_model_config,
42+
)
3943

4044
SCRIPT_DIR = Path(__file__).resolve().parent
4145
DEFAULT_CONFIG = SCRIPT_DIR / "reference_models.json"
4246

4347

44-
def load_reference_models(config_path: Path) -> dict[str, str]:
45-
"""Load the architecture-to-model mapping from a JSON config file."""
46-
with open(config_path) as f:
47-
return json.load(f)
48-
49-
50-
def extract_ops_for_model(model_name: str) -> set[str] | None:
48+
def extract_ops_for_model(model_name: str,
49+
quantize: bool = False) -> set[str] | None:
5150
"""Trace a HuggingFace model and return its TorchScript op set.
5251
5352
Returns None if the model could not be loaded or traced.
5453
"""
55-
print(f" Loading {model_name}...", file=sys.stderr)
56-
traced = load_and_trace_hf_model(model_name)
54+
label = f"{model_name} (quantized)" if quantize else model_name
55+
print(f" Loading {label}...", file=sys.stderr)
56+
traced = load_and_trace_hf_model(model_name, quantize=quantize)
5757
if traced is None:
5858
return None
5959
return collect_inlined_ops(traced)
@@ -81,7 +81,7 @@ def main():
8181
help="Path to reference_models.json config file")
8282
args = parser.parse_args()
8383

84-
reference_models = load_reference_models(args.config)
84+
reference_models = load_model_config(args.config)
8585

8686
per_model_ops = {}
8787
union_ops = set()
@@ -90,8 +90,9 @@ def main():
9090
file=sys.stderr)
9191

9292
failed = []
93-
for arch, model_name in reference_models.items():
94-
ops = extract_ops_for_model(model_name)
93+
for arch, spec in reference_models.items():
94+
ops = extract_ops_for_model(spec["model_id"],
95+
quantize=spec["quantized"])
9596
if ops is None:
9697
failed.append(arch)
9798
print(f" {arch}: FAILED", file=sys.stderr)
@@ -109,7 +110,8 @@ def main():
109110
"pytorch_version": torch.__version__,
110111
"models": {
111112
arch: {
112-
"model_id": reference_models[arch],
113+
"model_id": reference_models[arch]["model_id"],
114+
"quantized": reference_models[arch]["quantized"],
113115
"ops": sorted(ops),
114116
}
115117
for arch, ops in sorted(per_model_ops.items())
@@ -125,7 +127,11 @@ def main():
125127

126128
if args.per_model:
127129
for arch, ops in sorted(per_model_ops.items()):
128-
print(f"\n=== {arch} ({reference_models[arch]}) ===")
130+
spec = reference_models[arch]
131+
label = spec["model_id"]
132+
if spec["quantized"]:
133+
label += " (quantized)"
134+
print(f"\n=== {arch} ({label}) ===")
129135
for op in sorted(ops):
130136
print(f" {op}")
131137

dev-tools/extract_model_ops/reference_models.json

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,10 @@
1616
"elastic-hugging-face-elser": "elastic/hugging-face-elser",
1717
"elastic-multilingual-e5-small-optimized": "elastic/multilingual-e5-small-optimized",
1818
"elastic-splade-v3": "elastic/splade-v3",
19-
"elastic-test-elser-v2": "elastic/test-elser-v2"
19+
"elastic-test-elser-v2": "elastic/test-elser-v2",
20+
21+
"_comment:quantized": "Quantized variants: Eland applies torch.quantization.quantize_dynamic on nn.Linear layers when importing models. These produce quantized::* ops not present in the standard traced graphs above.",
22+
"elastic-elser-v2-quantized": {"model_id": "elastic/elser-v2", "quantized": true},
23+
"elastic-eis-elser-v2-quantized": {"model_id": "elastic/eis-elser-v2", "quantized": true},
24+
"elastic-test-elser-v2-quantized": {"model_id": "elastic/test-elser-v2", "quantized": true}
2025
}

dev-tools/extract_model_ops/torchscript_utils.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,52 @@
1111
#
1212
"""Shared utilities for extracting and inspecting TorchScript operations."""
1313

14+
import json
1415
import os
1516
import sys
17+
from pathlib import Path
1618

1719
import torch
1820
from transformers import AutoConfig, AutoModel, AutoTokenizer
1921

2022

23+
def load_model_config(config_path: Path) -> dict[str, dict]:
24+
"""Load a model config JSON file and normalise entries.
25+
26+
Each entry is either a plain model-name string or a dict with
27+
``model_id`` (required) and optional ``quantized`` boolean. All
28+
entries are normalised to ``{"model_id": str, "quantized": bool}``.
29+
Keys starting with ``_comment`` are silently skipped.
30+
31+
Raises ``ValueError`` for malformed entries so that config problems
32+
are caught early with an actionable message.
33+
"""
34+
with open(config_path) as f:
35+
raw = json.load(f)
36+
37+
models: dict[str, dict] = {}
38+
for key, value in raw.items():
39+
if key.startswith("_comment"):
40+
continue
41+
if isinstance(value, str):
42+
models[key] = {"model_id": value, "quantized": False}
43+
elif isinstance(value, dict):
44+
if "model_id" not in value:
45+
raise ValueError(
46+
f"Config entry {key!r} is a dict but missing required "
47+
f"'model_id' key: {value!r}")
48+
models[key] = {
49+
"model_id": value["model_id"],
50+
"quantized": value.get("quantized", False),
51+
}
52+
else:
53+
raise ValueError(
54+
f"Config entry {key!r} has unsupported type "
55+
f"{type(value).__name__}: {value!r}. "
56+
f"Expected a model name string or a dict with 'model_id'.")
57+
return models
58+
59+
2160
def collect_graph_ops(graph) -> set[str]:
2261
"""Collect all operation names from a TorchScript graph, including blocks."""
2362
ops = set()
@@ -35,9 +74,13 @@ def collect_inlined_ops(module) -> set[str]:
3574
return collect_graph_ops(graph)
3675

3776

38-
def load_and_trace_hf_model(model_name: str):
77+
def load_and_trace_hf_model(model_name: str, quantize: bool = False):
3978
"""Load a HuggingFace model, tokenize sample input, and trace to TorchScript.
4079
80+
When *quantize* is True the model is dynamically quantized (nn.Linear
81+
layers converted to quantized::linear_dynamic) before tracing. This
82+
mirrors what Eland does when importing models for Elasticsearch.
83+
4184
Returns the traced module, or None if the model could not be loaded or traced.
4285
"""
4386
token = os.environ.get("HF_TOKEN")
@@ -53,6 +96,16 @@ def load_and_trace_hf_model(model_name: str):
5396
print(f" LOAD ERROR: {exc}", file=sys.stderr)
5497
return None
5598

99+
if quantize:
100+
try:
101+
model = torch.quantization.quantize_dynamic(
102+
model, {torch.nn.Linear}, dtype=torch.qint8)
103+
print(" Applied dynamic quantization (nn.Linear -> qint8)",
104+
file=sys.stderr)
105+
except Exception as exc:
106+
print(f" QUANTIZE ERROR: {exc}", file=sys.stderr)
107+
return None
108+
56109
inputs = tokenizer(
57110
"This is a sample input for graph extraction.",
58111
return_tensors="pt", padding="max_length",

dev-tools/extract_model_ops/validate_allowlist.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
"""
3030

3131
import argparse
32-
import json
3332
import re
3433
import sys
3534
from pathlib import Path
@@ -40,6 +39,7 @@
4039
collect_graph_ops,
4140
collect_inlined_ops,
4241
load_and_trace_hf_model,
42+
load_model_config,
4343
)
4444

4545
SCRIPT_DIR = Path(__file__).resolve().parent
@@ -103,10 +103,12 @@ def check_ops(ops: set[str],
103103
def validate_model(model_name: str,
104104
allowed: set[str],
105105
forbidden: set[str],
106-
verbose: bool) -> bool:
106+
verbose: bool,
107+
quantize: bool = False) -> bool:
107108
"""Validate one HuggingFace model. Returns True if all ops pass."""
108-
print(f" {model_name}...", file=sys.stderr)
109-
traced = load_and_trace_hf_model(model_name)
109+
label = f"{model_name} (quantized)" if quantize else model_name
110+
print(f" {label}...", file=sys.stderr)
111+
traced = load_and_trace_hf_model(model_name, quantize=quantize)
110112
if traced is None:
111113
print(f" FAILED (could not load/trace)", file=sys.stderr)
112114
return False
@@ -151,14 +153,15 @@ def main():
151153

152154
results: dict[str, bool] = {}
153155

154-
with open(args.config) as f:
155-
models = json.load(f)
156+
models = load_model_config(args.config)
157+
156158
print(f"Validating {len(models)} HuggingFace models from "
157159
f"{args.config.name}...", file=sys.stderr)
158160

159-
for arch, model_id in models.items():
161+
for arch, spec in models.items():
160162
results[arch] = validate_model(
161-
model_id, allowed, forbidden, args.verbose)
163+
spec["model_id"], allowed, forbidden, args.verbose,
164+
quantize=spec["quantized"])
162165

163166
if args.pt_dir and args.pt_dir.is_dir():
164167
pt_files = sorted(args.pt_dir.glob("*.pt"))
@@ -178,7 +181,11 @@ def main():
178181
if key.startswith("pt:"):
179182
print(f" {key}: {status}", file=sys.stderr)
180183
else:
181-
print(f" {key} ({models[key]}): {status}", file=sys.stderr)
184+
spec = models[key]
185+
label = spec["model_id"]
186+
if spec["quantized"]:
187+
label += " (quantized)"
188+
print(f" {key} ({label}): {status}", file=sys.stderr)
182189

183190
print("=" * 60, file=sys.stderr)
184191
if all_pass:

dev-tools/extract_model_ops/validation_models.json

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@
1919
"elastic-splade-v3": "elastic/splade-v3",
2020
"elastic-test-elser-v2": "elastic/test-elser-v2",
2121

22+
"elastic-elser-v2-quantized": {"model_id": "elastic/elser-v2", "quantized": true},
23+
"elastic-eis-elser-v2-quantized": {"model_id": "elastic/eis-elser-v2", "quantized": true},
24+
"elastic-test-elser-v2-quantized": {"model_id": "elastic/test-elser-v2", "quantized": true},
25+
2226
"ner-dslim-bert-base": "dslim/bert-base-NER",
2327
"sentiment-distilbert-sst2": "distilbert-base-uncased-finetuned-sst-2-english",
2428

0 commit comments

Comments
 (0)