Skip to content

Torchtitan evals are slow because torch.compilerecompiles for every unique input sequence length. #588

@pefontana

Description

@pefontana

Torchtitan evals are slow because torch.compilerecompiles for every unique input sequence length.

Each eval item has a different sequence length, and the eval harness processes them one at a time. Every newshape triggers a torch.compile recompilation, which takes seconds per item.

torch.compile is hardcoded on in python/python/psyche/modelsttitan.py:305:

job_config.compile.enable = True
job_config.compile.components = ["model", "loss"]

The eval harness (shared/eval/src/harness.rs:497-510) calls model.forward() with a different sequence length for each eval

let request_tensor = Tensor::from_slice(&full_request)
    .to(options.model.device())
    .unsqueeze(0);
let (logits, _) = {
    let _no_grad = tch::no_grad_guard();
    options.model.forward(&request_tensor, ...);
  };

Tasks with acc_uncond (arc_easy, arc_challenge, mmlu_cf, piqa) do an additional forward pass per choice, making it even slower.

  • During training, the client runs evals incrementally (limit=10per iteration in shared/client/src/state/evals.rs:282),
    but each batch still hits new sequence lengths and triggers recompilation. Eval progress is very slow.
  • Standalone evals with the evaluate example are also really slow with Torchtitan on full datasets.
  • This only affects Torchtitan. Native Rust models (HfLlama,HfDeepseek) and HfAuto don't use torch.compile.

Possible solutions:

  1. Disable torch.compile during eval inference (e.g. check if labels is None and skip compile, or add a flag)
  2. Pad eval inputs to a set of fixed bucket sizes to reduce uniqueshapes

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions