Conversation
There was a problem hiding this comment.
Pull request overview
This PR adds support for directly loading GPT-OSS models quantized with MXFP4 format by automatically detecting MXFP4 quantization and applying dequantization during model loading.
Changes:
- Updated model references in test files from local/unsloth paths to official OpenAI model identifiers
- Added MXFP4 quantization detection and automatic dequantization support in model loading utilities
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| test/test_cuda/models/test_moe_model.py | Updated GPT-OSS model reference from local path to OpenAI identifier |
| test/test_cpu/models/test_moe_model.py | Updated GPT-OSS model reference from unsloth path to OpenAI identifier |
| auto_round/utils/model.py | Added MXFP4 detection function and integrated dequantization config into model loading |
|
|
||
| def _is_mxfp4_model(model_path: str) -> bool: | ||
| """Check if the model is quantized with MXFP4.""" | ||
| supported_model_types = ["gpt_oss"] |
There was a problem hiding this comment.
The supported model types are hardcoded in this function. Consider making this a module-level constant or configuration parameter to improve maintainability and make it easier to add support for additional model types in the future.
| quantization_config = Mxfp4Config(dequantized=True) | ||
| logger.info("Detected MXFP4 quantized model, using Mxfp4Config(dequantized=True) for loading.") | ||
| except ImportError: | ||
| logger.warning("Mxfp4Config not available in current transformers version, loading without dequantization.") |
There was a problem hiding this comment.
The warning message could be more actionable by suggesting which transformers version is required for MXFP4 support. Consider adding version information to help users understand what upgrade is needed.
| logger.warning("Mxfp4Config not available in current transformers version, loading without dequantization.") | |
| required_tf_version = "4.46.0" | |
| logger.warning( | |
| "Mxfp4Config is not available in the current transformers installation " | |
| f"(transformers=={transformers.__version__}). MXFP4 dequantization requires " | |
| f"transformers>={required_tf_version}. The model will be loaded without " | |
| "MXFP4 dequantization. Please upgrade transformers, for example with " | |
| f'`pip install -U "transformers>={required_tf_version}"`.' | |
| ) |
Signed-off-by: He, Xin3 <xin3.he@intel.com>
Description
Please briefly describe your main changes, the motivation.
Type of Change
Related Issues
Fixes or relates to #
Checklist Before Submitting