Skip to content

[Feature]: MLPSpeculator Tensor Parallel support #5809

@njhill

Description

@njhill

🚀 The feature, motivation and pitch

MLPSpeculator-based speculative decoding was recently added in #4947, but the initial integration only covers single GPU usage.

There will soon be "speculator" models available for larger target models that require multiple GPUs so we would like to ensure that TP can be used.

The first part of this issue would be testing it out in conjunction with #5414 and making necessary adjustments so that it will work with TP=1 for the speculator and TP=N for the target model.

Following this we can look at having the speculator itself run with TP>1, but that may be more involved since it will require some distributed coordination of the sampling of each speculated token in the MLPSpeculator loop. It might be possible to avoid additional communication here by the having the sampler used by the speculator model use a dedicated torch.Generator for its sampling and doing this sampling in tandem across the ranks.

@JRosenkranz already used VocabParallelEmbedding in the implementation so the model layers themselves should work fine.

cc @cadedaniel @sirejdua @JRosenkranz @tdoublep

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions