Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,089 changes: 611 additions & 478 deletions autoparallel/_testing/models/dsv3.py

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions autoparallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
assert_has_no_collectives,
cleanup_graph,
fix_scatter_on_aliased_inputs,
functionalize_fresh_index_put_mutations,
update_joint_with_descriptors,
)
from .input_validation import (
Expand Down Expand Up @@ -446,6 +447,7 @@ def _apply_placement_common(self, sharding_placement):
from torch._inductor.fx_passes.post_grad import view_to_reshape

view_to_reshape(parallel_gm)
functionalize_fresh_index_put_mutations(parallel_gm)

mark_fsdp_all_gather_recomputation(
parallel_gm.graph, self.reshard_after_forward
Expand Down
23 changes: 23 additions & 0 deletions autoparallel/graph_passes/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,29 @@ def is_collective(node: torch.fx.Node) -> bool:
}


def functionalize_fresh_index_put_mutations(gm: torch.fx.GraphModule) -> bool:
"""Rewrite index_put_ on fresh tensors to the functional index_put form."""
changed = False
for node in gm.graph.nodes:
if (
node.op != "call_function"
or node.target != torch.ops.aten.index_put_.default
):
continue
base = node.args[0]
if not isinstance(base, torch.fx.Node):
continue
if base.op == "placeholder" or len(base.users) != 1:
continue
node.target = torch.ops.aten.index_put.default
changed = True

if changed:
gm.graph.lint()
gm.recompile()
return changed


def fix_scatter_on_aliased_inputs(graph: torch.fx.Graph) -> None:
"""Insert clone before scatter ops whose input has zero strides (aliased from expand).

Expand Down
9 changes: 8 additions & 1 deletion autoparallel/module_construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,14 @@ def _assign_attr(
curr_mod.register_parameter(field, attr)
elif attr_kind == _AttrKind.BUFFER:
assert isinstance(attr, torch.Tensor)
curr_mod.register_buffer(field, attr)
ref_curr_mod = ref_module
for attr_name in prefix:
ref_curr_mod = getattr(ref_curr_mod, attr_name)
curr_mod.register_buffer(
field,
attr,
persistent=field not in ref_curr_mod._non_persistent_buffers_set,
)
else:
setattr(curr_mod, field, attr)

Expand Down
162 changes: 57 additions & 105 deletions examples/example_ds3_local_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,68 +8,42 @@

import torch
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.distributed.fsdp import MixedPrecisionPolicy
from torch.distributed.tensor.placement_types import Shard
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from torch.testing._internal.distributed.fake_pg import FakeStore

from autoparallel._testing.models.dsv3 import (
DeepSeekV3Model,
DeepSeekV3ModelArgs,
MoEArgs,
)
from autoparallel._testing.models.dsv3 import DeepSeekV3Model, make_dsv3_config
from autoparallel.api import AutoParallel
from autoparallel.shardings.placement_options import NumericsLogger

_DEFAULT_DTENSOR_RNG_SEED = 0


def _seed_dtensor_rng(rng_seed: Optional[int]) -> None:
torch.manual_seed(_DEFAULT_DTENSOR_RNG_SEED if rng_seed is None else rng_seed)


def run_test(fake_evaluate: bool, rng_seed: Optional[int], logs_dir: str):
seq_len = 1024
if fake_evaluate:
# must symbolically evaluate to run on 32 dp ranks
# world_size = 2048

world_size = 256

fake_store = FakeStore()
torch.distributed.init_process_group(
"fake", store=fake_store, rank=0, world_size=world_size
)
local_rank = torch.distributed.get_rank()
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
_seed_dtensor_rng(rng_seed)
mesh = torch.distributed.device_mesh.init_device_mesh(
"cuda",
(world_size // 64, 64),
mesh_dim_names=(
"dp",
"ep",
),
mesh_dim_names=("dp", "ep"),
)

config = DeepSeekV3ModelArgs(
vocab_size=102400,
max_seq_len=seq_len,
dim=2048,
inter_dim=10944,
moe_inter_dim=1408,
n_layers=1, # 27,
n_dense_layers=0, # 1,
n_heads=16,
moe_args=MoEArgs(
num_experts=64,
num_shared_experts=2,
top_k=6,
score_func="softmax",
route_norm=False,
score_before_experts=False,
mesh=mesh,
),
q_lora_rank=0,
kv_lora_rank=512,
qk_nope_head_dim=128,
qk_rope_head_dim=64,
v_head_dim=128,
mscale=0.70,
use_flex_attn=False,
attn_mask_type="causal",
)
config = make_dsv3_config(num_experts=64, max_seq_len=seq_len)
else:
dp_degree = 2
ep_degree = 2
Expand All @@ -82,49 +56,29 @@ def run_test(fake_evaluate: bool, rng_seed: Optional[int], logs_dir: str):
int(os.getenv("WORLD_SIZE")) == world_size
), f"Need at least {world_size} GPUs for real evaluation"
local_rank = int(os.getenv("LOCAL_RANK"))
torch.distributed.init_process_group(backend="nccl")
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
_seed_dtensor_rng(rng_seed)
torch.distributed.init_process_group(backend="nccl", device_id=device)
mesh = torch.distributed.device_mesh.init_device_mesh(
"cuda",
(dp_degree, ep_degree),
mesh_dim_names=(
"dp",
"ep",
),
mesh_dim_names=("dp", "ep"),
)

config = DeepSeekV3ModelArgs(
vocab_size=2048,
max_seq_len=seq_len,
dim=256,
inter_dim=1024,
moe_inter_dim=256,
n_layers=4,
n_dense_layers=0,
n_heads=16,
moe_args=MoEArgs(
num_experts=4,
num_shared_experts=2,
top_k=2,
score_func="softmax",
route_norm=False,
score_before_experts=False,
mesh=mesh,
),
q_lora_rank=0,
kv_lora_rank=512,
qk_nope_head_dim=128,
qk_rope_head_dim=64,
v_head_dim=128,
mscale=0.70,
config = make_dsv3_config(
num_experts=4, top_k=2, n_layers=4, n_dense_layers=0, max_seq_len=seq_len
)

local_batch_size = 2
global_batch_size = local_batch_size * mesh.shape[0] * mesh.shape[1]
device = torch.device(f"cuda:{local_rank}")

# parallelize the model
with torch.device("meta"):
model = DeepSeekV3Model(config).bfloat16()
model = DeepSeekV3Model(
config,
mesh=mesh,
compute_dtype=torch.bfloat16,
)

def input_fn():
return torch.randint(
Expand All @@ -137,10 +91,15 @@ def input_fn():
numerics_logger = None
if rng_seed is not None:
numerics_logger = NumericsLogger(logs_dir)
with AutoParallel(model, input_fn, mesh, dynamic=True) as autop:
mp_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
)
with AutoParallel(
model, input_fn, mesh, mp_policy=mp_policy, dynamic=True
) as autop:
autop.add_parameter_memory_constraint(low=None, high=None)

# x_sharding = (Shard(0), Replicate())
x_sharding = (Shard(0), Shard(0))

autop.add_input_constraints([x_sharding])
Expand All @@ -150,11 +109,6 @@ def input_fn():
parallel_mod = autop.apply_placement(sharding_placement)

parallel_mod.to_empty(device=device)
# run weight init on our sharded DTensor params
# TODO: plumb init_std through
# parallel_mod.init_weights(
# init_std=0.02, buffer_device="cuda"
# ) # maybe not correct value
parallel_mod.init_weights(buffer_device=device, seed=rng_seed)
if rng_seed is not None:
numerics_logger.log_model_weights(parallel_mod)
Expand All @@ -174,45 +128,44 @@ def input_fn():
full_batch.to(torch.float32), prefix="full batch input"
)

# Symbolically evaluate in case you want to test running a graph bigger than your gpu
if fake_evaluate:
# all gather on the tokens takes 128 GiB (4GiB * 32 ranks)
shape_env = ShapeEnv()
with FakeTensorMode(
allow_non_fake_inputs=True,
shape_env=shape_env,
):
# now let's run it
for x in microbatches:
with torch.autograd.set_multithreading_enabled(False):
if fake_evaluate:
shape_env = ShapeEnv()
with FakeTensorMode(
allow_non_fake_inputs=True,
shape_env=shape_env,
):
for x in microbatches:
out = parallel_mod(x)
out.backward(torch.ones_like(out))
else:
for i, x in enumerate(microbatches):
assert x.shape[0] == 2
out = parallel_mod(x)
assert not torch.any(torch.isnan(out)), "Found NaNs in forward output"
out.backward(torch.ones_like(out))
else:
for i, x in enumerate(microbatches):
assert x.shape[0] == 2
out = parallel_mod(x)
assert not torch.any(torch.isnan(out)), "Found NaNs in forward output"
out.backward(torch.ones_like(out))
if rng_seed is not None:
numerics_logger.log_diff(out, prefix=f"mb{i} fwd out")
if rng_seed is not None:
numerics_logger.log_diff(out, prefix=f"mb{i} fwd out")

if rng_seed is not None:
for k, v in parallel_mod.named_parameters():
numerics_logger.log_diff(v.grad, prefix=f"grad {k}")
if rng_seed is not None:
for k, v in parallel_mod.named_parameters():
numerics_logger.log_diff(v.grad, prefix=f"grad {k}")

print("All good!")

if torch.distributed.is_initialized():
torch.distributed.barrier()
torch.cuda.synchronize()
if torch.distributed.get_backend() == torch.distributed.Backend.NCCL:
torch.distributed.barrier(device_ids=[local_rank])
else:
torch.distributed.barrier()
torch.cuda.synchronize(device)
torch.distributed.destroy_process_group()


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser(
description="Run DeepSeek V3 pipeline parallel example"
)
parser = argparse.ArgumentParser(description="Run DeepSeek V3 local_map example")
parser.add_argument(
"--fake-evaluate",
action="store_true",
Expand All @@ -235,7 +188,6 @@ def input_fn():

if args.rng_seed is not None:
torch.use_deterministic_algorithms(True)
torch.manual_seed(args.rng_seed)

run_test(
fake_evaluate=args.fake_evaluate, rng_seed=args.rng_seed, logs_dir=args.logs_dir
Expand Down
21 changes: 8 additions & 13 deletions tests/test_graph_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,7 @@
from torch.distributed.fsdp import MixedPrecisionPolicy
from torch.distributed.tensor.placement_types import Replicate, Shard

from autoparallel._testing.models.dsv3 import (
DeepSeekV3Model,
DeepSeekV3ModelArgs,
MoEArgs,
)
from autoparallel._testing.models.dsv3 import DeepSeekV3Model, make_dsv3_config
from autoparallel._testing.models.llama3 import Transformer, TransformerModelArgs
from autoparallel.api import AutoParallel
from autoparallel.graph_passes.graph_clustering import get_identical_regions
Expand Down Expand Up @@ -215,14 +211,13 @@ def input_fn():

def _setup_ds3_local_map_autop(device_mesh_2d, n_layers=2):
global_batch_size = 2 * device_mesh_2d.shape[0] * device_mesh_2d.shape[1]
moe_args = MoEArgs(mesh=device_mesh_2d)
config = DeepSeekV3ModelArgs(
n_layers=n_layers,
n_dense_layers=0,
moe_args=moe_args,
)
config = make_dsv3_config(n_layers=n_layers, n_dense_layers=0)
with torch.device("meta"):
model = DeepSeekV3Model(config).bfloat16()
model = DeepSeekV3Model(
config,
mesh=device_mesh_2d,
compute_dtype=torch.bfloat16,
)
for module in model.modules():
if hasattr(module, "axis_name"):
module.axis_name = device_mesh_2d.mesh_dim_names[1]
Expand All @@ -231,7 +226,7 @@ def input_fn():
return torch.randint(
0,
config.vocab_size,
(global_batch_size, config.max_seq_len),
(global_batch_size, config.rope.max_seq_len),
device="cuda",
)

Expand Down
Loading
Loading