Skip to content

Conversation

pierrestock
Copy link
Contributor

@pierrestock pierrestock commented Dec 11, 2023

Adding support for mistralai/Mixtral-8x7B-v0.1 and mistralai/Mixtral-8x7B-Instruct-v0.1 models as described in our blogpost.

This is joint work between @zhuohan123, @WoosukKwon from the vLLM project and Mistral AI.

It integrates fast sparse mixture of experts kernels from the Megablocks project.

Copy link
Member

@zhuohan123 zhuohan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thank you for your contribution and the official support of vLLM from Mistral AI!!

@zhuohan123 zhuohan123 merged commit b5f882c into vllm-project:main Dec 11, 2023
hidden_states: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
hidden_states = self.norm(hidden_states)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Can we do this in forward?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I merged the PR asap from my phone to let people use the model. I think one other small thing is to add mixtral to the supported model list. I am AFK now, can you help fix this if possible?

@draganjovanovich
Copy link

I installed from latest main, installed stk, megablocks, latest flash_attn, transformers etc...
And got following error:

  File "/home/ubuntu/mambaforge/lib/python3.10/site-packages/vllm/engine/ray_utils.py", line 32, in execute_method
    return executor(*args, **kwargs)
  File "/home/ubuntu/mambaforge/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/ubuntu/mambaforge/lib/python3.10/site-packages/vllm/worker/worker.py", line 88, in profile_num_available_blocks
    self.model_runner.profile_run()
  File "/home/ubuntu/mambaforge/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/ubuntu/mambaforge/lib/python3.10/site-packages/vllm/worker/model_runner.py", line 321, in profile_run
    self.execute_model(seqs, kv_caches)
  File "/home/ubuntu/mambaforge/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/ubuntu/mambaforge/lib/python3.10/site-packages/vllm/worker/model_runner.py", line 279, in execute_model
    hidden_states = self.model(
  File "/home/ubuntu/mambaforge/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ubuntu/mambaforge/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/mambaforge/lib/python3.10/site-packages/vllm/model_executor/models/mixtral.py", line 488, in forward
    hidden_states = layer(
  File "/home/ubuntu/mambaforge/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ubuntu/mambaforge/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/mambaforge/lib/python3.10/site-packages/vllm/model_executor/models/mixtral.py", line 439, in forward
    r = self.block_sparse_moe(self.ffn_norm(h))
  File "/home/ubuntu/mambaforge/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ubuntu/mambaforge/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/mambaforge/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/ubuntu/mambaforge/lib/python3.10/site-packages/vllm/model_executor/models/mixtral.py", line 353, in forward
    x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins,
  File "/home/ubuntu/mambaforge/lib/python3.10/site-packages/torch/autograd/function.py", line 539, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/ubuntu/mambaforge/lib/python3.10/site-packages/stk/backend/autocast.py", line 28, in decorate_fwd
    return fwd(*args, **kwargs)
TypeError: PaddedGatherOp.forward() takes 6 positional arguments but 7 were given

@WoosukKwon
Copy link
Collaborator

@draganjovanovich Thanks for reporting the error! Please re-install megablocks with pip install git+https://github.com/stanford-futuredata/[email protected].

@draganjovanovich
Copy link

draganjovanovich commented Dec 11, 2023

Np, I tried installing git+https://github.com/stanford-futuredata/[email protected], but it fails in the build step.

/tmp/pip-install-1xstsqs1/grouped-gemm_8825c4917c174628acfa2bd08ed8c9d2/third_party/cutlass/include/cute/atom/copy_traits_sm90_tma.hpp(587): error: ide
ntifier "cuTensorMapEncodeTiled" is undefined                                                                                                                
          CUresult result = cuTensorMapEncodeTiled( 

Failed to build grouped_gemm

Now, I created new env, and start from scratch. I will comment if success.

@tgale96
Copy link

tgale96 commented Dec 11, 2023

Hi, it looks like the errors you're getting are from megablocks. Can you share more details on your environment? The latest issue looks like it may be because of the CUDA toolkit version you're using (maybe lets start a separate issue?).

@Ededu1984
Copy link

I tried to install Megablocks and I got this error

CalledProcessError: Command 'pip --disable-pip-version-check install git+https://github.com/stanford-futuredata/[email protected]' returned non-zero exit status 1.

@tgale96
Copy link

tgale96 commented Dec 11, 2023

Hi Edson, you should be able to run Mixtral with by just installing MegaBlocks with pip install megablocks now. Could you give that a try?

See related issues:
#2017
#2022

hongxiayang pushed a commit to hongxiayang/vllm that referenced this pull request Feb 13, 2024
Co-authored-by: Pierre Stock <[email protected]>
Co-authored-by: Zhuohan Li <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants