Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,17 @@ dependencies = [
"transformers",
"jsonpatch",
"datasets",
"fm-training-estimator>=0.1.3",
"setuptools==70.0.0"
"setuptools"
]

[project.optional-dependencies]
estimator = [
"fm-training-estimator>=0.1.3",
]
dev = [
"black",
"black",
"mypy",
"pylint",
"pylint",
"pytest",
"pytest-asyncio",
]
Expand Down
2 changes: 2 additions & 0 deletions src/tuning_config_recommender/actions/compute.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from typing import Optional

from loguru import logger
Expand Down
4 changes: 3 additions & 1 deletion src/tuning_config_recommender/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def execute(
unique_tag,
paths,
skip_estimator=None,
fsdp_args_format="accelerate",
):
try:
if not data_config and not tuning_config.get("training_data_path", None):
Expand Down Expand Up @@ -162,7 +163,8 @@ def execute(
ir_clean.get("compute_config", {}), compute_config_path
)
launch_cmd = build_launch_command(
ir_clean, data_path, accel_path, dynamic_args
ir_clean, data_path, accel_path, dynamic_args,
fsdp_args_format=fsdp_args_format,
)
serializable_patches = []
for patch in patches:
Expand Down
71 changes: 65 additions & 6 deletions src/tuning_config_recommender/utils/adapter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,19 +69,78 @@ def prepare_ir_for_accelerate(ir: dict):
return ir, dynamic


def _accel_to_fsdp_args(accel_cfg: dict) -> list[str]:
"""Convert accelerate_config FSDP settings to HF TrainingArguments --fsdp / --fsdp_config."""
fsdp_cfg = accel_cfg.get("fsdp_config", {})
if not fsdp_cfg and accel_cfg.get("distributed_type") != "FSDP":
return []

# Map accelerate sharding strategy to HF --fsdp flags.
strategy_map = {
"FULL_SHARD": "full_shard",
"SHARD_GRAD_OP": "shard_grad_op",
"HYBRID_SHARD": "hybrid_shard",
"HYBRID_SHARD_ZERO2": "hybrid_shard_zero2",
"NO_SHARD": "no_shard",
1: "full_shard",
2: "shard_grad_op",
3: "no_shard",
4: "hybrid_shard",
5: "hybrid_shard_zero2",
}
raw_strategy = fsdp_cfg.get("fsdp_sharding_strategy", "FULL_SHARD")
sharding = strategy_map.get(raw_strategy, "full_shard")

fsdp_flags = [sharding]
if fsdp_cfg.get("fsdp_auto_wrap_policy") == "TRANSFORMER_BASED_WRAP":
fsdp_flags.append("auto_wrap")
if fsdp_cfg.get("fsdp_offload_params", False):
fsdp_flags.append("offload")

# Build --fsdp_config JSON (strip fsdp_ prefixes for HF TrainingArguments).
prefix_map = {
"fsdp_auto_wrap_policy": "auto_wrap_policy",
"fsdp_backward_prefetch": "backward_prefetch",
"fsdp_backward_prefetch_policy": "backward_prefetch",
"fsdp_forward_prefetch": "forward_prefetch",
"fsdp_offload_params": "offload_params",
"fsdp_state_dict_type": "state_dict_type",
"fsdp_cpu_ram_efficient_loading": "cpu_ram_efficient_loading",
"fsdp_sync_module_states": "sync_module_states",
}
hf_fsdp_config = {}
for accel_key, hf_key in prefix_map.items():
if accel_key in fsdp_cfg:
hf_fsdp_config[hf_key] = fsdp_cfg[accel_key]

args = [f"--fsdp {fmt_cli_value(' '.join(fsdp_flags))}"]
if hf_fsdp_config:
args.append(f"--fsdp_config {fmt_cli_value(hf_fsdp_config)}")
return args


def build_launch_command(
ir: dict[str, Any],
data_config_path: Path,
accelerate_config_path: Path,
dynamic_args: list[str] = None,
fsdp_args_format: str = "accelerate",
) -> str:
try:
cmd = [
"accelerate launch",
f"--config_file {accelerate_config_path}",
*(dynamic_args or []),
"-m 'tuning.sft_trainer'",
]
if fsdp_args_format == "hftrainer":
cmd = [
"python -m tuning.sft_trainer",
]
# Convert accelerate FSDP config to HF TrainingArguments.
fsdp_args = _accel_to_fsdp_args(ir.get("accelerate_config", {}))
cmd.extend(fsdp_args)
else:
cmd = [
"accelerate launch",
f"--config_file {accelerate_config_path}",
*(dynamic_args or []),
"-m 'tuning.sft_trainer'",
]

for k, v in ir.get("tuning_config", {}).items():
if v is not None and k != "training_data_path":
Expand Down
Loading