|
21 | 21 |
|
22 | 22 | import numpy as np |
23 | 23 | import paddle |
| 24 | +import pytest |
24 | 25 | from safetensors.numpy import save_file |
25 | 26 |
|
26 | 27 | from fastdeploy.model_executor.load_weight_utils import ( |
|
31 | 32 | is_weight_cache_enabled, |
32 | 33 | kv_cache_scale_iterator, |
33 | 34 | load_composite_checkpoint, |
| 35 | + load_ep_checkpoint, |
34 | 36 | load_kv_cache_scale, |
35 | 37 | load_pre_sharded_checkpoint, |
36 | 38 | load_weights_from_cache, |
@@ -458,3 +460,60 @@ def test_composite_pre_sharded(self, monkeypatch): |
458 | 460 | mock_cls = SimpleNamespace(_get_tensor_parallel_mappings=lambda _: {}) |
459 | 461 | result = load_composite_checkpoint(d, mock_cls, cfg) |
460 | 462 | assert "w" in result |
| 463 | + |
| 464 | + # ── load_ep_checkpoint ───────────────────────────────────────────── |
| 465 | + |
| 466 | + def test_load_ep_checkpoint_basic(self): |
| 467 | + with tempfile.TemporaryDirectory() as d: |
| 468 | + save_file( |
| 469 | + {"w": np.array([1.0, 2.0], dtype=np.float32)}, |
| 470 | + os.path.join(d, "s1.safetensors"), |
| 471 | + ) |
| 472 | + index = {"weight_map": {"w": "s1.safetensors"}} |
| 473 | + with open(os.path.join(d, "model.safetensors.index.json"), "w") as f: |
| 474 | + json.dump(index, f) |
| 475 | + cfg = _make_fd_config() |
| 476 | + cfg.parallel_config.num_experts_start_offset = 0 |
| 477 | + cfg.parallel_config.num_experts_per_rank = 1 |
| 478 | + cfg.model_config.moe_num_experts = 2 |
| 479 | + cfg.model_config.moe_layer_start_index = 0 |
| 480 | + cfg.model_config.num_hidden_layers = 1 |
| 481 | + cfg.speculative_config = SimpleNamespace(model_type="main") |
| 482 | + cfg.parallel_config.use_sequence_parallel_moe = False |
| 483 | + mock_cls = SimpleNamespace(_get_tensor_parallel_mappings=lambda _: {}) |
| 484 | + result = load_ep_checkpoint(mock_cls, d, cfg, return_numpy=True) |
| 485 | + assert "w" in result |
| 486 | + np.testing.assert_allclose(result["w"], [1.0, 2.0], rtol=1e-6) |
| 487 | + |
| 488 | + def test_composite_ep_branch(self, monkeypatch): |
| 489 | + cfg = _make_fd_config() |
| 490 | + cfg.parallel_config.use_ep = True |
| 491 | + cfg.quant_config.kv_cache_quant_type = "none" |
| 492 | + monkeypatch.setattr( |
| 493 | + "fastdeploy.model_executor.load_weight_utils.load_ep_checkpoint", |
| 494 | + lambda cls, path, fd_config, return_numpy=True: {"w": np.zeros((2,))}, |
| 495 | + ) |
| 496 | + mock_cls = SimpleNamespace(_get_tensor_parallel_mappings=lambda _: {}) |
| 497 | + with tempfile.TemporaryDirectory() as d: |
| 498 | + result = load_composite_checkpoint(d, mock_cls, cfg) |
| 499 | + assert "w" in result |
| 500 | + |
| 501 | + # ── get_weight_iterator (unordered sharded) ──────────────────────── |
| 502 | + |
| 503 | + def test_get_weight_iterator_unordered(self): |
| 504 | + with tempfile.TemporaryDirectory() as d: |
| 505 | + path = os.path.join(d, "model-001.safetensors") |
| 506 | + save_file( |
| 507 | + {"z_last": np.array([1.0], dtype=np.float32), "a_first": np.array([2.0], dtype=np.float32)}, |
| 508 | + path, |
| 509 | + ) |
| 510 | + # Keys deliberately not naturally sorted → triggers ordered iterator path |
| 511 | + index = {"weight_map": {"z_last": "model-001.safetensors", "a_first": "model-001.safetensors"}} |
| 512 | + with open(os.path.join(d, "model.safetensors.index.json"), "w") as f: |
| 513 | + json.dump(index, f) |
| 514 | + results = dict(get_weight_iterator(d)) |
| 515 | + assert "z_last" in results and "a_first" in results |
| 516 | + |
| 517 | + |
| 518 | +if __name__ == "__main__": |
| 519 | + pytest.main([__file__, "-v"]) |
0 commit comments