Skip to content

Commit 7b887b3

Browse files
author
cloudforge1
committed
[CI]【Hackathon 10th Spring No.32】add coverage tests for load_weight_utils
- Add test_is_layers_grouped: test layers_are_grouped() with grouped, interleaved, and no-layer keys - Add test_save_model_bf16_cache: exercise save_model decorator with is_checkpoint_bf16=True - Add test_composite_checkpoint_ep: test load_composite_checkpoint use_ep=True branch - Add test_composite_checkpoint_rank_mismatch: test tp_size != rank_dirs ValueError - Add test_composite_checkpoint_kv_quant: test float8_e4m3fn kv_cache path - Add __main__ block for direct execution - Branch coverage: 72% -> 80%
1 parent b9f96a0 commit 7b887b3

1 file changed

Lines changed: 67 additions & 0 deletions

File tree

tests/model_executor/test_load_weight_utils.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@ def test_natural_key(self):
5858
"layer.10.weight",
5959
]
6060

61+
def test_is_layers_grouped(self):
62+
assert lwu.layers_are_grouped(["layers.0.w", "layers.0.b", "layers.1.w", "layers.1.b"]) is True
63+
assert lwu.layers_are_grouped(["layers.0.w", "layers.1.w", "layers.0.b"]) is False
64+
assert lwu.layers_are_grouped(["embed.weight"]) is True
65+
6166
def test_measure_time(self):
6267
@lwu.measure_time("T")
6368
def dummy():
@@ -161,6 +166,25 @@ def dummy_load(model, fd_config):
161166
monkeypatch.setenv("FD_ENABLE_MODEL_LOAD_CACHE", "1")
162167
assert dummy_load(mock_model, cfg) == {"loaded": True}
163168

169+
def test_save_model_bf16_cache(self, tmp_path, monkeypatch):
170+
monkeypatch.setenv("FD_ENABLE_MODEL_LOAD_CACHE", "1")
171+
cfg = _cfg()
172+
cfg.model_config.model = str(tmp_path)
173+
cfg.quant_config.is_checkpoint_bf16 = True
174+
cfg.parallel_config.tensor_parallel_rank = 0
175+
176+
saved = {}
177+
monkeypatch.setattr("paddle.save", lambda sd, p: saved.update({"path": p}))
178+
179+
@lwu.save_model()
180+
def dummy_load(model, fd_config):
181+
return {"loaded": True}
182+
183+
mock_model = SimpleNamespace(state_dict=lambda: {"w": 1})
184+
result = dummy_load(mock_model, cfg)
185+
assert result == {"loaded": True}
186+
assert "path" in saved
187+
164188

165189
class TestCompositeLoading:
166190
def test_load_kv_cache_scale(self, tmp_path):
@@ -214,3 +238,46 @@ def test_load_ep_checkpoint(self, tmp_path):
214238
mock_cls = SimpleNamespace(_get_tensor_parallel_mappings=lambda _: {})
215239
result = lwu.load_ep_checkpoint(mock_cls, str(tmp_path), cfg, return_numpy=True)
216240
np.testing.assert_allclose(result["w"], [1.0, 2.0], rtol=1e-6)
241+
242+
def test_composite_checkpoint_ep(self, tmp_path, monkeypatch):
243+
save_file({"w": np.array([1.0], dtype=np.float32)}, str(tmp_path / "s1.safetensors"))
244+
index = {"weight_map": {"w": "s1.safetensors"}}
245+
with open(str(tmp_path / "model.safetensors.index.json"), "w") as f:
246+
json.dump(index, f)
247+
cfg = _cfg()
248+
cfg.parallel_config.use_ep = True
249+
cfg.parallel_config.num_experts_start_offset = 0
250+
cfg.parallel_config.num_experts_per_rank = 1
251+
cfg.model_config.moe_num_experts = 1
252+
cfg.model_config.moe_layer_start_index = 0
253+
cfg.speculative_config = SimpleNamespace(model_type="main")
254+
mock_cls = SimpleNamespace(_get_tensor_parallel_mappings=lambda _: {})
255+
result = lwu.load_composite_checkpoint(str(tmp_path), mock_cls, cfg, return_numpy=True)
256+
assert "w" in result
257+
258+
def test_composite_checkpoint_rank_mismatch(self, tmp_path):
259+
(tmp_path / "rank0").mkdir()
260+
(tmp_path / "rank1").mkdir()
261+
(tmp_path / "rank2").mkdir()
262+
cfg = _cfg()
263+
cfg.parallel_config.tensor_parallel_size = 2 # doesn't match 3 rank dirs
264+
mock_cls = SimpleNamespace(_get_tensor_parallel_mappings=lambda _: {})
265+
with pytest.raises(ValueError, match="tp3"):
266+
lwu.load_composite_checkpoint(str(tmp_path), mock_cls, cfg)
267+
268+
def test_composite_checkpoint_kv_quant(self, tmp_path, monkeypatch):
269+
save_file({"w": np.random.randn(4, 4).astype(np.float32)}, str(tmp_path / "model.safetensors"))
270+
cfg = _cfg()
271+
cfg.model_config.model = str(tmp_path)
272+
cfg.quant_config.kv_cache_quant_type = "float8_e4m3fn"
273+
cfg.model_config.kv_cache_quant_scale_path = str(tmp_path / "nonexistent.json")
274+
monkeypatch.setattr(
275+
"fastdeploy.model_executor.load_weight_utils.load_tp_checkpoint", lambda *a, **kw: {"w": np.ones((4, 4))}
276+
)
277+
mock_cls = SimpleNamespace(_get_tensor_parallel_mappings=lambda _: {})
278+
result = lwu.load_composite_checkpoint(str(tmp_path), mock_cls, cfg, return_numpy=True)
279+
assert "w" in result
280+
281+
282+
if __name__ == "__main__":
283+
pytest.main([__file__, "-v"])

0 commit comments

Comments
 (0)