Skip to content

Conversation

ElizaWszola
Copy link
Contributor

@ElizaWszola ElizaWszola commented Oct 28, 2024

This PR enables HQQ support (float zero points) in float16 kernels. This PR supports 4-bit quantization and group_size 64.

unit tests:

pytest tests/kernels/test_marlin_gemm.py -k test_hqq_marlin_gemm

offline inference model:

nm-testing/Llama-3.2-1B-Instruct-HQQ

or

from vllm import LLM
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, HqqConfig

model_path = "unsloth/Llama-3.2-1B-Instruct"
quant_config = HqqConfig(nbits=4, group_size=64, axis=1)

model = AutoModelForCausalLM.from_pretrained(model_path,
                                             torch_dtype=torch.float16,
                                             cache_dir='.',
                                             device_map="cuda:0",
                                             quantization_config=quant_config,
                                             low_cpu_mem_usage=True)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

qp = "llama-3.2-1b-instruct_hqq"
model.save_pretrained(qp)
tokenizer.save_pretrained(qp)

llm = LLM(model=qp)

Signed-off-by: ElizaWszola <[email protected]>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@ElizaWszola ElizaWszola changed the title HQQ support [Model] HQQ support Oct 30, 2024
marlin_tile_size=self.marlin_tile_size)


class HQQQweightParameter(PackedvLLMParameter):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dsikka @mgoin @LucasWilkinson what do you guys think of inheriting from the PackedParameter to handle the details of HQQ? I think this is a reasonable approach.

Perhaps this should live inside hqq_marlin.py

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ElizaWszola would like to understand how we're extending this class to see if we can clean-up some of the weight loading logic.

Copy link
Contributor Author

@ElizaWszola ElizaWszola Oct 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dsikka HQQ qweights have a data format where each row corresponds to 2 groups. I'm first reshaping HQQ qweights (and correspoinding scales/zp) to their actual shapes in qweight multiplication, and then load. Additionally, I store shard offsets and sizes because they are needed for unpacking HQQ from 4-bits to 8-bits (I need to unpack shard-by-shard, so I must know where they are). I'll add a comment on this in the code.

Copy link
Collaborator

@robertgshaw2-redhat robertgshaw2-redhat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Eliza! Nice job avoiding any changes to llama.py. We have much better encapsulated the HQQ logic.

Questions:

  • do we have any end-to-end perf benchmarks?
  • do we have any lm-eval runs for correctness?
  • can you add a test for weight loading? @dsikka can you point out an example

The only other piece is that we have not adapted the Kernel abstraction proposed by @LucasWilkinson. @LucasWilkinson do you think this would be appropriate?

See https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/gptq_marlin.py for an example of the Kernel abstraction

@ElizaWszola
Copy link
Contributor Author

ElizaWszola commented Oct 31, 2024

  • do we have any end-to-end perf benchmarks?
  • do we have any lm-eval runs for correctness?
  • can you add a test for weight loading?

@robertgshaw2-neuralmagic I've run some lm-eval: I'll add the relevant configs as soon as it's verified that the results look good. I can also add other mentioned tests.

Copy link
Contributor

@dsikka dsikka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

touching base offline about parameter testing



# Zero points and scales in HQQ must also be reshaped to their actual shapes.
class HQQZeroScaleParameter(GroupQuantScaleParameter):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like we just need this class to reshape the weight before weight loading? We could just add an optional property to do this in the parent classes so that we dont have to maintain another parameter.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any other parameter in the code that already gets reshaped like this?

@ElizaWszola
Copy link
Contributor Author

/ready

@robertgshaw2-redhat robertgshaw2-redhat added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 6, 2024
@mgoin mgoin changed the title [Model] HQQ support [Model][Quantization] HQQ support through Marlin kernel expansion Nov 14, 2024
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically looks good to me, although good for @tlrmchlsmth to go through as well

Copy link
Member

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ElizaWszola When you merge latest main, could you sanity check the wheel size before landing?

Copy link
Member

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm -- just need to address nits and merge and then I think it's g2g

Copy link

mergify bot commented Nov 18, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ElizaWszola.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 18, 2024
@mergify mergify bot removed the needs-rebase label Nov 19, 2024
@tlrmchlsmth
Copy link
Member

failures look related to #10456

Copy link
Member

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me!

@tlrmchlsmth
Copy link
Member

@ElizaWszola When you merge latest main, could you sanity check the wheel size before landing?

BTW just checked this, and check-wheel size goes from 189.25 MB -> 190.04 MB

@simon-mo simon-mo merged commit b00b33d into vllm-project:main Nov 19, 2024
69 of 71 checks passed
sleepwalker2017 pushed a commit to sleepwalker2017/vllm that referenced this pull request Dec 13, 2024
@JaheimLee
Copy link

JaheimLee commented Jan 14, 2025

Hi, I use hqq model with lora. But it will raise ValueError(f"Unsupported base layer: {base_layer}") from here. Why isn't 'B' in base_layer even though it was expanded from marlin? Is it a bug? @ElizaWszola @tlrmchlsmth @mgoin

@mgoin
Copy link
Member

mgoin commented Jan 14, 2025

I'm not sure how that function works, @varun-sundar-rabindranath @jeejeelee would you have an idea?

@varun-sundar-rabindranath
Copy link
Contributor

looks like the weights are not registered as B in HQQ Marlin ?
hqq_marlin.py

layer.register_parameter("W_q", qweight)

marlin.py
layer.register_parameter("B", qweight)

@varun-sundar-rabindranath
Copy link
Contributor

I have a PR for this #12090
@JaheimLee I tested that it works using a random but compatible LoRA adapter I found in HF.
Can you share the model, adapter, and command that failed for you ? I can test with that also.
@mgoin fyi

anko-intel pushed a commit to HabanaAI/vllm-fork that referenced this pull request Feb 10, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants