diff --git a/src/art/megatron/model_support/handlers/default_dense.py b/src/art/megatron/model_support/handlers/default_dense.py index bb5cffaa..15e9e40c 100644 --- a/src/art/megatron/model_support/handlers/default_dense.py +++ b/src/art/megatron/model_support/handlers/default_dense.py @@ -4,6 +4,7 @@ import torch from art.megatron.model_support.spec import ( + DEFAULT_DENSE_META, CompileWorkaroundConfig, LayerFamilyInstance, SharedExpertCompileState, @@ -11,9 +12,9 @@ class DefaultDenseHandler: - key = "default_dense" + key = DEFAULT_DENSE_META.key is_moe = False - native_vllm_lora_status = "disabled" + native_vllm_lora_status = DEFAULT_DENSE_META.native_vllm_lora_status def identity_lora_model_config(self, base_config: Any) -> Any: return base_config diff --git a/src/art/megatron/model_support/handlers/qwen3_5.py b/src/art/megatron/model_support/handlers/qwen3_5.py index 06ae9392..05dfd72c 100644 --- a/src/art/megatron/model_support/handlers/qwen3_5.py +++ b/src/art/megatron/model_support/handlers/qwen3_5.py @@ -14,6 +14,8 @@ _require_moe_experts, ) from art.megatron.model_support.spec import ( + QWEN3_5_DENSE_META, + QWEN3_5_MOE_META, CompileWorkaroundConfig, LayerFamilyInstance, ) @@ -38,8 +40,9 @@ class Qwen35BaseHandler(DefaultDenseHandler): + # Abstract base; the two leaves below override `key` from their own meta. key = "qwen3_5_base" - native_vllm_lora_status = "validated" + native_vllm_lora_status = QWEN3_5_DENSE_META.native_vllm_lora_status def identity_lora_model_config(self, base_config: Any) -> Any: return getattr(base_config, "text_config", base_config) @@ -364,11 +367,13 @@ def get_forward_kwargs(self, model: Any, **kwargs: Any) -> dict[str, Any]: class Qwen35DenseHandler(Qwen35BaseHandler): - key = "qwen3_5_dense" + key = QWEN3_5_DENSE_META.key + native_vllm_lora_status = QWEN3_5_DENSE_META.native_vllm_lora_status class Qwen35MoeHandler(Qwen35BaseHandler): - key = "qwen3_5_moe" + key = QWEN3_5_MOE_META.key + native_vllm_lora_status = QWEN3_5_MOE_META.native_vllm_lora_status is_moe = True def to_vllm_lora_tensors( diff --git a/src/art/megatron/model_support/handlers/qwen3_dense.py b/src/art/megatron/model_support/handlers/qwen3_dense.py index 5cf76e22..32d8e528 100644 --- a/src/art/megatron/model_support/handlers/qwen3_dense.py +++ b/src/art/megatron/model_support/handlers/qwen3_dense.py @@ -4,11 +4,12 @@ from art.megatron.model_support.handlers.qwen3_common import ( install_qwen3_text_preprocess_patch, ) +from art.megatron.model_support.spec import QWEN3_DENSE_META class Qwen3DenseHandler(DefaultDenseHandler): - key = "qwen3_dense" - native_vllm_lora_status = "validated" + key = QWEN3_DENSE_META.key + native_vllm_lora_status = QWEN3_DENSE_META.native_vllm_lora_status def install_preprocess_patch(self, model_chunks: Sequence[Any]) -> None: install_qwen3_text_preprocess_patch(model_chunks) diff --git a/src/art/megatron/model_support/handlers/qwen3_moe.py b/src/art/megatron/model_support/handlers/qwen3_moe.py index 61aa8fca..b9777fe7 100644 --- a/src/art/megatron/model_support/handlers/qwen3_moe.py +++ b/src/art/megatron/model_support/handlers/qwen3_moe.py @@ -7,7 +7,7 @@ from art.megatron.model_support.handlers.qwen3_common import ( install_qwen3_text_preprocess_patch, ) -from art.megatron.model_support.spec import CompileWorkaroundConfig +from art.megatron.model_support.spec import QWEN3_MOE_META, CompileWorkaroundConfig _QWEN3_MOE_COMPILE_WORKAROUND_FLAGS = ( "alltoall_dtoh", @@ -25,8 +25,8 @@ class Qwen3MoeHandler(DefaultMoeHandler): - key = "qwen3_moe" - native_vllm_lora_status = "validated" + key = QWEN3_MOE_META.key + native_vllm_lora_status = QWEN3_MOE_META.native_vllm_lora_status def to_vllm_lora_tensors( self, diff --git a/src/art/megatron/model_support/registry.py b/src/art/megatron/model_support/registry.py index 910718ce..ef8cd842 100644 --- a/src/art/megatron/model_support/registry.py +++ b/src/art/megatron/model_support/registry.py @@ -1,16 +1,39 @@ -from art.megatron.model_support.handlers import ( - DEFAULT_DENSE_HANDLER, - QWEN3_5_DENSE_HANDLER, - QWEN3_5_MOE_HANDLER, - QWEN3_DENSE_HANDLER, - QWEN3_MOE_HANDLER, -) +import importlib + from art.megatron.model_support.spec import ( + ALL_HANDLER_METAS, + DEFAULT_DENSE_META, + QWEN3_5_DENSE_META, + QWEN3_5_MOE_META, + QWEN3_DENSE_META, + QWEN3_MOE_META, DependencyFloor, + HandlerMeta, ModelSupportHandler, ModelSupportSpec, ) +# Handler modules top-level import megatron-core, so we keep the registry +# importable without megatron-core by deferring those imports until something +# actually needs a handler instance. The handler `key` / `native_vllm_lora_status` +# values shared between spec construction (here) and the handler classes live +# in `handler_meta.py`, which has no megatron dependency. +_HANDLER_METAS_BY_KEY: dict[str, HandlerMeta] = { + meta.key: meta for meta in ALL_HANDLER_METAS +} +_HANDLERS_BY_KEY: dict[str, ModelSupportHandler] = {} + + +def _load_handler(handler_key: str) -> ModelSupportHandler: + cached = _HANDLERS_BY_KEY.get(handler_key) + if cached is not None: + return cached + meta = _HANDLER_METAS_BY_KEY[handler_key] + handler = getattr(importlib.import_module(meta.module), meta.attr) + _HANDLERS_BY_KEY[handler_key] = handler + return handler + + _DENSE_TARGET_MODULES = ( "q_proj", "k_proj", @@ -48,15 +71,15 @@ ) DEFAULT_DENSE_SPEC = ModelSupportSpec( - key="default_dense", - handler_key=DEFAULT_DENSE_HANDLER.key, + key=DEFAULT_DENSE_META.key, + handler_key=DEFAULT_DENSE_META.key, default_target_modules=_DENSE_TARGET_MODULES, - native_vllm_lora_status=DEFAULT_DENSE_HANDLER.native_vllm_lora_status, + native_vllm_lora_status=DEFAULT_DENSE_META.native_vllm_lora_status, ) QWEN3_MOE_SPEC = ModelSupportSpec( - key="qwen3_moe", - handler_key=QWEN3_MOE_HANDLER.key, + key=QWEN3_MOE_META.key, + handler_key=QWEN3_MOE_META.key, model_names=( "Qwen/Qwen3-30B-A3B", "Qwen/Qwen3-30B-A3B-Base", @@ -64,12 +87,12 @@ "Qwen/Qwen3-235B-A22B-Instruct-2507", ), default_target_modules=_QWEN3_MOE_TARGET_MODULES, - native_vllm_lora_status=QWEN3_MOE_HANDLER.native_vllm_lora_status, + native_vllm_lora_status=QWEN3_MOE_META.native_vllm_lora_status, ) QWEN3_DENSE_SPEC = ModelSupportSpec( - key="qwen3_dense", - handler_key=QWEN3_DENSE_HANDLER.key, + key=QWEN3_DENSE_META.key, + handler_key=QWEN3_DENSE_META.key, model_names=( "Qwen/Qwen3-0.6B", "Qwen/Qwen3-0.6B-Base", @@ -87,34 +110,34 @@ "Qwen/Qwen3-32B-Base", ), default_target_modules=_DENSE_TARGET_MODULES, - native_vllm_lora_status=QWEN3_DENSE_HANDLER.native_vllm_lora_status, + native_vllm_lora_status=QWEN3_DENSE_META.native_vllm_lora_status, ) QWEN3_5_DENSE_SPEC = ModelSupportSpec( - key="qwen3_5_dense", - handler_key=QWEN3_5_DENSE_HANDLER.key, + key=QWEN3_5_DENSE_META.key, + handler_key=QWEN3_5_DENSE_META.key, model_names=( "Qwen/Qwen3.5-4B", "Qwen/Qwen3.5-27B", "Qwen/Qwen3.6-27B", ), default_target_modules=_QWEN3_5_DENSE_TARGET_MODULES, - native_vllm_lora_status=QWEN3_5_DENSE_HANDLER.native_vllm_lora_status, + native_vllm_lora_status=QWEN3_5_DENSE_META.native_vllm_lora_status, dependency_floor=DependencyFloor( megatron_bridge="e049cc00c24d03e2ae45d2608c7a44e2d2364e3d", ), ) QWEN3_5_MOE_SPEC = ModelSupportSpec( - key="qwen3_5_moe", - handler_key=QWEN3_5_MOE_HANDLER.key, + key=QWEN3_5_MOE_META.key, + handler_key=QWEN3_5_MOE_META.key, model_names=( "Qwen/Qwen3.5-35B-A3B", "Qwen/Qwen3.5-397B-A17B", "Qwen/Qwen3.6-35B-A3B", ), default_target_modules=_QWEN3_5_MOE_TARGET_MODULES, - native_vllm_lora_status=QWEN3_5_MOE_HANDLER.native_vllm_lora_status, + native_vllm_lora_status=QWEN3_5_MOE_META.native_vllm_lora_status, dependency_floor=DependencyFloor( megatron_bridge="e049cc00c24d03e2ae45d2608c7a44e2d2364e3d", ), @@ -143,14 +166,6 @@ for spec in PROBE_ONLY_MODEL_SUPPORT_SPECS for model_name in spec.model_names } -_HANDLERS_BY_KEY: dict[str, ModelSupportHandler] = { - DEFAULT_DENSE_HANDLER.key: DEFAULT_DENSE_HANDLER, - QWEN3_DENSE_HANDLER.key: QWEN3_DENSE_HANDLER, - QWEN3_MOE_HANDLER.key: QWEN3_MOE_HANDLER, - QWEN3_5_DENSE_HANDLER.key: QWEN3_5_DENSE_HANDLER, - QWEN3_5_MOE_HANDLER.key: QWEN3_5_MOE_HANDLER, -} - QWEN3_DENSE_MODELS = frozenset(QWEN3_DENSE_SPEC.model_names) QWEN3_MOE_MODELS = frozenset(QWEN3_MOE_SPEC.model_names) QWEN3_5_DENSE_MODELS = frozenset(QWEN3_5_DENSE_SPEC.model_names) @@ -195,7 +210,7 @@ def get_model_support_handler( def get_model_support_handler_for_spec( spec: ModelSupportSpec, ) -> ModelSupportHandler: - return _HANDLERS_BY_KEY[spec.handler_key] + return _load_handler(spec.handler_key) def default_target_modules_for_model( @@ -216,7 +231,7 @@ def native_vllm_lora_status_for_model( *, allow_unvalidated_arch: bool = False, ) -> str: - return get_model_support_handler( + return get_model_support_spec( base_model, allow_unvalidated_arch=allow_unvalidated_arch, ).native_vllm_lora_status diff --git a/src/art/megatron/model_support/spec.py b/src/art/megatron/model_support/spec.py index 1e5858c8..612664c7 100644 --- a/src/art/megatron/model_support/spec.py +++ b/src/art/megatron/model_support/spec.py @@ -1,9 +1,66 @@ from typing import Any, Literal, Protocol, Sequence -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field RolloutWeightsMode = Literal["lora", "merged"] NativeVllmLoraStatus = Literal["disabled", "wip", "validated"] + + +class HandlerMeta(BaseModel): + """Single source of truth for a registered handler's key, status, and + lazy-load address. Kept megatron-free so the registry and spec layer can + be imported without `megatron-core` installed.""" + + model_config = ConfigDict(frozen=True) + + key: str + native_vllm_lora_status: NativeVllmLoraStatus + module: str + attr: str + + +DEFAULT_DENSE_META = HandlerMeta( + key="default_dense", + native_vllm_lora_status="disabled", + module="art.megatron.model_support.handlers.default_dense", + attr="DEFAULT_DENSE_HANDLER", +) + +QWEN3_DENSE_META = HandlerMeta( + key="qwen3_dense", + native_vllm_lora_status="validated", + module="art.megatron.model_support.handlers.qwen3_dense", + attr="QWEN3_DENSE_HANDLER", +) + +QWEN3_MOE_META = HandlerMeta( + key="qwen3_moe", + native_vllm_lora_status="validated", + module="art.megatron.model_support.handlers.qwen3_moe", + attr="QWEN3_MOE_HANDLER", +) + +QWEN3_5_DENSE_META = HandlerMeta( + key="qwen3_5_dense", + native_vllm_lora_status="validated", + module="art.megatron.model_support.handlers.qwen3_5", + attr="QWEN3_5_DENSE_HANDLER", +) + +QWEN3_5_MOE_META = HandlerMeta( + key="qwen3_5_moe", + native_vllm_lora_status="validated", + module="art.megatron.model_support.handlers.qwen3_5", + attr="QWEN3_5_MOE_HANDLER", +) + +ALL_HANDLER_METAS: tuple[HandlerMeta, ...] = ( + DEFAULT_DENSE_META, + QWEN3_DENSE_META, + QWEN3_MOE_META, + QWEN3_5_DENSE_META, + QWEN3_5_MOE_META, +) SharedExpertCompileState = Literal[ "none", "shared_experts",