Skip to content

Commit 7ec711d

Browse files
committed
module:设计
1 parent 039deae commit 7ec711d

1 file changed

Lines changed: 4 additions & 21 deletions

File tree

front/py/deepx/nn/modules/module.py

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,7 @@ def __init__(self, name: Optional[str] = None):
1010
self._parent: Optional[Module] = None
1111
self._modules: OrderedDict[str, Module] = OrderedDict()
1212
self._parameters: OrderedDict[str, Tensor] = OrderedDict()
13-
self._buffers: OrderedDict[str, Tensor] = OrderedDict()
14-
self.training = True
15-
13+
1614
def _generate_default_name(self) -> str:
1715
class_name = self.__class__.__name__
1816
base_name = re.sub(r'(?<!^)(?=[A-Z])', '_', class_name).lower()
@@ -26,10 +24,7 @@ def __setattr__(self, name: str, value: Any) -> None:
2624
if isinstance(value, Module):
2725
self.register_module(name, value)
2826
elif isinstance(value, Tensor):
29-
if value.requires_grad:
30-
self.register_parameter(name, value)
31-
else:
32-
self.register_buffer(name, value)
27+
self.register_parameter(name, value)
3328
super().__setattr__(name, value)
3429

3530
def register_module(self, name: str, module: Optional['Module']) -> None:
@@ -46,14 +41,7 @@ def register_parameter(self, name: str, param: Optional[Tensor]) -> None:
4641
else:
4742
self._parameters[name] = param
4843
param.name = self._full_name + '.' + name
49-
50-
def register_buffer(self, name: str, tensor: Optional[Tensor]) -> None:
51-
if tensor is None:
52-
self._buffers.pop(name, None)
53-
else:
54-
self._buffers[name] = tensor
55-
tensor.name = self._full_name + '.' + name
56-
44+
5745
@property
5846
def _full_name(self) -> str:
5947
names = []
@@ -123,19 +111,14 @@ def state_dict(self) -> Dict[str, Tensor]:
123111
state = {}
124112
for name, param in self.named_parameters():
125113
state[name] = param.detach().clone()
126-
for name, buf in self.named_buffers():
127-
state[name] = buf.detach().clone()
128114
return state
129115

130116
def load_state_dict(self, state_dict: Dict[str, Tensor]) -> None:
131117
"""加载模型状态"""
132118
for name, param in self.named_parameters():
133119
if name in state_dict:
134120
param.data.copy_(state_dict[name])
135-
for name, buf in self.named_buffers():
136-
if name in state_dict:
137-
buf.data.copy_(state_dict[name])
138-
121+
139122
def __call__(self, *args, **kwargs) -> Any:
140123
"""允许模块像函数一样调用"""
141124
return self.forward(*args, **kwargs)

0 commit comments

Comments
 (0)