Skip to content

Conversation

@BlueCrescent
Copy link
Member

@BlueCrescent BlueCrescent commented Nov 7, 2025

What does this PR do?

Adds support for multi stage pipeline parallelism schedules, in particular interleaved 1F1B.
Issue #408

General Changes

  • Made code compatible with having multiple stages per rank.
  • Switched to interleaved 1F1B in some configs.
  • Note: In warmstart test, drastically increased epsilon for loss comparison.

Breaking Changes

  • Changes should be backwards compatible.

Checklist before submitting final PR

  • My PR is minimal and addresses one issue in isolation
  • I have merged the latest version of the target branch into this feature branch
  • I have reviewed my own code w.r.t. correct implementation, missing type hints, proper documentation, etc.
  • I have run a sample config for model training
  • I have checked that all tests run through (python tests/tests.py)
  • I have updated the internal changelog (CHANGELOG_DEV.md)

…odel.

Also made None returns more visible in get_module_class_from_name().
- Switched from using abs=1e-16 to rel=1e-2 for loss comparisons. Need to investigate further, why this is necessary for some configurations.
- Additional configs and test setups which are however commented out due to the long runtime of these tests.
- Easier configurability for expected checkpoint paths (for debugging/messing around).
- Better error logging.
Comment on lines +196 to +204
else:
assert len(app_state.model_parts) == 1, "Expected a single model part for non-OptimizersList optimizer."
sd = get_optimizer_state_dict(
model=app_state.model_parts[0],
optimizers=app_state.optimizer,
# NOTE: Flattening is required for pipeline parallelism to work correctly.
# see https://github.com/pytorch/torchtitan/blob/b291ad662493b63d25b038a30a915082d3617baf/torchtitan/components/checkpoint.py#L193-L214
options=StateDictOptions(flatten_optimizer_state_dict=True),
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we remove this, since in case of PP we now always have an optimizer list which takes care of the flattening?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good!

@BlueCrescent BlueCrescent marked this pull request as ready for review November 24, 2025 09:35
Comment on lines 59 to 69
@model_validator(mode="before")
@classmethod
def warn_deprecated_alias(cls, data: Any) -> Any:
if isinstance(data, dict) and "wrapped_model" in data:
warnings.warn(
"Field 'wrapped_model' is deprecated. Use 'wrapped_model_or_parts' instead.",
DeprecationWarning,
stacklevel=3,
)
return data

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use this deprecation warning? If yes, should we use it also in other configs where a field got renamed to plural?

Comment on lines +54 to +56
# ("gpt2_train_num_steps_7_pp.yaml", "gpt2_warm_start_from_step_4_fsdp2.yaml", 4, 2),
# ("gpt2_train_num_steps_7_pp.yaml", "gpt2_warm_start_from_step_4_fsdp2_grad_accu.yaml", 4, 2),
# ("gpt2_train_num_steps_7_tp.yaml", "gpt2_warm_start_from_step_4_fsdp2.yaml", 4, 2),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are currently deactivated due to the long runtime of these tests. Should we activate them anyways?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first and the third commented-out configs are the same, right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that
("gpt2_train_num_steps_7_pp.yaml", "gpt2_warm_start_from_step_4_fsdp2.yaml", 4, 2),
is necessary since we already test
("gpt2_train_num_steps_7_pp_tp.yaml", "gpt2_warm_start_from_step_4_fsdp2.yaml", 8, 2),
which is the same setup + data parallelism, correct?

And since we have
("gpt2_train_num_steps_7_pp_tp.yaml", "gpt2_warm_start_from_step_4_grad_accu.yaml", 8, 1),
we can probably skip
# ("gpt2_train_num_steps_7_pp.yaml", "gpt2_warm_start_from_step_4_fsdp2_grad_accu.yaml", 4, 2),

Copy link
Member Author

@BlueCrescent BlueCrescent Dec 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, these configs are mostly useful for debugging with fewer ranks. Probably makes sense to have them turned off (or even delete them in the future).

Comment on lines +108 to +111
( # FIXME wpe and drop probably should not get the higher weight
["transformer.wte", "transformer.wpe", "transformer.drop"],
self._input_layer_equivalence,
),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this FIXME, anyone got an opinion on whether I can remove wpe and drop from this list?

@rrutmann rrutmann self-requested a review November 25, 2025 13:37
Copy link
Collaborator

@rrutmann rrutmann left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some tests failing for me:

/workspaces/modalities/tests/conversion/gpt2/test_conversion_model.py::test_convert_model_checkpoint_produces_same_logits_as_original[gpt2_config_test.yaml-False]
TypeError: check_model_inputs..wrapped_fn() got an unexpected keyword argument 'input_ids'

/workspaces/modalities/tests/conversion/gpt2/test_convert_gpt2.py::test_converting_gpt2_does_not_change_outputs[gpt2_config_test.yaml-False]
TypeError: check_model_inputs..wrapped_fn() got an unexpected keyword argument 'input_ids'

/workspaces/modalities/tests/fsdp2_parallelization/test_tensor_parallelism.py::TestTensorParallelism::test_tp_sharding[swiglu-fsdp2_config_path1-tp_config_path1]
torch.multiprocessing.spawn.ProcessExitedException: process 2 terminated with signal SIGABRT

As well as an error importing one of the tests:
______ ERROR collecting tests/checkpointing/test_checkpoint_conversion.py ______
tests/checkpointing/test_checkpoint_conversion.py:59: in
@pytest.mark.skipif(
/home/richard-rutmann/.local/lib/python3.11/site-packages/_pytest/mark/structures.py:401: in call
store_mark(unwrapped_func, self.mark, stacklevel=3)
/home/richard-rutmann/.local/lib/python3.11/site-packages/_pytest/mark/structures.py:466: in store_mark
warnings.warn(MARKED_FIXTURE, stacklevel=stacklevel)
E pytest.PytestRemovedIn9Warning: Marks applied to fixtures have no effect
E See docs: https://docs.pytest.org/en/stable/deprecations.html#applying-a-mark-to-a-fixture-function

Comment on lines +54 to +56
# ("gpt2_train_num_steps_7_pp.yaml", "gpt2_warm_start_from_step_4_fsdp2.yaml", 4, 2),
# ("gpt2_train_num_steps_7_pp.yaml", "gpt2_warm_start_from_step_4_fsdp2_grad_accu.yaml", 4, 2),
# ("gpt2_train_num_steps_7_tp.yaml", "gpt2_warm_start_from_step_4_fsdp2.yaml", 4, 2),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first and the third commented-out configs are the same, right?

Comment on lines +54 to +56
# ("gpt2_train_num_steps_7_pp.yaml", "gpt2_warm_start_from_step_4_fsdp2.yaml", 4, 2),
# ("gpt2_train_num_steps_7_pp.yaml", "gpt2_warm_start_from_step_4_fsdp2_grad_accu.yaml", 4, 2),
# ("gpt2_train_num_steps_7_tp.yaml", "gpt2_warm_start_from_step_4_fsdp2.yaml", 4, 2),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that
("gpt2_train_num_steps_7_pp.yaml", "gpt2_warm_start_from_step_4_fsdp2.yaml", 4, 2),
is necessary since we already test
("gpt2_train_num_steps_7_pp_tp.yaml", "gpt2_warm_start_from_step_4_fsdp2.yaml", 8, 2),
which is the same setup + data parallelism, correct?

And since we have
("gpt2_train_num_steps_7_pp_tp.yaml", "gpt2_warm_start_from_step_4_grad_accu.yaml", 8, 1),
we can probably skip
# ("gpt2_train_num_steps_7_pp.yaml", "gpt2_warm_start_from_step_4_fsdp2_grad_accu.yaml", 4, 2),

Copy link
Collaborator

@rrutmann rrutmann left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work, thank you. A few tests are failing (see my comment), but aside from that, no major changes required from my side

@le1nux le1nux self-requested a review December 10, 2025 09:34
Copy link
Member

@le1nux le1nux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1st batch of comments.

lr_scheduler (Optional[LRScheduler], optional): The lr scheduler used during training. Defaults to None.
"""
self._model = model
self._model_parts = list(model) if isinstance(model, list) else [model]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self._model_parts = list(model) if isinstance(model, list) else [model]
self._model_parts = model if isinstance(model, list) else [model]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think, creating a new list here is saver in case an outside context accidentally changes the input list.

@staticmethod
def get_state_dict(app_state: AppState) -> dict[str, Any]:
"""Returns the state dict of the model in the AppState object.
"""Returns the flattened state dicts of the model parts in the AppState object.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

flattened keys or tensors?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess flattened keys. Though, I'm not sure if I would call it that. We are mapping from a list of dicts to a single dict. Flattened keys sounds more like flattening a dict of dicts.

dict[str, Any]: The state dict of the model in the AppState object.
"""
return get_model_state_dict(model=app_state.model)
return {k: v for sd in map(get_model_state_dict, app_state.model_parts) for k, v in sd.items()}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we sure that k is always unique across model parts? Should we maybe throw an exception if k is not unique?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is assumed that the model parts are distinct and thus have distinct. I'll modify the function to check this is fulfilled.

Comment on lines +196 to +204
else:
assert len(app_state.model_parts) == 1, "Expected a single model part for non-OptimizersList optimizer."
sd = get_optimizer_state_dict(
model=app_state.model_parts[0],
optimizers=app_state.optimizer,
# NOTE: Flattening is required for pipeline parallelism to work correctly.
# see https://github.com/pytorch/torchtitan/blob/b291ad662493b63d25b038a30a915082d3617baf/torchtitan/components/checkpoint.py#L193-L214
options=StateDictOptions(flatten_optimizer_state_dict=True),
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good!

Comment on lines +218 to 225
assert len(app_state.model_parts) == 1, "Expected a single model part for non-OptimizersList optimizer."
set_optimizer_state_dict(
model=app_state.model_parts[0],
optimizers=app_state.optimizer,
optim_state_dict=state_dict,
options=StateDictOptions(flatten_optimizer_state_dict=True),
)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given your comment above, I assume the else case can also be removed here?

component_config_type=component_config_type,
)
comp_config = component_config_type(**config_dict, strict=True)
comp_config = component_config_type.model_validate(config_dict, extra="forbid")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!

scheduled_pipeline: Pipeline | None = None,
):
if num_train_steps_done % evaluation_interval_in_steps == 0:
if num_train_steps_done % evaluation_interval_in_steps == 0 and num_train_steps_done > 0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here, we should add a note with the details regarding the error that we were experiencing otherwise.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a corresponding TODO.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants