As mentioned in #239, I have come to believe that the way I originally implemented automatic batch size determination in Heretic is flawed. There are several problems with the current mechanism:
- It wastes time with inference on test prompts, then throws away the results.
- It tries to determine a global batch size that applies to all operations, even though generating responses requires more VRAM than generating logprobs.
- It fails to adapt to changing memory pressure on the system. If another process starts to take up VRAM, Heretic will crash with an OOM error.
- It starts from a batch size of 1 and incrementally doubles the batch size until it OOMs to determine the optimum. This is problematic for two reasons:
- It can spend a lot of time doing inference and re-generating compute graphs for low batch sizes, even though the maximum of 128 often works.
- It might miss major performance gains for sizes that aren't powers of 2. For example, a system might be able to support a batch size of 100, but not 128. Choosing 100 instead of 64 can give huge performance gains with the default configuration, because there are 100 evaluation prompts per set.
Here is how I think it should work instead:
- Every inference method (
get_responses, get_residuals, get_logprobs) has its own associated batch size, initialized to a value like 128.
- There is no batch size determination process on program startup.
- Instead, when an inference method is called, it just runs with its current batch size. If inference throws an OOM error, it reduces the batch size and tries again. The reduced batch size is stored in a
Model instance field, so it is remembered for the next invocation.
This avoids wasted work, optimizes the batch size separately for different methods, and dynamically adapts to changing memory availability at runtime.
While this discards the "benchmarking" aspect and thus might lead to situations where larger batch sizes are chosen while a smaller batch size could improve overall throughput, I have found those situations to be extremely rare in practice. For almost every (non-CPU) run I have ever seen, you get the highest generation speed by simply choosing the largest batch size that fits into your VRAM constraints.
As mentioned in #239, I have come to believe that the way I originally implemented automatic batch size determination in Heretic is flawed. There are several problems with the current mechanism:
Here is how I think it should work instead:
get_responses,get_residuals,get_logprobs) has its own associated batch size, initialized to a value like 128.Modelinstance field, so it is remembered for the next invocation.This avoids wasted work, optimizes the batch size separately for different methods, and dynamically adapts to changing memory availability at runtime.
While this discards the "benchmarking" aspect and thus might lead to situations where larger batch sizes are chosen while a smaller batch size could improve overall throughput, I have found those situations to be extremely rare in practice. For almost every (non-CPU) run I have ever seen, you get the highest generation speed by simply choosing the largest batch size that fits into your VRAM constraints.