Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""
PyTorch TorchScript backend for MLPerf inference.

Loads TorchScript models (.pt) exported via torch.jit.trace or torch.jit.script.
Unlike backend_pytorch_native which expects raw state dicts, this backend works
directly with serialized TorchScript modules.
"""

# pylint: disable=unused-argument,missing-docstring
import torch
import backend


class BackendPytorchTorchScript(backend.Backend):
def __init__(self):
super(BackendPytorchTorchScript, self).__init__()
self.model = None
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"

def version(self):
return torch.__version__

def name(self):
return "pytorch-torchscript"

def image_format(self):
return "NCHW"

def load(self, model_path, inputs=None, outputs=None):
self.model = torch.jit.load(model_path, map_location=self.device)
self.model.eval()
self.inputs = inputs or ["image"]
self.outputs = outputs or ["output"]
return self

def predict(self, feed):
key = [key for key in feed.keys()][0]
feed[key] = torch.tensor(feed[key]).float().to(self.device)
with torch.no_grad():
output = self.model(feed[key])
if isinstance(output, torch.Tensor):
return [output.cpu().numpy()]
# handle tuple/list outputs
return [o.cpu().numpy() for o in output]
6 changes: 5 additions & 1 deletion vision/classification_and_detection/python/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@
"inputs": "image",
"outputs": "ArgMax:0",
"dataset": "imagenet_pytorch",
"backend": "tensorflow",
"backend": "pytorch-torchscript",
"model-name": "resnet50",
},
"resnet50-onnxruntime": {
Expand Down Expand Up @@ -408,6 +408,10 @@ def get_backend(backend):
from backend_pytorch_native import BackendPytorchNative

backend = BackendPytorchNative()
elif backend == "pytorch-torchscript":
from backend_pytorch_torchscript import BackendPytorchTorchScript

backend = BackendPytorchTorchScript()
elif backend == "tflite":
from backend_tflite import BackendTflite

Expand Down
Loading