-
-
Notifications
You must be signed in to change notification settings - Fork 10.6k
Description
🚀 The feature, motivation and pitch
MLPSpeculator
-based speculative decoding was recently added in #4947, but the initial integration only covers single GPU usage.
There will soon be "speculator" models available for larger target models that require multiple GPUs so we would like to ensure that TP can be used.
The first part of this issue would be testing it out in conjunction with #5414 and making necessary adjustments so that it will work with TP=1 for the speculator and TP=N for the target model.
Following this we can look at having the speculator itself run with TP>1, but that may be more involved since it will require some distributed coordination of the sampling of each speculated token in the MLPSpeculator loop. It might be possible to avoid additional communication here by the having the sampler used by the speculator model use a dedicated torch.Generator
for its sampling and doing this sampling in tandem across the ranks.
@JRosenkranz already used VocabParallelEmbedding
in the implementation so the model layers themselves should work fine.