[feat] add support for D2SD-mode VP-Drafter training#2
Conversation
|
Thanks for the implementation. A few thoughts on the D²SD approach itself: |
|
@zdaxie Thank you for your interest in our work and for raising these very valuable points. We find your comments highly insightful! Regarding the first question, we did not strictly control the compute cost of the two experiments to be exactly the same within a single iteration. In other words, what we actually controlled was the same verification budget, rather than the same per-iteration compute cost. Since D²SD indeed introduces a higher draft cost due to the second draft pass, in our experimental comparison we report not only the single-step accepted length, but also the end-to-end speedup. Our goal is to examine whether the additional draft cost introduced by the second draft pass would offset the overall end-to-end speedup gained from the increase in TPF brought by D²SD. Regarding the second question, we acknowledge that since the verification budget of D²SD is significantly higher than that of DFlash or MTP, its overall throughput will noticeably decrease under high-concurrency settings. The additional cost introduced by the second draft pass in real serving frameworks is also an issue we have been actively investigating. We plan to adapt D²SD to SGLang and vLLM in the near future, and we expect this method to be mainly effective in low-concurrency scenarios. Regarding the third question, we allocate the second-draft budget to branching mainly to exploit the parallelism of the second draft pass. We previously tried improving single-sequence quality by resampling from the most likely rejection boundary. However, the additional draft cost introduced by the second draft pass would offset the end-to-end speedup gained from the increase in TPF. By contrast, starting parallel drafts from multiple possible rejection boundaries can significantly improve TPF without substantially increasing the draft cost, especially in low-concurrency settings. As for further comparisons, after completing the adaptation to serving frameworks, we plan to conduct experiments under high-concurrency settings and compare D²SD end-to-end against most existing methods, in order to better identify the regime where D²SD has an advantage. Finally, thank you again for your constructive suggestions on our work, and we also sincerely appreciate the significant contributions that DSpark has made to the speculative decoding community. |
Motivation
This PR adds training support for the VP-Drafter used in D2SD (Dual Diffusion Draft Speculative Decoding). D2SD extends DFlash by using a first DFlash draft to estimate likely rejection boundaries, then training a second variable-prefix drafter to re-anchor at selected prefixes and generate alternative continuations.
The key training requirement is different from standard DFlash: the drafter must learn from variable-length visible prefixes instead of always seeing only the anchor token followed by masks. This PR implements that behavior as a DFlash training-mode branch.
References:
Modifications
Added D2-style feature support to dflash training.
The D2 mode samples a variable visible prefix length per sampled anchor block.
Prefix tokens are fed as real token embeddings; suffix positions remain masked and contribute to loss.
Added D2 prefix-length sampling controlled by
d2_prefix_weight_base.Added D2-specific loss masking so visible prefix positions are excluded from supervision.
Added loss decay offsets so exponential decay starts from the first masked suffix position.
Wired Qwen3 and Gemma4 DSpark/dflash models to read
enable_d2_featureandd2_prefix_weight_basefrom draft config.Wired Qwen3 and Gemma4 draft config builders to propagate D2 feature config from
model_args.Enabled D2 feature in all dflash configs:
config/dflash/dflash_qwen3_4b.pyconfig/dflash/dflash_qwen3_8b.pyconfig/dflash/dflash_qwen3_14b.pyconfig/dflash/dflash_gemma4_12b.pyAdded shared helper utilities for D2 prefix sampling, D2 noise embedding construction, and D2 eval-mask construction.
Kept the original DSpark/dflash training path unchanged when
enable_d2_feature=False.