diff --git a/angelslim/utils/config_parser.py b/angelslim/utils/config_parser.py index fb2f7fed..b0244b43 100644 --- a/angelslim/utils/config_parser.py +++ b/angelslim/utils/config_parser.py @@ -755,6 +755,25 @@ def parse_json_compression_config_section(compress_config: dict) -> CompressionC return CompressionConfig(name=comp_names, quantization=quantization, cache=cache) +def _require_json_section(config_data: dict, section_name: str) -> Any: + """ + Fetch a required top-level section from a JSON configuration. + + Raises a descriptive ValueError (instead of a bare KeyError) when the + section is missing, mirroring the validation style of the YAML parser. + + Args: + config_data: The full parsed JSON configuration dictionary. + section_name: Name of the required top-level section. + + Returns: + The section value associated with ``section_name``. + """ + if section_name not in config_data: + raise ValueError(f"Missing required '{section_name}' section in JSON configuration.") + return config_data[section_name] + + def parse_json_full_config(json_file_path: str) -> FullConfig: """ Parses a JSON configuration file into a FullConfig instance @@ -769,10 +788,12 @@ def parse_json_full_config(json_file_path: str) -> FullConfig: config_data = json.load(f) # Parse model configuration section - model_config = ModelConfig(**config_data["model_config"]) + model_config = ModelConfig(**_require_json_section(config_data, "model_config")) # Parse compression configuration section - comp_config = parse_json_compression_config_section(config_data["compression_config"]) + comp_config = parse_json_compression_config_section( + _require_json_section(config_data, "compression_config") + ) # Parse other configuration sections with default fallbacks dataset_config, global_config, infer_config = ( @@ -805,26 +826,9 @@ def parse_json_full_config(json_file_path: str) -> FullConfig: if spin_data is not None: transform_config.spin_config = SpinConfig(**spin_data) - # Parse calibration configuration section (nested under compression) - comp_data = config_data.get("compression_config", {}) - calibrate_data = comp_data.get("calibrate", None) - if not calibrate_data and config_data.get("calibrate_config"): - # Backward compatibility: support top-level calibrate_config - calibrate_data = config_data["calibrate_config"] - if calibrate_data: - comp_config.calibrate = CalibrateConfig(**calibrate_data) - - # Parse transform configuration section - transform_config = None - transform_data = config_data.get("transform_config", {}) - if transform_data: - spin_data = transform_data.pop("spin_config", None) - transform_config = TransformConfig(**transform_data) - if spin_data is not None: - transform_config.spin_config = SpinConfig(**spin_data) - return FullConfig( model_config=model_config, + compression_config=comp_config, dataset_config=dataset_config, global_config=global_config, infer_config=infer_config, diff --git a/tests/test_config_parser.py b/tests/test_config_parser.py new file mode 100644 index 00000000..9fdfb594 --- /dev/null +++ b/tests/test_config_parser.py @@ -0,0 +1,65 @@ +# Copyright 2025 Tencent Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for JSON configuration parsing in ``angelslim.utils.config_parser``. + +These tests are CPU-only and require neither a GPU nor model weights: they +exercise the pure configuration-parsing logic that ``Engine.prepare_model`` +relies on when loading a previously compressed model from +``angelslim_config.json``. +""" + +import json + +import pytest + +from angelslim.utils.config_parser import parse_json_full_config + + +def _write_json(tmp_path, payload): + config_path = tmp_path / "angelslim_config.json" + config_path.write_text(json.dumps(payload)) + return str(config_path) + + +def test_json_roundtrip_preserves_compression_config(tmp_path): + """The compression section must survive a JSON load round-trip. + + ``Engine.prepare_model`` forwards ``slim_config.compression_config`` to + ``from_pretrained``; if it is dropped during parsing the compressed model is + reloaded without any compression metadata. + """ + payload = { + "model_config": {"name": "Qwen", "model_path": "Base Model Path"}, + "compression_config": { + "name": "PTQ", + "quantization": {"name": "fp8_dynamic", "bits": 8}, + }, + "global_config": {"save_path": "Save Model Path"}, + } + + full_config = parse_json_full_config(_write_json(tmp_path, payload)) + + assert full_config.compression_config is not None + assert full_config.compression_config.name == ["PTQ"] + assert full_config.compression_config.quantization is not None + assert full_config.compression_config.quantization.name == "fp8_dynamic" + + +def test_json_missing_required_section_raises_value_error(tmp_path): + """A missing required section should raise a descriptive ValueError.""" + payload = {"model_config": {"name": "Qwen", "model_path": "Base Model Path"}} + + with pytest.raises(ValueError, match="compression_config"): + parse_json_full_config(_write_json(tmp_path, payload))