diff --git a/front/py/deepx/nn/__init__.py b/front/py/deepx/nn/__init__.py index f20124f..fef9753 100644 --- a/front/py/deepx/nn/__init__.py +++ b/front/py/deepx/nn/__init__.py @@ -1,6 +1,6 @@ from .deepxir import * from .modules import __all__ as _modules_all __all__ = [ - "DeepxIR","DeepxIRResp", + "DeepxIR","DeepxIRResp","deepx_op","deepx_subgraph", *_modules_all ] \ No newline at end of file diff --git a/front/py/deepx/nn/deepxir.py b/front/py/deepx/nn/deepxir.py index 4d837f9..ae84dad 100644 --- a/front/py/deepx/nn/deepxir.py +++ b/front/py/deepx/nn/deepxir.py @@ -250,4 +250,42 @@ def FuseFunc(f): f() - return \ No newline at end of file + return + +from functools import wraps +import inspect + +def deepx_op(opname): + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + sig = inspect.signature(func) + bound = sig.bind(*args, **kwargs) + bound.apply_defaults() + + params = [Param.tensor(v) for k, v in bound.arguments.items() if k != 'out'] + returns = [Param.tensor(bound.arguments['out'])] + + ir = DeepxIR(opname, params, returns, kwargs.get('author', None)) + send(ir) + return func(*args, **kwargs) + return wrapper + return decorator + +def deepx_subgraph(opname): + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + sig = inspect.signature(func) + bound = sig.bind(*args, **kwargs) + bound.apply_defaults() + + params = [Param.tensor(v) for k, v in bound.arguments.items() if k != 'out'] + returns = [Param.tensor(bound.arguments['out'])] + + ir = DeepxIR(opname, params, returns, kwargs.get('author', None)) + # 修改这里的逻辑 + send(ir) + return func(*args, **kwargs) + return wrapper + return decorator \ No newline at end of file diff --git a/front/py/deepx/nn/functional/rtf_matmul.py b/front/py/deepx/nn/functional/rtf_matmul.py index 93f4d5c..f596e2a 100644 --- a/front/py/deepx/nn/functional/rtf_matmul.py +++ b/front/py/deepx/nn/functional/rtf_matmul.py @@ -2,6 +2,7 @@ from deepx.nn import DeepxIR,Param from deepx.scheduler import send + def rtf_matmul(a:Tensor,b:Tensor,out: Tensor ,author='cublas',bench:int=None): args=[Param.tensor(a),Param.tensor(b)] returns=[Param.tensor(out)] @@ -9,4 +10,8 @@ def rtf_matmul(a:Tensor,b:Tensor,out: Tensor ,author='cublas',bench:int=None): if bench is not None: ir._metadata.openbench(bench) send(ir) + return out + +@deepx_op("matmul") +def rtf_matmul_2(a:Tensor,b:Tensor,out: Tensor ,author='cublas',bench:int=None): return out \ No newline at end of file