Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 24 additions & 20 deletions angelslim/utils/config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,25 @@ def parse_json_compression_config_section(compress_config: dict) -> CompressionC
return CompressionConfig(name=comp_names, quantization=quantization, cache=cache)


def _require_json_section(config_data: dict, section_name: str) -> Any:
"""
Fetch a required top-level section from a JSON configuration.

Raises a descriptive ValueError (instead of a bare KeyError) when the
section is missing, mirroring the validation style of the YAML parser.

Args:
config_data: The full parsed JSON configuration dictionary.
section_name: Name of the required top-level section.

Returns:
The section value associated with ``section_name``.
"""
if section_name not in config_data:
raise ValueError(f"Missing required '{section_name}' section in JSON configuration.")
return config_data[section_name]


def parse_json_full_config(json_file_path: str) -> FullConfig:
"""
Parses a JSON configuration file into a FullConfig instance
Expand All @@ -769,10 +788,12 @@ def parse_json_full_config(json_file_path: str) -> FullConfig:
config_data = json.load(f)

# Parse model configuration section
model_config = ModelConfig(**config_data["model_config"])
model_config = ModelConfig(**_require_json_section(config_data, "model_config"))

# Parse compression configuration section
comp_config = parse_json_compression_config_section(config_data["compression_config"])
comp_config = parse_json_compression_config_section(
_require_json_section(config_data, "compression_config")
)

# Parse other configuration sections with default fallbacks
dataset_config, global_config, infer_config = (
Expand Down Expand Up @@ -805,26 +826,9 @@ def parse_json_full_config(json_file_path: str) -> FullConfig:
if spin_data is not None:
transform_config.spin_config = SpinConfig(**spin_data)

# Parse calibration configuration section (nested under compression)
comp_data = config_data.get("compression_config", {})
calibrate_data = comp_data.get("calibrate", None)
if not calibrate_data and config_data.get("calibrate_config"):
# Backward compatibility: support top-level calibrate_config
calibrate_data = config_data["calibrate_config"]
if calibrate_data:
comp_config.calibrate = CalibrateConfig(**calibrate_data)

# Parse transform configuration section
transform_config = None
transform_data = config_data.get("transform_config", {})
if transform_data:
spin_data = transform_data.pop("spin_config", None)
transform_config = TransformConfig(**transform_data)
if spin_data is not None:
transform_config.spin_config = SpinConfig(**spin_data)

return FullConfig(
model_config=model_config,
compression_config=comp_config,
dataset_config=dataset_config,
global_config=global_config,
infer_config=infer_config,
Expand Down
65 changes: 65 additions & 0 deletions tests/test_config_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright 2025 Tencent Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unit tests for JSON configuration parsing in ``angelslim.utils.config_parser``.

These tests are CPU-only and require neither a GPU nor model weights: they
exercise the pure configuration-parsing logic that ``Engine.prepare_model``
relies on when loading a previously compressed model from
``angelslim_config.json``.
"""

import json

import pytest

from angelslim.utils.config_parser import parse_json_full_config


def _write_json(tmp_path, payload):
config_path = tmp_path / "angelslim_config.json"
config_path.write_text(json.dumps(payload))
return str(config_path)


def test_json_roundtrip_preserves_compression_config(tmp_path):
"""The compression section must survive a JSON load round-trip.

``Engine.prepare_model`` forwards ``slim_config.compression_config`` to
``from_pretrained``; if it is dropped during parsing the compressed model is
reloaded without any compression metadata.
"""
payload = {
"model_config": {"name": "Qwen", "model_path": "Base Model Path"},
"compression_config": {
"name": "PTQ",
"quantization": {"name": "fp8_dynamic", "bits": 8},
},
"global_config": {"save_path": "Save Model Path"},
}

full_config = parse_json_full_config(_write_json(tmp_path, payload))

assert full_config.compression_config is not None
assert full_config.compression_config.name == ["PTQ"]
assert full_config.compression_config.quantization is not None
assert full_config.compression_config.quantization.name == "fp8_dynamic"


def test_json_missing_required_section_raises_value_error(tmp_path):
"""A missing required section should raise a descriptive ValueError."""
payload = {"model_config": {"name": "Qwen", "model_path": "Base Model Path"}}

with pytest.raises(ValueError, match="compression_config"):
parse_json_full_config(_write_json(tmp_path, payload))
Loading