3232from .base import current_report , find_base_report_node , single_step_state
3333
3434
35+ _seen_warnings = set ()
36+
37+
3538@contextlib .contextmanager
3639def register_hooker (model ):
3740 marker = model .marker
@@ -266,27 +269,63 @@ def replace_forward_output(node, current_name=None):
266269
267270 def inner (input_ ):
268271 nonlocal cur_idx
269- if isinstance (input_ , (paddle .Tensor , torch .Tensor )):
270- if cur_idx >= len (numpy_file_list ):
271- raise RuntimeError (
272- f"\n ⚠️ Single-step alignment FAILED: the { cur_idx + 1 } st output is requested, "
273- f"but only { len (numpy_file_list )} pre-saved numpy files are available."
272+ if not isinstance (input_ , (paddle .Tensor , torch .Tensor )):
273+ return input_
274+
275+ if cur_idx >= len (numpy_file_list ):
276+ warning_key = ("single-step: output_count_mismatch" , current_name )
277+ if warning_key not in _seen_warnings :
278+ logger .warning (
279+ f"\n ⚠️ Single-step alignment SKIPPED: the { cur_idx + 1 } st output is requested, "
280+ f"but only { len (numpy_file_list )} pre-saved from base model, skip the current output."
281+ "\n ⚠️ This warning will not repeat for this layer."
274282 f"\n 📌 Layer Name: { current_name } (raw)"
275283 "\n 💡 Possible Causes and Solutions:"
276- "\n - The number of outputs from the current layer in the raw model does not match "
284+ "\n - The number of outputs from the current layer in the raw model is bigger than "
277285 "that of its corresponding layer in the base model."
278286 "\n - Verify that both models have identical architectures for this layer."
279287 "\n - If the corresponding relationship of the current layer is correct, "
280288 "please disable single step mode, or add the layer to blacklist to skip the check of this layer."
289+ "\n - Or when you are sure that the extra output does not need to be compared, "
290+ "you can swap the execution order of the base model and the raw model."
281291 )
282- value = np .load (numpy_file_list [cur_idx ]["path" ])
283- cur_idx += 1
284- if isinstance (input_ , paddle .Tensor ):
285- return paddle .to_tensor (value , dtype = input_ .dtype )
286- else :
287- return torch .as_tensor (value , dtype = input_ .dtype , device = input_ .device )
288- else :
292+ _seen_warnings .add (warning_key )
293+ return input_
294+
295+ value = np .load (numpy_file_list [cur_idx ]["path" ])
296+ cur_idx += 1
297+ base_shape = tuple (value .shape )
298+ raw_shape = tuple (input_ .shape )
299+
300+ if base_shape == raw_shape :
301+ pass
302+
303+ elif np .prod (base_shape ) != np .prod (raw_shape ):
304+ warning_key = ("single-step: shape_mismatch" , current_name )
305+ if warning_key not in _seen_warnings :
306+ logger .warning (
307+ f"\n ⚠️ Single-step alignment SKIPPED: shape mismatch."
308+ "\n ⚠️ This warning will not repeat for this layer."
309+ f"\n 📌 Layer Name: { current_name } (raw)"
310+ f"\n 📌 Shape: { base_shape } (base) vs { raw_shape } (raw)"
311+ )
312+ _seen_warnings .add (warning_key )
289313 return input_
314+ else :
315+ value = value .reshape (input_ .shape )
316+ debug_key = ("single-step: reshape_used" , current_name )
317+ if debug_key not in _seen_warnings :
318+ logger .debug (
319+ f"\n ⚠️ Try to reshape loaded value to input's shape of layer { current_name } (raw). "
320+ "This may lead to numerical errors even if reshape succeeds."
321+ "\n ⚠️ This warning will not repeat for this layer."
322+ )
323+ _seen_warnings .add (debug_key )
324+
325+ if isinstance (input_ , paddle .Tensor ):
326+ return paddle .to_tensor (value , dtype = input_ .dtype )
327+ else :
328+ return torch .as_tensor (value , dtype = input_ .dtype , device = input_ .device )
290329
291330 return inner
292331
@@ -296,14 +335,17 @@ def single_step_check(report, net_id, step_idx, current_name, node_type, bwd_ite
296335 try :
297336 base_report_node = find_base_report_node (net_id , step_idx )
298337 if base_report_node ["name" ] != current_name :
299- warning_msg = (
300- f"\n ⚠️ Single-step alignment WARNING: { node_type } with net_id={ net_id } mismatch!\n "
301- f" 📌 Mismatch { node_type .capitalize ()} : { base_report_node ['name' ]} (base) vs { current_name } (raw)\n "
302- " 💡 Suggestion: Models have different architectures or initialization order. "
303- "Please check the model implementation or decrease 'align_depth' to reduce the alignment "
304- "granularity, or add layers that do not require alignment to the blacklist."
305- )
306- logger .warning (warning_msg )
338+ warning_key = ("single-step: name_mismatch" , current_name )
339+ if warning_key not in _seen_warnings :
340+ logger .warning (
341+ f"\n ⚠️ Single-step alignment WARNING: { node_type } with net_id={ net_id } mismatch!"
342+ "\n ⚠️ This warning will not repeat for this layer."
343+ f"\n 📌 Mismatch { node_type .capitalize ()} : { base_report_node ['name' ]} (base) vs { current_name } (raw)"
344+ "\n 💡 Suggestion: Models have different architectures or class name or initialization order. "
345+ "Please check the model implementation or decrease 'align_depth' to reduce the alignment "
346+ "granularity, or add layers that do not require alignment to the blacklist."
347+ )
348+ _seen_warnings .add (warning_key )
307349 else :
308350 logger .debug (f"Single Step: { current_name } (net_id={ net_id } )" )
309351
0 commit comments