Skip to content

Commit 192e58d

Browse files
committed
autograd:
1 parent 0b138ff commit 192e58d

File tree

1 file changed

+157
-0
lines changed

1 file changed

+157
-0
lines changed

front/py/autograd/function.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
from typing import Any, Tuple, List, Optional
2+
from ..tensor import Tensor
3+
4+
class Function:
5+
"""
6+
所有自动微分操作的基类
7+
8+
类似于PyTorch的 torch.autograd.Function,用于定义具有自定义前向和反向传播规则的操作。
9+
每个子类都需要实现 forward() 和 backward() 方法。
10+
11+
Example:
12+
class ReLU(Function):
13+
@staticmethod
14+
def forward(ctx, input):
15+
ctx.save_for_backward(input)
16+
return input.clamp(min=0)
17+
18+
@staticmethod
19+
def backward(ctx, grad_output):
20+
input, = ctx.saved_tensors
21+
grad_input = grad_output.clone()
22+
grad_input[input < 0] = 0
23+
return grad_input
24+
"""
25+
26+
@staticmethod
27+
def forward(ctx: 'Context', *args: Any, **kwargs: Any) -> Tensor:
28+
"""
29+
执行操作的前向传播
30+
31+
Args:
32+
ctx: Context对象,用于存储反向传播需要的信息
33+
*args: 输入参数
34+
**kwargs: 关键字参数
35+
36+
Returns:
37+
计算结果张量
38+
"""
39+
raise NotImplementedError
40+
41+
@staticmethod
42+
def backward(ctx: 'Context', *grad_outputs: Tensor) -> Tuple[Optional[Tensor], ...]:
43+
"""
44+
执行操作的反向传播
45+
46+
Args:
47+
ctx: Context对象,包含前向传播保存的信息
48+
grad_outputs: 输出梯度
49+
50+
Returns:
51+
输入梯度的元组
52+
"""
53+
raise NotImplementedError
54+
55+
56+
class Context:
57+
"""
58+
用于在前向和反向传播之间传递信息的上下文对象
59+
"""
60+
def __init__(self):
61+
self.saved_tensors: List[Tensor] = []
62+
self.saved_variables: dict = {}
63+
64+
def save_for_backward(self, *tensors: Tensor) -> None:
65+
"""
66+
保存反向传播需要的张量
67+
68+
Args:
69+
*tensors: 需要保存的张量
70+
"""
71+
self.saved_tensors.extend(tensors)
72+
73+
def save_variables(self, **kwargs: Any) -> None:
74+
"""
75+
保存反向传播需要的变量
76+
77+
Args:
78+
**kwargs: 需要保存的变量
79+
"""
80+
self.saved_variables.update(kwargs)
81+
82+
@property
83+
def saved_values(self) -> dict:
84+
"""获取所有保存的变量"""
85+
return self.saved_variables
86+
87+
88+
class FunctionMeta(type):
89+
"""
90+
Function类的元类,用于管理Function的注册和应用
91+
"""
92+
_function_registry = {}
93+
94+
def __new__(cls, name, bases, attrs):
95+
new_cls = super().__new__(cls, name, bases, attrs)
96+
if 'forward' in attrs:
97+
cls._function_registry[name] = new_cls
98+
return new_cls
99+
100+
@classmethod
101+
def get_function(cls, name: str) -> Optional[type]:
102+
"""
103+
获取已注册的Function
104+
105+
Args:
106+
name: Function的名称
107+
108+
Returns:
109+
对应的Function类
110+
"""
111+
return cls._function_registry.get(name)
112+
113+
114+
def register_function(name: str) -> callable:
115+
"""
116+
注册Function的装饰器
117+
118+
Args:
119+
name: Function的名称
120+
121+
Returns:
122+
装饰器函数
123+
"""
124+
def decorator(cls):
125+
FunctionMeta._function_registry[name] = cls
126+
return cls
127+
return decorator
128+
129+
130+
def apply_function(name: str, *args: Any, **kwargs: Any) -> Tensor:
131+
"""
132+
应用Function到给定的输入
133+
134+
Args:
135+
name: Function的名称
136+
*args: 输入参数
137+
**kwargs: 关键字参数
138+
139+
Returns:
140+
计算结果张量
141+
142+
Raises:
143+
ValueError: 如果Function未注册
144+
"""
145+
function_cls = FunctionMeta.get_function(name)
146+
if function_cls is None:
147+
raise ValueError(f"Function {name} not found")
148+
149+
ctx = Context()
150+
result = function_cls.forward(ctx, *args, **kwargs)
151+
152+
if any(t.requires_grad for t in args if isinstance(t, Tensor)):
153+
result.requires_grad = True
154+
result._ctx = ctx
155+
result._backward_function = function_cls.backward
156+
157+
return result

0 commit comments

Comments
 (0)