Skip to content

tiny record: add varlen + ttt stuff#67

Open
samacqua wants to merge 1 commit into
qlabs-eng:mainfrom
samacqua:varlen
Open

tiny record: add varlen + ttt stuff#67
samacqua wants to merge 1 commit into
qlabs-eng:mainfrom
samacqua:varlen

Conversation

@samacqua
Copy link
Copy Markdown
Contributor

tiny record: document-based shuffling + varlen attention + ttt (3.272 BPB)

Adds a few things, intended hopefully as a record but also to start discussion on clarifying rules. Experiments/code are a bit scattered, but on the openai parameter golf I just submitted a record that was beat 18' before I submitted it, so I'm submitting now to not make the same mistake :)

# train first, this doesn't do ttt
torchrun --nproc_per_node=8 tiny/train.py --varlen --run-name varlen

# then eval
torchrun --nproc_per_node=2 tiny/eval.py --varlen --run-name varlen
=======================================
Standard Loss: 3.322792 | BPB: 1.079893
Strided  Loss: 3.301744 | BPB: 1.072715
TTT      Loss: 3.272587 | BPB: 1.063242
Delta (strided-TTT): loss=0.029157 bpb=0.009473
TTT eval time: 3453.9s
=======================================

Things added

Document-based shuffling

Old data shuffler had a fixed set of batches. Every epoch it would shuffle the batch order, but the batches would not change. Now, every batch shuffles the documents, concatenates them, and then batches them. I found using this during training on its own was worth ~0.003 BPB.

Variable length attention

Instead of splitting the training sequence into [B, L] chunks where all tokens attend to previous tokens (including in other documents), use variable length attention to only attend to tokens within the document. This reduces the training load (the model doesn't have to learn to ignore unrelated pre-BOS tokens), but it also makes things faster, e.g. 10 short (100-token) docs packed into a 1k-token buffer cost proportional to 100 * 100**2 = 1M attention FLOPs vs 10 * 1000**2 = 10M with dense attention

I need to test how much this, on its own, helps.

Strided eval

For a document of length 8192, instead of splitting it into arbitrary 2048-length chunks (current behavior), score the first 2048 tokens, then score tokens 2048-2560 conditioned on the full context before them, .... This avoids artificial loss spikes where the model is given tokens representing the middle of a document without seeing the context before it.

Test-time training

This is the same as a parameter golf record where I train a LoRA for each sequence at eval time. Before predicting token j, you train on tokens i<j -- no leakage from future tokens, no dependence between sequences in a validation set. To make this computationally manageable, we train in chunks: so we only train on the context every chunk_size tokens. As you see below, there is a tradeoff between speed and loss (more chunks = better loss but slower).

Below are two plots. Each dot is one TTT hparam configuration over a sweep. The best config for each chunk size is colored in. On the left is validation loss v. time, and on the right is validation loss versus chunk size.

Note that the strided window improvement saturates quickly, but the TTT improvement continues to improve. Also note that the average sequence length in the validation set is ~700 tokens, so the loss gain per token of supervision here is very high.

1 hour results

These can easily be ported to the 1 hour setting (except for TTT you need to figure out how to join this with ensembles -- I think they could be synergistic).

Strided (stride=512) | Val BPB: 1.051322 | Val Loss: 3.231245
[equal] Standard | Val BPB: 1.042844 | Val Loss: 3.208622
[equal] Strided (stride=512) | Val BPB: 1.036539 | Val Loss: 3.185811
[weighted] Standard | Val BPB: 1.043151 | Val Loss: 3.209579
[weighted] Strided (stride=512) | Val BPB: 1.036846 | Val Loss: 3.186754

Note that these aren't directly comparable to the current records. They are scored on the same exact tokens, but during validation even the standard results don't have as many artificial mid-doc breakpoints which increase the loss. Not submitting this as a record, but I had the results so figured I'd add.

Things to consider

Adding the strided eval makes comparison to previous results unfair. At the very least, previous results should be rescored using this to avoid artificial mid-sequence loss spikes.

Personally, I think that the "validation set" should be a collection of sequences of tokens (documents) that should be processed independently. This would rule out most of the TTT approaches in the openai parameter golf (dependence between sequences), and would allow strided eval/doc-based TTT as outlined here. I think the alternative -- enforcing that the batches are a given sequence length -- is arbitrary and limits interesting research (especially longer-context stuff).

I'm curious people's thoughts.

@ChinmayK0607
Copy link
Copy Markdown
Contributor

I think the alternative -- enforcing that the batches are a given sequence length -- is arbitrary and limits interesting research (especially longer-context stuff).

This is super interesting, and definitely something that can be worked on.

@akshayvegesna
Copy link
Copy Markdown
Contributor

This PR is super interesting, thank you for submitting it!

Thus far, all the PRs we have merged improve both val loss and in principle, should also improve downstream benchmarks. We plan to add those soon-ish. We want to stay disciplined about not hill climbing validation loss while not actually improving the downstream benchmark metrics -- basically so our data efficiency improvements have a chance of applying to downstream benchmarks as well. So we won't merge PRs that improve val loss, but won't in principle improve downstream benchmarks.

So I would suggest:

  • Document-based shuffling and Variable length attention -- we should test this on the tiny track standalone and see if they have a gain. If so, let's merge.
  • Strided Eval -- this is a nice improvement to evaluation, but it does not actually improve model performance, so we shouldn't count it as the record. We can track it though. It won't help the downstream benchmarks.
  • TTT -- this is a very cool idea and a nice addition. But given it doesn't help with the normal downstream benchmarks, I would say we merge this research/ttt/ and we can further develop it there.

Does that makes sense? Lmk if you have any questions

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants