Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions hpc_launcher/cli/torchrun_hpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ def main():
"Ensureing that HIP vs ROCR can improve behavior of HF Accelerate and TorchTitan.",
)

parser.add_argument(
"--no-dist-init",
action="store_true",
default=False,
help="Do not call torch.distributed.init_process_group() in the torchrun-hpc trampoline.",
)

# Grab the rest of the command line to launch
# torchrun-hpc does not support running with a pre-generated batch script file
parser.add_argument("command", help="Command to be executed")
Expand Down Expand Up @@ -152,6 +159,10 @@ def main():
launch_args = [
"-u",
f"{os.path.abspath(folder_name)}/{trampoline_file}",
]
if args.no_dist_init:
launch_args.append("--no-dist-init")
launch_args += [
os.path.abspath(args.command),
]
launch_args += args.args
Expand Down
14 changes: 10 additions & 4 deletions hpc_launcher/torch/torchrun_hpc_trampoline.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
def main():
# Strip off the name of this script and pass the rest to runpy
args = sys.argv[1:]
no_dist_init = False
if "--no-dist-init" in args:
no_dist_init = True
args = [a for a in args if a != "--no-dist-init"]

scheduler_type = os.getenv("TORCHRUN_HPC_SCHEDULER")
scheduler = get_schedulers()[scheduler_type]
Expand Down Expand Up @@ -67,6 +71,7 @@ def main():
os.environ["LOCAL_RANK"] = f"{local_device_id}"

torch_dist_initialized = dist.is_initialized()
did_init_dist = False
rdv_protocol = os.getenv("TORCHRUN_HPC_RDV_PROTOCOL")
if world_size > 1 or rdv_protocol == "mpi://":
if rdv_protocol == "mpi://":
Expand All @@ -86,7 +91,7 @@ def main():
f"MPI rendezvous protocol selected without installing mpi_rndv library."
)

if not torch_dist_initialized:
if not torch_dist_initialized and not no_dist_init:
if not backend:
raise Exception(
f"torchrun-hpc is unable to find a valid backend for torch distributed."
Expand All @@ -100,6 +105,7 @@ def main():
dist.init_process_group(
backend, init_method=rdv_protocol, world_size=world_size, rank=rank, device_id=torch.device(device, local_device_id)
)
did_init_dist = True

if rdv_protocol == "mpi://" and rank == 0:
print(
Expand Down Expand Up @@ -129,12 +135,12 @@ def main():
# If the mpi rendezvous protocol is set, this should be necessary but some packages still look for it
os.environ["MASTER_ADDR"] = "23456"

# Note that run_path will prepend the args[0] back onto the sys.argv so it needs to be stripped off first
sys.argv = sys.argv[1:]
# Forward the underlying script argv, but strip any trampoline-only flags.
sys.argv = args
# Run underlying script
runpy.run_path(args[0], run_name="__main__")

if dist.is_initialized():
if did_init_dist and dist.is_initialized():
# Deal with destroying the process group here
dist.destroy_process_group()

Expand Down