tiny record: add varlen + ttt stuff#67
Conversation
This is super interesting, and definitely something that can be worked on. |
|
This PR is super interesting, thank you for submitting it! Thus far, all the PRs we have merged improve both val loss and in principle, should also improve downstream benchmarks. We plan to add those soon-ish. We want to stay disciplined about not hill climbing validation loss while not actually improving the downstream benchmark metrics -- basically so our data efficiency improvements have a chance of applying to downstream benchmarks as well. So we won't merge PRs that improve val loss, but won't in principle improve downstream benchmarks. So I would suggest:
Does that makes sense? Lmk if you have any questions |
tiny record: document-based shuffling + varlen attention + ttt (3.272 BPB)
Adds a few things, intended hopefully as a record but also to start discussion on clarifying rules. Experiments/code are a bit scattered, but on the openai parameter golf I just submitted a record that was beat 18' before I submitted it, so I'm submitting now to not make the same mistake :)
Things added
Document-based shuffling
Old data shuffler had a fixed set of batches. Every epoch it would shuffle the batch order, but the batches would not change. Now, every batch shuffles the documents, concatenates them, and then batches them. I found using this during training on its own was worth ~0.003 BPB.
Variable length attention
Instead of splitting the training sequence into
[B, L]chunks where all tokens attend to previous tokens (including in other documents), use variable length attention to only attend to tokens within the document. This reduces the training load (the model doesn't have to learn to ignore unrelated pre-BOS tokens), but it also makes things faster, e.g. 10 short (100-token) docs packed into a 1k-token buffer cost proportional to100 * 100**2 = 1Mattention FLOPs vs10 * 1000**2 = 10Mwith dense attentionI need to test how much this, on its own, helps.
Strided eval
For a document of length 8192, instead of splitting it into arbitrary 2048-length chunks (current behavior), score the first 2048 tokens, then score tokens 2048-2560 conditioned on the full context before them, .... This avoids artificial loss spikes where the model is given tokens representing the middle of a document without seeing the context before it.
Test-time training
This is the same as a parameter golf record where I train a LoRA for each sequence at eval time. Before predicting token
j, you train on tokensi<j-- no leakage from future tokens, no dependence between sequences in a validation set. To make this computationally manageable, we train in chunks: so we only train on the context everychunk_sizetokens. As you see below, there is a tradeoff between speed and loss (more chunks = better loss but slower).Below are two plots. Each dot is one TTT hparam configuration over a sweep. The best config for each chunk size is colored in. On the left is validation loss v. time, and on the right is validation loss versus chunk size.
Note that the strided window improvement saturates quickly, but the TTT improvement continues to improve. Also note that the average sequence length in the validation set is ~700 tokens, so the loss gain per token of supervision here is very high.
1 hour results
These can easily be ported to the 1 hour setting (except for TTT you need to figure out how to join this with ensembles -- I think they could be synergistic).
Note that these aren't directly comparable to the current records. They are scored on the same exact tokens, but during validation even the standard results don't have as many artificial mid-doc breakpoints which increase the loss. Not submitting this as a record, but I had the results so figured I'd add.
Things to consider
Adding the strided eval makes comparison to previous results unfair. At the very least, previous results should be rescored using this to avoid artificial mid-sequence loss spikes.
Personally, I think that the "validation set" should be a collection of sequences of tokens (documents) that should be processed independently. This would rule out most of the TTT approaches in the openai parameter golf (dependence between sequences), and would allow strided eval/doc-based TTT as outlined here. I think the alternative -- enforcing that the batches are a given sequence length -- is arbitrary and limits interesting research (especially longer-context stuff).
I'm curious people's thoughts.