Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/art/megatron/model_support/handlers/default_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,17 @@
import torch

from art.megatron.model_support.spec import (
DEFAULT_DENSE_META,
CompileWorkaroundConfig,
LayerFamilyInstance,
SharedExpertCompileState,
)


class DefaultDenseHandler:
key = "default_dense"
key = DEFAULT_DENSE_META.key
is_moe = False
native_vllm_lora_status = "disabled"
native_vllm_lora_status = DEFAULT_DENSE_META.native_vllm_lora_status

def identity_lora_model_config(self, base_config: Any) -> Any:
return base_config
Expand Down
11 changes: 8 additions & 3 deletions src/art/megatron/model_support/handlers/qwen3_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
_require_moe_experts,
)
from art.megatron.model_support.spec import (
QWEN3_5_DENSE_META,
QWEN3_5_MOE_META,
CompileWorkaroundConfig,
LayerFamilyInstance,
)
Expand All @@ -38,8 +40,9 @@


class Qwen35BaseHandler(DefaultDenseHandler):
# Abstract base; the two leaves below override `key` from their own meta.
key = "qwen3_5_base"
native_vllm_lora_status = "validated"
native_vllm_lora_status = QWEN3_5_DENSE_META.native_vllm_lora_status

def identity_lora_model_config(self, base_config: Any) -> Any:
return getattr(base_config, "text_config", base_config)
Expand Down Expand Up @@ -364,11 +367,13 @@ def get_forward_kwargs(self, model: Any, **kwargs: Any) -> dict[str, Any]:


class Qwen35DenseHandler(Qwen35BaseHandler):
key = "qwen3_5_dense"
key = QWEN3_5_DENSE_META.key
native_vllm_lora_status = QWEN3_5_DENSE_META.native_vllm_lora_status


class Qwen35MoeHandler(Qwen35BaseHandler):
key = "qwen3_5_moe"
key = QWEN3_5_MOE_META.key
native_vllm_lora_status = QWEN3_5_MOE_META.native_vllm_lora_status
is_moe = True

