Skip to content

Commit e20cb86

Browse files
committed
decorator
1 parent 08068fd commit e20cb86

3 files changed

Lines changed: 45 additions & 2 deletions

File tree

front/py/deepx/nn/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .deepxir import *
22
from .modules import __all__ as _modules_all
33
__all__ = [
4-
"DeepxIR","DeepxIRResp",
4+
"DeepxIR","DeepxIRResp","deepx_op"
55
*_modules_all
66
]

front/py/deepx/nn/deepxir.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,4 +250,42 @@ def FuseFunc(f):
250250

251251
f()
252252

253-
return
253+
return
254+
255+
from functools import wraps
256+
import inspect
257+
258+
def deepx_op(opname):
259+
def decorator(func):
260+
@wraps(func)
261+
def wrapper(*args, **kwargs):
262+
sig = inspect.signature(func)
263+
bound = sig.bind(*args, **kwargs)
264+
bound.apply_defaults()
265+
266+
params = [Param.tensor(v) for k, v in bound.arguments.items() if k != 'out']
267+
returns = [Param.tensor(bound.arguments['out'])]
268+
269+
ir = DeepxIR(opname, params, returns, kwargs.get('author', None))
270+
send(ir)
271+
return func(*args, **kwargs)
272+
return wrapper
273+
return decorator
274+
275+
def deepx_subgraph(opname):
276+
def decorator(func):
277+
@wraps(func)
278+
def wrapper(*args, **kwargs):
279+
sig = inspect.signature(func)
280+
bound = sig.bind(*args, **kwargs)
281+
bound.apply_defaults()
282+
283+
params = [Param.tensor(v) for k, v in bound.arguments.items() if k != 'out']
284+
returns = [Param.tensor(bound.arguments['out'])]
285+
286+
ir = DeepxIR(opname, params, returns, kwargs.get('author', None))
287+
# 修改这里的逻辑
288+
send(ir)
289+
return func(*args, **kwargs)
290+
return wrapper
291+
return decorator

front/py/deepx/nn/functional/rtf_matmul.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,16 @@
22
from deepx.nn import DeepxIR,Param
33
from deepx.scheduler import send
44

5+
56
def rtf_matmul(a:Tensor,b:Tensor,out: Tensor ,author='cublas',bench:int=None):
67
args=[Param.tensor(a),Param.tensor(b)]
78
returns=[Param.tensor(out)]
89
ir=DeepxIR("matmul", args, returns, author)
910
if bench is not None:
1011
ir._metadata.openbench(bench)
1112
send(ir)
13+
return out
14+
15+
@deepx_op("matmul")
16+
def rtf_matmul(a:Tensor,b:Tensor,out: Tensor ,author='cublas',bench:int=None):
1217
return out

0 commit comments

Comments
 (0)