@@ -10,9 +10,7 @@ def __init__(self, name: Optional[str] = None):
1010 self ._parent : Optional [Module ] = None
1111 self ._modules : OrderedDict [str , Module ] = OrderedDict ()
1212 self ._parameters : OrderedDict [str , Tensor ] = OrderedDict ()
13- self ._buffers : OrderedDict [str , Tensor ] = OrderedDict ()
14- self .training = True
15-
13+
1614 def _generate_default_name (self ) -> str :
1715 class_name = self .__class__ .__name__
1816 base_name = re .sub (r'(?<!^)(?=[A-Z])' , '_' , class_name ).lower ()
@@ -26,10 +24,7 @@ def __setattr__(self, name: str, value: Any) -> None:
2624 if isinstance (value , Module ):
2725 self .register_module (name , value )
2826 elif isinstance (value , Tensor ):
29- if value .requires_grad :
30- self .register_parameter (name , value )
31- else :
32- self .register_buffer (name , value )
27+ self .register_parameter (name , value )
3328 super ().__setattr__ (name , value )
3429
3530 def register_module (self , name : str , module : Optional ['Module' ]) -> None :
@@ -46,14 +41,7 @@ def register_parameter(self, name: str, param: Optional[Tensor]) -> None:
4641 else :
4742 self ._parameters [name ] = param
4843 param .name = self ._full_name + '.' + name
49-
50- def register_buffer (self , name : str , tensor : Optional [Tensor ]) -> None :
51- if tensor is None :
52- self ._buffers .pop (name , None )
53- else :
54- self ._buffers [name ] = tensor
55- tensor .name = self ._full_name + '.' + name
56-
44+
5745 @property
5846 def _full_name (self ) -> str :
5947 names = []
@@ -123,19 +111,14 @@ def state_dict(self) -> Dict[str, Tensor]:
123111 state = {}
124112 for name , param in self .named_parameters ():
125113 state [name ] = param .detach ().clone ()
126- for name , buf in self .named_buffers ():
127- state [name ] = buf .detach ().clone ()
128114 return state
129115
130116 def load_state_dict (self , state_dict : Dict [str , Tensor ]) -> None :
131117 """加载模型状态"""
132118 for name , param in self .named_parameters ():
133119 if name in state_dict :
134120 param .data .copy_ (state_dict [name ])
135- for name , buf in self .named_buffers ():
136- if name in state_dict :
137- buf .data .copy_ (state_dict [name ])
138-
121+
139122 def __call__ (self , * args , ** kwargs ) -> Any :
140123 """允许模块像函数一样调用"""
141124 return self .forward (* args , ** kwargs )
0 commit comments