Skip to content

Commit f9a9f5f

Browse files
committed
[Eager] Add eager to graph fallback API
1 parent 9dbe037 commit f9a9f5f

2 files changed

Lines changed: 41 additions & 3 deletions

File tree

PyTorchSimDevice/torch_openreg/openreg/__init__.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,43 @@ def launch_model(model, *args, stream_index=0, timestamp=0, **kwargs):
243243
from .random import * # noqa: F403
244244
from .amp import *
245245

246+
def eager_to_compile(op_name):
247+
"""
248+
Register an eager mode operation as a graph-based implementation using torch.compile().
249+
250+
Args:
251+
op_name: Operator name (e.g., "aten::mul.Tensor")
252+
253+
Example:
254+
torch.npu.eager_to_compile("aten::mul.Tensor")
255+
"""
256+
def wrapper(*args, **kwargs):
257+
@torch.compile(dynamic=False)
258+
def dummy_graph(*args, **kwargs):
259+
# Convert "aten::mul.Tensor" -> torch.ops.aten.mul.Tensor
260+
namespace, op_path = op_name.split("::", 1)
261+
op_path_parts = op_path.split(".")
262+
op = torch.ops
263+
for part in [namespace] + op_path_parts:
264+
op = getattr(op, part)
265+
return op(*args, **kwargs)
266+
return dummy_graph(*args, **kwargs)
267+
268+
torch.library.impl(op_name, "npu", wrapper)
269+
270+
def register_eager_to_compile(ops):
271+
"""
272+
Register multiple operators at once using eager_to_compile.
273+
274+
Args:
275+
ops: List of operator names (e.g., ["aten::mul.Tensor", "aten::add.Tensor"])
276+
277+
Example:
278+
torch.npu.register_eager_to_compile(["aten::mul.Tensor", "aten::add.Tensor"])
279+
"""
280+
for op_name in ops:
281+
eager_to_compile(op_name)
282+
246283
__all__ = [
247284
"device",
248285
"device_count",
@@ -269,4 +306,6 @@ def launch_model(model, *args, stream_index=0, timestamp=0, **kwargs):
269306
"synchronize",
270307
"get_tog_simulator",
271308
"set_tog_simulator",
309+
"eager_to_compile",
310+
"register_eager_to_compile",
272311
]

tests/test_eager.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
import torch
22

3-
@torch.library.impl("aten::mul.Tensor", "npu")
4-
def my_fallback(x, y):
5-
raise NotImplementedError("Fallback called")
3+
torch.npu.register_eager_to_compile(["aten::mul.Tensor", "aten::add.Tensor"])
64

75
if __name__ == "__main__":
86
#torch.npu.register_fallback_op("aten::add.out", my_fallback)
97
device = torch.device("npu:0")
108
x = torch.ones(10, 10).to(device)
119
y = torch.ones(10, 10).to(device)
1210
z = x * y
11+
z = x + z
1312
print(z.cpu())

0 commit comments

Comments
 (0)