Torchtitan evals are slow because torch.compilerecompiles for every unique input sequence length.
Each eval item has a different sequence length, and the eval harness processes them one at a time. Every newshape triggers a torch.compile recompilation, which takes seconds per item.
torch.compile is hardcoded on in python/python/psyche/modelsttitan.py:305:
job_config.compile.enable = True
job_config.compile.components = ["model", "loss"]
The eval harness (shared/eval/src/harness.rs:497-510) calls model.forward() with a different sequence length for each eval
let request_tensor = Tensor::from_slice(&full_request)
.to(options.model.device())
.unsqueeze(0);
let (logits, _) = {
let _no_grad = tch::no_grad_guard();
options.model.forward(&request_tensor, ...);
};
Tasks with acc_uncond (arc_easy, arc_challenge, mmlu_cf, piqa) do an additional forward pass per choice, making it even slower.
- During training, the client runs evals incrementally (limit=10per iteration in shared/client/src/state/evals.rs:282),
but each batch still hits new sequence lengths and triggers recompilation. Eval progress is very slow.
- Standalone evals with the evaluate example are also really slow with Torchtitan on full datasets.
- This only affects Torchtitan. Native Rust models (HfLlama,HfDeepseek) and HfAuto don't use torch.compile.
Possible solutions:
- Disable
torch.compile during eval inference (e.g. check if labels is None and skip compile, or add a flag)
- Pad eval inputs to a set of fixed bucket sizes to reduce uniqueshapes
Torchtitan evals are slow because
torch.compilerecompiles for every unique input sequence length.Each eval item has a different sequence length, and the eval harness processes them one at a time. Every newshape triggers a torch.compile recompilation, which takes seconds per item.
torch.compileis hardcoded on inpython/python/psyche/modelsttitan.py:305:The eval harness (shared/eval/src/harness.rs:497-510) calls model.forward() with a different sequence length for each eval
Tasks with acc_uncond (arc_easy, arc_challenge, mmlu_cf, piqa) do an additional forward pass per choice, making it even slower.
but each batch still hits new sequence lengths and triggers recompilation. Eval progress is very slow.
Possible solutions:
torch.compileduring eval inference (e.g. check if labels is None and skip compile, or add a flag)