@@ -58,6 +58,11 @@ def test_natural_key(self):
5858 "layer.10.weight" ,
5959 ]
6060
61+ def test_is_layers_grouped (self ):
62+ assert lwu .layers_are_grouped (["layers.0.w" , "layers.0.b" , "layers.1.w" , "layers.1.b" ]) is True
63+ assert lwu .layers_are_grouped (["layers.0.w" , "layers.1.w" , "layers.0.b" ]) is False
64+ assert lwu .layers_are_grouped (["embed.weight" ]) is True
65+
6166 def test_measure_time (self ):
6267 @lwu .measure_time ("T" )
6368 def dummy ():
@@ -161,6 +166,25 @@ def dummy_load(model, fd_config):
161166 monkeypatch .setenv ("FD_ENABLE_MODEL_LOAD_CACHE" , "1" )
162167 assert dummy_load (mock_model , cfg ) == {"loaded" : True }
163168
169+ def test_save_model_bf16_cache (self , tmp_path , monkeypatch ):
170+ monkeypatch .setenv ("FD_ENABLE_MODEL_LOAD_CACHE" , "1" )
171+ cfg = _cfg ()
172+ cfg .model_config .model = str (tmp_path )
173+ cfg .quant_config .is_checkpoint_bf16 = True
174+ cfg .parallel_config .tensor_parallel_rank = 0
175+
176+ saved = {}
177+ monkeypatch .setattr ("paddle.save" , lambda sd , p : saved .update ({"path" : p }))
178+
179+ @lwu .save_model ()
180+ def dummy_load (model , fd_config ):
181+ return {"loaded" : True }
182+
183+ mock_model = SimpleNamespace (state_dict = lambda : {"w" : 1 })
184+ result = dummy_load (mock_model , cfg )
185+ assert result == {"loaded" : True }
186+ assert "path" in saved
187+
164188
165189class TestCompositeLoading :
166190 def test_load_kv_cache_scale (self , tmp_path ):
@@ -214,3 +238,46 @@ def test_load_ep_checkpoint(self, tmp_path):
214238 mock_cls = SimpleNamespace (_get_tensor_parallel_mappings = lambda _ : {})
215239 result = lwu .load_ep_checkpoint (mock_cls , str (tmp_path ), cfg , return_numpy = True )
216240 np .testing .assert_allclose (result ["w" ], [1.0 , 2.0 ], rtol = 1e-6 )
241+
242+ def test_composite_checkpoint_ep (self , tmp_path , monkeypatch ):
243+ save_file ({"w" : np .array ([1.0 ], dtype = np .float32 )}, str (tmp_path / "s1.safetensors" ))
244+ index = {"weight_map" : {"w" : "s1.safetensors" }}
245+ with open (str (tmp_path / "model.safetensors.index.json" ), "w" ) as f :
246+ json .dump (index , f )
247+ cfg = _cfg ()
248+ cfg .parallel_config .use_ep = True
249+ cfg .parallel_config .num_experts_start_offset = 0
250+ cfg .parallel_config .num_experts_per_rank = 1
251+ cfg .model_config .moe_num_experts = 1
252+ cfg .model_config .moe_layer_start_index = 0
253+ cfg .speculative_config = SimpleNamespace (model_type = "main" )
254+ mock_cls = SimpleNamespace (_get_tensor_parallel_mappings = lambda _ : {})
255+ result = lwu .load_composite_checkpoint (str (tmp_path ), mock_cls , cfg , return_numpy = True )
256+ assert "w" in result
257+
258+ def test_composite_checkpoint_rank_mismatch (self , tmp_path ):
259+ (tmp_path / "rank0" ).mkdir ()
260+ (tmp_path / "rank1" ).mkdir ()
261+ (tmp_path / "rank2" ).mkdir ()
262+ cfg = _cfg ()
263+ cfg .parallel_config .tensor_parallel_size = 2 # doesn't match 3 rank dirs
264+ mock_cls = SimpleNamespace (_get_tensor_parallel_mappings = lambda _ : {})
265+ with pytest .raises (ValueError , match = "tp3" ):
266+ lwu .load_composite_checkpoint (str (tmp_path ), mock_cls , cfg )
267+
268+ def test_composite_checkpoint_kv_quant (self , tmp_path , monkeypatch ):
269+ save_file ({"w" : np .random .randn (4 , 4 ).astype (np .float32 )}, str (tmp_path / "model.safetensors" ))
270+ cfg = _cfg ()
271+ cfg .model_config .model = str (tmp_path )
272+ cfg .quant_config .kv_cache_quant_type = "float8_e4m3fn"
273+ cfg .model_config .kv_cache_quant_scale_path = str (tmp_path / "nonexistent.json" )
274+ monkeypatch .setattr (
275+ "fastdeploy.model_executor.load_weight_utils.load_tp_checkpoint" , lambda * a , ** kw : {"w" : np .ones ((4 , 4 ))}
276+ )
277+ mock_cls = SimpleNamespace (_get_tensor_parallel_mappings = lambda _ : {})
278+ result = lwu .load_composite_checkpoint (str (tmp_path ), mock_cls , cfg , return_numpy = True )
279+ assert "w" in result
280+
281+
282+ if __name__ == "__main__" :
283+ pytest .main ([__file__ , "-v" ])
0 commit comments