Skip to content

Mixtral GPTQ with TP=2 not generating output #2728

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
SebastianBodza opened this issue Feb 2, 2024 · 17 comments · Fixed by #2760
Closed

Mixtral GPTQ with TP=2 not generating output #2728

SebastianBodza opened this issue Feb 2, 2024 · 17 comments · Fixed by #2760

Comments

@SebastianBodza
Copy link

In the new vllm 0.3 release mixtral with gptq does not generate any output anymore. Loading the model works fine, when calling the llm.generate it gets stuck.

from vllm import SamplingParams
import torch 

model_name = "TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ"
llm = vllm.LLM(model=model_name, quantization="gptq", dtype=torch.float16, tensor_parallel_size=2, max_model_len=16000, revision="gptq-4bit-32g-actorder_True", gpu_memory_utilization=0.75, enforce_eager=False)
llm.generate(formatted_prompt[0], sampling_params=SamplingParams(temperature=0.1, max_tokens=100))

Currently using:

  • python 3.11.*
  • vllm-0.3.0+cu123 -> tried also the current git repo
  • cuda 12.3

The worker seems to be stuck in the llm_engine.step() function / _run_workers call.

@hanzhi713
Copy link
Contributor

Can you try adding disable_custom_all_reduce=True? Just ruling out a possibility.

@adi
Copy link

adi commented Feb 3, 2024

I have the same issue with AWQ and this config:

llm = LLM(
    model="TheBloke/Mixtral-8x7B-Instruct-v0.1-AWQ",
    tensor_parallel_size=2,
    gpu_memory_utilization=0.75,
    max_model_len=256,
    disable_custom_all_reduce=True,
    enforce_eager=True
)

@adi
Copy link

adi commented Feb 3, 2024

This works fine for me:

llm = LLM(
    model="TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ",
    quantization="gptq",
    dtype=torch.float16,
    tensor_parallel_size=2,
    max_model_len=16384,
    revision="gptq-4bit-32g-actorder_True",
    gpu_memory_utilization=0.75,
    disable_custom_all_reduce=True,
    enforce_eager=True)

EDIT: Not all generations return even in this configuration. Example:

prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is"
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=256)
Prompt: 'Hello, my name is', Generated text: "a, and I'm new to this community.\n\nI'm here to ask a few questions about the game, and hopefully, find an answer to them.\n\nI've read that there is a way to switch between the 1st and 3rd person view, but I can't seem to find it. Is there such a feature, or am I missing something?\n\nAlso, how do I switch the way I walk? I see that I can either run or walk really slow, but I don't know how to do a regular walking speed. Is this another feature I missed, or is this something that's missing from the game?\n\nSorry if these are dumb questions, and thank you in advance for your time.\n\nEDIT: I've also noticed that some of the menu options are in a different language than my system's default, so I assume that it's a Japanese version of the game. Does anyone know if there is a way to change the language to English?"
Prompt: 'The president of the United States is', Generated text: ' making statements that are true, but this is not enough to make them convincing. On the subject of the FBI investigation into Russian interference in the 2016 presidential election, for instance, President Donald Trump said Wednesday that "we\'ll see what happens" and added, "I wish him luck."\n\nWhat Trump said is true, but it is not convincing because he is not behaving as if he believes what he is saying. Trump made the statements at a joint news conference with the prime minister of Kuwait, while the FBI investigation is led by his own appointee, Deputy Attorney General Rod Rosenstein, and is being conducted by a career prosecutor, Special Counsel Robert Mueller.\n\nTrump added that he has "nothing to hide." But he is behaving as if he has something to hide.\n\nThe president\'s critics have accused him of "obstructing justice" because he has fired people who were leading investigations of his campaign. The critics say Trump\'s firing of FBI Director James Comey, who is leading the investigation, was "obstruction of justice" because the president did so after Comey told Congress that he was investigating possible collusion between the president\'s campaign'
Prompt: 'The capital of France is', Generated text: ''
Prompt: 'The future of AI is', Generated text: 'haky/Getty Images\n\n, and it will be with us for a long time.\n\nBut it will be a future that we mold and shape, because what we do with AI is up to us. The choices that we make are informed by our thoughts and beliefs, and that means that if we choose to make AI more humane, then it can be.\n\nIn this post, I want to explore five AI trends that are making AI more humane.\n\n1. Diversity\n\nAn AI system that reflects our diversity is an AI system that reflects our humanity.\n\nThere are many things that we can do to encourage more diversity in AI. We can ensure that our AI teams are diverse, and that means that we have to work hard to increase diversity in the STEM fields. We can also create AI systems that recognize and respect differences in gender, ethnicity, sexual orientation, and ability.\n\nBut diversity in AI isn’t just about people. It’s also about the data that we use to train our AI systems. Diverse data will result in AI systems that are more humane.\n\nDiverse data is representative data. It ensures that an AI system will work well for a wide'

@ivysdad
Copy link

ivysdad commented Feb 4, 2024

I have the same issue with AWQ and this config:

llm = LLM(
    model="TheBloke/Mixtral-8x7B-Instruct-v0.1-AWQ",
    tensor_parallel_size=2,
    gpu_memory_utilization=0.75,
    max_model_len=256,
    disable_custom_all_reduce=True,
    enforce_eager=True
)

I think there is an issue with that awq quant specifically, try using this one: https://huggingface.co/casperhansen/mixtral-instruct-awq

it worked for me! just be sure to set enforce_eager=True like you did before

@SebastianBodza
Copy link
Author

@hanzhi713 Thanks! Works with the disable_custom_all_reduce command

@hanzhi713
Copy link
Contributor

@SebastianBodza Which GPU model are you using?

@SebastianBodza
Copy link
Author

@hanzhi713 2x RTX 3090 with Cuda 12.3

@hanzhi713
Copy link
Contributor

@SebastianBodza Can you try this potential fix? #2760

Remember to recompile vLLM from source by pip install -e .

@SebastianBodza
Copy link
Author

SebastianBodza commented Feb 5, 2024

Still not working for me. GPUs stuck at 150W usage and not returning. I don't know if it is relevant, but I have to run NCCL_P2P_DISABLE=1 to load the model

@SinanAkkoyun
Copy link

I think there is an issue with that awq quant specifically, try using this one: https://huggingface.co/casperhansen/mixtral-instruct-awq

it worked for me! just be sure to set enforce_eager=True like you did before

For me, this also resolved the issue

@hanzhi713
Copy link
Contributor

Still not working for me. GPUs stuck at 150W usage and not returning. I don't know if it is relevant, but I have to run NCCL_P2P_DISABLE=1 to load the model

@SebastianBodza What error do you observe when you don't set NCCL_P2P_DISABLE=1? This might be relevant since if you can't let NCCL use P2P, custom all reduce shouldn't use it too, and custom all reduce can't function without P2P.

@SebastianBodza
Copy link
Author

When not setting NCCL_P2P_DISABLE=1 the Model loading freezes. Just like in #1801

@hanzhi713
Copy link
Contributor

@SebastianBodza I see. Then it's expected that it will also freeze with custom all reduce enabled. The underlying cause might be like #1801 in which the driver is buggy.

@SebastianBodza
Copy link
Author

@hanzhi713 Thanks for the fixes. I am not too sure if they were necessary, however SinanAkkoyun seems to approve them, sorry for that. Could we add a check for P2P and implement a fallback with disable_custom_all_reduce enabled or throwing an error?

For me the vllm version 0.3.0 is working with nvidia drivers 535.154.05.
Anybody relying on the official nvidia stream for rhel 9 just downgrade.

@SinanAkkoyun
Copy link

@SebastianBodza Hi, I didn't approve any changes, I just approved that the other awq model provided + enforce_eager=True works for me, does it for you?

@hanzhi713
Copy link
Contributor

@hanzhi713 Thanks for the fixes. I am not too sure if they were necessary, however SinanAkkoyun seems to approve them, sorry for that. Could we add a check for P2P and implement a fallback with disable_custom_all_reduce enabled or throwing an error?

For me the vllm version 0.3.0 is working with nvidia drivers 535.154.05. Anybody relying on the official nvidia stream for rhel 9 just downgrade.

@SebastianBodza Unfortunately, I don't think there's a good way to detect that. We already checked P2P support via cuda runtime API (NCCL likely uses the same check), and the check passed. Basically, the problem is that the runtime reports that it supports P2P, but the underlying implementation is buggy.

@andrewgross
Copy link

andrewgross commented Feb 13, 2024

I ran in to this with exllamav2 as well, the maintainer added a fix that just did a quick check to see if it was safe to move tensors between devices, and used that to determine if it would move directly (or move via CPU). The Nvidia drivers seem to have consistent issues on 4090s properly reporting this info.

https://github.com/turboderp/exllamav2/blob/1c67f97f3d2a968605a9c31ab791a05c85bb7879/exllamav2/compat.py#L10

turboderp-org/exllamav2#85

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants