From d41144102372746ed7bb4c87664e6e40eac19813 Mon Sep 17 00:00:00 2001 From: Arjun Date: Tue, 21 Apr 2026 04:07:41 +0530 Subject: [PATCH] Add PyTorch TorchScript backend for vision benchmarks Add a new backend (pytorch-torchscript) that loads TorchScript models exported via torch.jit.trace or torch.jit.script. This enables running inference with pre-traced PyTorch models (.pt files) without requiring ONNX conversion. Changes: - Add backend_pytorch_torchscript.py with BackendPytorchTorchScript class - Register the new backend in main.py get_backend() - Fix resnet50-pytorch profile: change backend from 'tensorflow' (incorrect) to 'pytorch-torchscript' --- .../python/backend_pytorch_torchscript.py | 44 +++++++++++++++++++ .../python/main.py | 6 ++- 2 files changed, 49 insertions(+), 1 deletion(-) create mode 100644 vision/classification_and_detection/python/backend_pytorch_torchscript.py diff --git a/vision/classification_and_detection/python/backend_pytorch_torchscript.py b/vision/classification_and_detection/python/backend_pytorch_torchscript.py new file mode 100644 index 0000000000..2a8de2dd1a --- /dev/null +++ b/vision/classification_and_detection/python/backend_pytorch_torchscript.py @@ -0,0 +1,44 @@ +""" +PyTorch TorchScript backend for MLPerf inference. + +Loads TorchScript models (.pt) exported via torch.jit.trace or torch.jit.script. +Unlike backend_pytorch_native which expects raw state dicts, this backend works +directly with serialized TorchScript modules. +""" + +# pylint: disable=unused-argument,missing-docstring +import torch +import backend + + +class BackendPytorchTorchScript(backend.Backend): + def __init__(self): + super(BackendPytorchTorchScript, self).__init__() + self.model = None + self.device = "cuda:0" if torch.cuda.is_available() else "cpu" + + def version(self): + return torch.__version__ + + def name(self): + return "pytorch-torchscript" + + def image_format(self): + return "NCHW" + + def load(self, model_path, inputs=None, outputs=None): + self.model = torch.jit.load(model_path, map_location=self.device) + self.model.eval() + self.inputs = inputs or ["image"] + self.outputs = outputs or ["output"] + return self + + def predict(self, feed): + key = [key for key in feed.keys()][0] + feed[key] = torch.tensor(feed[key]).float().to(self.device) + with torch.no_grad(): + output = self.model(feed[key]) + if isinstance(output, torch.Tensor): + return [output.cpu().numpy()] + # handle tuple/list outputs + return [o.cpu().numpy() for o in output] diff --git a/vision/classification_and_detection/python/main.py b/vision/classification_and_detection/python/main.py index 5f1ef39429..d75fc49409 100755 --- a/vision/classification_and_detection/python/main.py +++ b/vision/classification_and_detection/python/main.py @@ -141,7 +141,7 @@ "inputs": "image", "outputs": "ArgMax:0", "dataset": "imagenet_pytorch", - "backend": "tensorflow", + "backend": "pytorch-torchscript", "model-name": "resnet50", }, "resnet50-onnxruntime": { @@ -408,6 +408,10 @@ def get_backend(backend): from backend_pytorch_native import BackendPytorchNative backend = BackendPytorchNative() + elif backend == "pytorch-torchscript": + from backend_pytorch_torchscript import BackendPytorchTorchScript + + backend = BackendPytorchTorchScript() elif backend == "tflite": from backend_tflite import BackendTflite