def to_vllm_lora_tensors(
Expand Down
5 changes: 3 additions & 2 deletions src/art/megatron/model_support/handlers/qwen3_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
from art.megatron.model_support.handlers.qwen3_common import (
install_qwen3_text_preprocess_patch,
)
from art.megatron.model_support.spec import QWEN3_DENSE_META


class Qwen3DenseHandler(DefaultDenseHandler):
key = "qwen3_dense"
native_vllm_lora_status = "validated"
key = QWEN3_DENSE_META.key
native_vllm_lora_status = QWEN3_DENSE_META.native_vllm_lora_status

def install_preprocess_patch(self, model_chunks: Sequence[Any]) -> None:
install_qwen3_text_preprocess_patch(model_chunks)
Expand Down
6 changes: 3 additions & 3 deletions src/art/megatron/model_support/handlers/qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from art.megatron.model_support.handlers.qwen3_common import (
install_qwen3_text_preprocess_patch,
)
from art.megatron.model_support.spec import CompileWorkaroundConfig
from art.megatron.model_support.spec import QWEN3_MOE_META, CompileWorkaroundConfig

_QWEN3_MOE_COMPILE_WORKAROUND_FLAGS = (
"alltoall_dtoh",
Expand All @@ -25,8 +25,8 @@


class Qwen3MoeHandler(DefaultMoeHandler):
key = "qwen3_moe"
native_vllm_lora_status = "validated"
key = QWEN3_MOE_META.key
native_vllm_lora_status = QWEN3_MOE_META.native_vllm_lora_status

def to_vllm_lora_tensors(
self,
Expand Down
79 changes: 47 additions & 32 deletions src/art/megatron/model_support/registry.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,39 @@
from art.megatron.model_support.handlers import (
DEFAULT_DENSE_HANDLER,
QWEN3_5_DENSE_HANDLER,
QWEN3_5_MOE_HANDLER,
QWEN3_DENSE_HANDLER,
QWEN3_MOE_HANDLER,
)
import importlib

from art.megatron.model_support.spec import (
ALL_HANDLER_METAS,
DEFAULT_DENSE_META,
QWEN3_5_DENSE_META,
QWEN3_5_MOE_META,
QWEN3_DENSE_META,
QWEN3_MOE_META,
DependencyFloor,
HandlerMeta,
ModelSupportHandler,
ModelSupportSpec,
)

# Handler modules top-level import megatron-core, so we keep the registry
# importable without megatron-core by deferring those imports until something
# actually needs a handler instance. The handler `key` / `native_vllm_lora_status`
# values shared between spec construction (here) and the handler classes live
# in `handler_meta.py`, which has no megatron dependency.
_HANDLER_METAS_BY_KEY: dict[str, HandlerMeta] = {
meta.key: meta for meta in ALL_HANDLER_METAS
}
_HANDLERS_BY_KEY: dict[str, ModelSupportHandler] = {}


def _load_handler(handler_key: str) -> ModelSupportHandler:
cached = _HANDLERS_BY_KEY.get(handler_key)
if cached is not None:
return cached
meta = _HANDLER_METAS_BY_KEY[handler_key]
handler = getattr(importlib.import_module(meta.module), meta.attr)
_HANDLERS_BY_KEY[handler_key] = handler
return handler


_DENSE_TARGET_MODULES = (
"q_proj",
"k_proj",
Expand Down Expand Up @@ -48,28 +71,28 @@
)

DEFAULT_DENSE_SPEC = ModelSupportSpec(
key="default_dense",
handler_key=DEFAULT_DENSE_HANDLER.key,
key=DEFAULT_DENSE_META.key,
handler_key=DEFAULT_DENSE_META.key,
default_target_modules=_DENSE_TARGET_MODULES,
native_vllm_lora_status=DEFAULT_DENSE_HANDLER.native_vllm_lora_status,
native_vllm_lora_status=DEFAULT_DENSE_META.native_vllm_lora_status,
)

QWEN3_MOE_SPEC = ModelSupportSpec(
key="qwen3_moe",
handler_key=QWEN3_MOE_HANDLER.key,
key=QWEN3_MOE_META.key,
handler_key=QWEN3_MOE_META.key,
model_names=(
"Qwen/Qwen3-30B-A3B",
"Qwen/Qwen3-30B-A3B-Base",
"Qwen/Qwen3-30B-A3B-Instruct-2507",
"Qwen/Qwen3-235B-A22B-Instruct-2507",
),
default_target_modules=_QWEN3_MOE_TARGET_MODULES,
native_vllm_lora_status=QWEN3_MOE_HANDLER.native_vllm_lora_status,
native_vllm_lora_status=QWEN3_MOE_META.native_vllm_lora_status,
)

QWEN3_DENSE_SPEC = ModelSupportSpec(
key="qwen3_dense",
handler_key=QWEN3_DENSE_HANDLER.key,
key=QWEN3_DENSE_META.key,
handler_key=QWEN3_DENSE_META.key,
model_names=(
"Qwen/Qwen3-0.6B",
"Qwen/Qwen3-0.6B-Base",
Expand All @@ -87,34 +110,34 @@
"Qwen/Qwen3-32B-Base",
),
default_target_modules=_DENSE_TARGET_MODULES,
native_vllm_lora_status=QWEN3_DENSE_HANDLER.native_vllm_lora_status,
native_vllm_lora_status=QWEN3_DENSE_META.native_vllm_lora_status,
)

QWEN3_5_DENSE_SPEC = ModelSupportSpec(
key="qwen3_5_dense",
handler_key=QWEN3_5_DENSE_HANDLER.key,
key=QWEN3_5_DENSE_META.key,
handler_key=QWEN3_5_DENSE_META.key,
model_names=(
"Qwen/Qwen3.5-4B",
"Qwen/Qwen3.5-27B",
"Qwen/Qwen3.6-27B",
),
default_target_modules=_QWEN3_5_DENSE_TARGET_MODULES,
native_vllm_lora_status=QWEN3_5_DENSE_HANDLER.native_vllm_lora_status,
native_vllm_lora_status=QWEN3_5_DENSE_META.native_vllm_lora_status,
dependency_floor=DependencyFloor(
megatron_bridge="e049cc00c24d03e2ae45d2608c7a44e2d2364e3d",
),
)

QWEN3_5_MOE_SPEC = ModelSupportSpec(
key="qwen3_5_moe",
handler_key=QWEN3_5_MOE_HANDLER.key,
key=QWEN3_5_MOE_META.key,
handler_key=QWEN3_5_MOE_META.key,
model_names=(
"Qwen/Qwen3.5-35B-A3B",
"Qwen/Qwen3.5-397B-A17B",
"Qwen/Qwen3.6-35B-A3B",
),
default_target_modules=_QWEN3_5_MOE_TARGET_MODULES,
native_vllm_lora_status=QWEN3_5_MOE_HANDLER.native_vllm_lora_status,
native_vllm_lora_status=QWEN3_5_MOE_META.native_vllm_lora_status,
dependency_floor=DependencyFloor(
megatron_bridge="e049cc00c24d03e2ae45d2608c7a44e2d2364e3d",
),
Expand Down Expand Up @@ -143,14 +166,6 @@
for spec in PROBE_ONLY_MODEL_SUPPORT_SPECS
for model_name in spec.model_names
}
_HANDLERS_BY_KEY: dict[str, ModelSupportHandler] = {
DEFAULT_DENSE_HANDLER.key: DEFAULT_DENSE_HANDLER,
QWEN3_DENSE_HANDLER.key: QWEN3_DENSE_HANDLER,
QWEN3_MOE_HANDLER.key: QWEN3_MOE_HANDLER,
QWEN3_5_DENSE_HANDLER.key: QWEN3_5_DENSE_HANDLER,
QWEN3_5_MOE_HANDLER.key: QWEN3_5_MOE_HANDLER,
}

QWEN3_DENSE_MODELS = frozenset(QWEN3_DENSE_SPEC.model_names)
QWEN3_MOE_MODELS = frozenset(QWEN3_MOE_SPEC.model_names)
QWEN3_5_DENSE_MODELS = frozenset(QWEN3_5_DENSE_SPEC.model_names)
Expand Down Expand Up @@ -195,7 +210,7 @@ def get_model_support_handler(
def get_model_support_handler_for_spec(
spec: ModelSupportSpec,
) -> ModelSupportHandler:
return _HANDLERS_BY_KEY[spec.handler_key]
return _load_handler(spec.handler_key)


def default_target_modules_for_model(
Expand All @@ -216,7 +231,7 @@ def native_vllm_lora_status_for_model(
*,
allow_unvalidated_arch: bool = False,
) -> str:
return get_model_support_handler(
return get_model_support_spec(
base_model,
allow_unvalidated_arch=allow_unvalidated_arch,
).native_vllm_lora_status
Expand Down
59 changes: 58 additions & 1 deletion src/art/megatron/model_support/spec.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,66 @@
from typing import Any, Literal, Protocol, Sequence

from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field

RolloutWeightsMode = Literal["lora", "merged"]
NativeVllmLoraStatus = Literal["disabled", "wip", "validated"]


class HandlerMeta(BaseModel):
"""Single source of truth for a registered handler's key, status, and
lazy-load address. Kept megatron-free so the registry and spec layer can
be imported without `megatron-core` installed."""

model_config = ConfigDict(frozen=True)

key: str
native_vllm_lora_status: NativeVllmLoraStatus
module: str
attr: str


DEFAULT_DENSE_META = HandlerMeta(
key="default_dense",
native_vllm_lora_status="disabled",
module="art.megatron.model_support.handlers.default_dense",
attr="DEFAULT_DENSE_HANDLER",
)

QWEN3_DENSE_META = HandlerMeta(
key="qwen3_dense",
native_vllm_lora_status="validated",
module="art.megatron.model_support.handlers.qwen3_dense",
attr="QWEN3_DENSE_HANDLER",
)

QWEN3_MOE_META = HandlerMeta(
key="qwen3_moe",
native_vllm_lora_status="validated",
module="art.megatron.model_support.handlers.qwen3_moe",
attr="QWEN3_MOE_HANDLER",
)

QWEN3_5_DENSE_META = HandlerMeta(
key="qwen3_5_dense",
native_vllm_lora_status="validated",
module="art.megatron.model_support.handlers.qwen3_5",
attr="QWEN3_5_DENSE_HANDLER",
)

QWEN3_5_MOE_META = HandlerMeta(
key="qwen3_5_moe",
native_vllm_lora_status="validated",
module="art.megatron.model_support.handlers.qwen3_5",
attr="QWEN3_5_MOE_HANDLER",
)

ALL_HANDLER_METAS: tuple[HandlerMeta, ...] = (
DEFAULT_DENSE_META,
QWEN3_DENSE_META,
QWEN3_MOE_META,
QWEN3_5_DENSE_META,
QWEN3_5_MOE_META,
)
SharedExpertCompileState = Literal[
"none",
"shared_experts",
Expand Down