Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions .github/workflows/test_cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,70 +13,70 @@

jobs:
test-cuda-single-gpu:
name: Test CUDA Single GPU (cuda12.6-py3.12)
name: Test CUDA Single GPU (cuda13.0-py3.12)
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
with:
timeout: 60
runner: linux.g5.4xlarge.nvidia.gpu
gpu-arch-type: cuda
gpu-arch-version: "12.6"
gpu-arch-version: "13.0"
submodules: recursive
script: |
conda create --yes --quiet --name py312 python=3.12
source $(conda info --base)/etc/profile.d/conda.sh
conda activate py312

pip install --quiet -r requirements-test.txt
# For some reason the spec above isnt working
pip uninstall -y torch
pip install --no-input --quiet --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126
pip install --no-input --quiet --pre torch --index-url https://download.pytorch.org/whl/nightly/cu130
pip install --quiet .
pytest tests

examples-cuda-single-gpu:

Check warning

Code scanning / CodeQL

Workflow does not contain permissions Medium test

Actions job or workflow does not limit the permissions of the GITHUB_TOKEN. Consider setting an explicit permissions block, using the following as a minimal starting point: {}
name: Examples CUDA Single GPU (cuda12.6-py3.12)
name: Examples CUDA Single GPU (cuda13.0-py3.12)
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
with:
timeout: 60
runner: linux.g5.4xlarge.nvidia.gpu
gpu-arch-type: cuda
gpu-arch-version: "12.6"
gpu-arch-version: "13.0"
submodules: recursive
script: |
conda create --yes --quiet --name py312 python=3.12
source $(conda info --base)/etc/profile.d/conda.sh
conda activate py312

pip install --quiet -r requirements-test.txt
# For some reason the spec above isnt working
pip uninstall -y torch
pip install --no-input --quiet --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126
pip install --no-input --quiet --pre torch --index-url https://download.pytorch.org/whl/nightly/cu130
pip install --quiet .
run_timed() { local start=$SECONDS; "$@"; echo "$* : $((SECONDS - start))s" >> /tmp/timings.txt; }
run_timed python examples/example_autoparallel.py
run_timed python examples/example_llama3.py
run_timed python examples/example_local_map.py
echo "========== Timings =========="
cat /tmp/timings.txt

test-cuda-multi-gpu:

Check warning

Code scanning / CodeQL

Workflow does not contain permissions Medium test

Actions job or workflow does not limit the permissions of the GITHUB_TOKEN. Consider setting an explicit permissions block, using the following as a minimal starting point: {}
name: Test CUDA Multi GPU (cuda12.6-py3.12)
name: Test CUDA Multi GPU (cuda13.0-py3.12)
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
with:
timeout: 60
runner: linux.g5.12xlarge.nvidia.gpu
gpu-arch-type: cuda
gpu-arch-version: "12.6"
gpu-arch-version: "13.0"
submodules: recursive
script: |
conda create --yes --quiet --name py312 python=3.12
source $(conda info --base)/etc/profile.d/conda.sh
conda activate py312

pip install --quiet -r requirements-test.txt
# For some reason the spec above isnt working
pip uninstall -y torch
pip install --no-input --quiet --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126
pip install --no-input --quiet --pre torch --index-url https://download.pytorch.org/whl/nightly/cu130
pip install --quiet .
python examples/example_dcp.py
torchrun --standalone --nproc-per-node 4 examples/example_ds3_local_map.py

Check warning

Code scanning / CodeQL

Workflow does not contain permissions Medium test

Actions job or workflow does not limit the permissions of the GITHUB_TOKEN. Consider setting an explicit permissions block, using the following as a minimal starting point: {}
36 changes: 34 additions & 2 deletions autoparallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@

import copy
import logging
import operator
import time
from contextlib import ExitStack, contextmanager
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, Optional, Union

import torch
Expand Down Expand Up @@ -57,7 +59,35 @@
logger = logging.getLogger(__name__)


def _boxed_nop_preserve_node_meta(fx_g, example_inputs):
def _boxed_nop_preserve_node_meta(fx_g, example_inputs, tag_forward=False):
if tag_forward:
# Tag the forward graph's OUTPUT values as "must save". These are
# the tensors the first min_cut decided to save for backward —
# only these should be saved in the second compilation.
# Uses the "custom" meta field (not "recompute") to avoid
# interfering with ac_joint_pass which uses "recompute" for
# activation checkpointing decisions.
output_node = next(n for n in fx_g.graph.nodes if n.op == "output")
for out in output_node.args[0]:
if not isinstance(out, torch.fx.Node) or out.op != "call_function":
continue
if out.target == operator.getitem:
# getitem metadata doesn't survive preserve_node_meta
# (Python builtin, not dispatched). Tag the parent
# multi-output op instead, keeping the getitem index so the
# second partitioner can replay the exact saved output.
parent = out.args[0]
if isinstance(parent, torch.fx.Node):
custom = parent.meta.setdefault("custom", {})
custom["ap_must_save"] = True
indices = custom.setdefault("ap_must_save_getitem_indices", [])
idx = out.args[1]
if idx not in indices:
indices.append(idx)
else:
out.meta.setdefault("custom", {})
out.meta["custom"]["ap_must_save"] = True

def run(args):
with torch.fx.traceback.preserve_node_meta():
return torch.fx.Interpreter(fx_g).boxed_run(args)
Expand Down Expand Up @@ -482,9 +512,11 @@ def apply_placement(self, sharding_placement):
self.parallel_gm.graph, self.reshard_after_forward
)

fw_compiler_fn = partial(self.compiler_fn, tag_forward=True)

self.parallel_model_fn = parallel_model_fn = aot_compile_joint_with_descriptors(
self.joint_with_descriptors,
fw_compiler=self.compiler_fn,
fw_compiler=fw_compiler_fn,
bw_compiler=self.compiler_fn,
)

Expand Down
Loading
Loading