Skip to content

Conversation

HandH1998
Copy link
Contributor

@HandH1998 HandH1998 commented Jun 3, 2024

We have proposed a W4A8 quantization solution QQQ and integrated it into vllm. QQQ can not only achieve the similar performance of the leading W4A8, W8A8, and W4A16 quantization methods but also significantly accelerate inference—achieving up to 2.24x, 2.10x, and 1.25x speed boosting compared to FP16, W8A8, and W4A16(Marlin), respectively.

News or Update

  • [2024/06/17] Update!!! We release the QQQ paper on arXiv.
  • [2024/06/07] We fuse dynamic activation quantization into QQQLinearMethod to support QQQ for all the models. Note that this will reduce some inference performance. You should use this repo to reproduce the inference speed results of our paper.
  • [2024/06/03] We integrate QQQ into vLLM and release the code.

Usage

You can export the quantized model weights with this repo QQQ(only support llama for now). Our paper will be published on arXiv soon. Here is an offline inference example.

from vllm import LLM, SamplingParams

# Sample prompts.
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is a",
    "A pig",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
model = your_quantized_model_path
tokenizer = your_tokenizer_path

# Create an LLM.
llm = LLM(
    model=model,
    tokenizer=tokenizer,
)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

Experiments

Here we provide our experiment results using this repo.

Settings

  • vllm v0.4.1
  • 1 A100 80G
  • CUDA 11.8

Model Performance

We evaluated the model performance on WikiText2 and five zero-shot tasks.
model_performace

Throughput

We conducted the same-batch throughput comparison of quantized LLaMA-2 models under various batch sizes. The input sequence length is 1024 and the output sequence length is 128.
speedup

W4A8 GEMM Performance

We implement the W4A8 GEMM based on Marlin GEMM. Thanks to their great work! Here is the speedup over PyTorch FP16 GEMM (Calling CUTLASS) of all GEMMs under different numbers of input tokens. The weight matrix size is (N=8192, K=21760).
gemm_performance

@robertgshaw2-redhat robertgshaw2-redhat self-assigned this Jun 3, 2024
@robertgshaw2-redhat
Copy link
Collaborator

Thanks for the PR!

The team for NM will review this. We are going to have to split this work out into a few pieces.

Will get back to you

@robertgshaw2-redhat
Copy link
Collaborator

@HandH1998 I synced up with @comaniac

This PR does 2 things:

  1. Introduces W4A8 GEMM for QQQLinearMethod
  2. Introduces the concept of A8 compute for layers that are not "Linear"

For (2), we need to make a more holistic plan to support this concept and NM and Anyscale are already working on this. We do not want to just hack up one model file to support this as it will become very unmaintainable

So, in terms of path forward for this PR, please remove the A8 compute from LayerNorm, Attention, and SiluAndMul and we should not make any material changes to llama.py.

This will allow QQQ to land quickly with the Linear running at A8.

@alexm-neuralmagic will review the kernels when he has some time this week.

We can then incorporate the other layers as part of the broader effort we have going on in a separate PR

@zhyncs
Copy link
Contributor

zhyncs commented Jun 4, 2024

remove the A8 compute from LayerNorm, Attention, and SiluAndMul

@robertgshaw2-neuralmagic This will reduce some performance gains. Could you share a more general approach and timeline? Perhaps our team can collaborate to implement the corresponding features in this PR, and then you can modify it into a general version later.
The existing W4A8 in the industry does not meet our accuracy requirements. That's why we propose this pr. In order to maximize performance while ensuring accuracy, all efforts are made to make this quantization truly usable in an online production environment, not just for publishing papers.

@mgoin
Copy link
Member

mgoin commented Jun 4, 2024

@zhyncs The goal of making this W4A8 optimization "production-ready" is exactly why I also think it is a good idea to land the first step as simply having this only be enabled for Linear modules as QQQLinearMethod. With this, we can enable the method for all models, rather than just supporting Llama.

We have an RFC tracking Int8 W8A8 that lays out some of the high-level goals, including graph fusion through torch.compile #3975. We have been landing many kernels recently and hope to share the torch.compile prototype soon!

@HandH1998
Copy link
Contributor Author

@zhyncs The goal of making this W4A8 optimization "production-ready" is exactly why I also think it is a good idea to land the first step as simply having this only be enabled for Linear modules as QQQLinearMethod. With this, we can enable the method for all models, rather than just supporting Llama.

We have an RFC tracking Int8 W8A8 that lays out some of the high-level goals, including graph fusion through torch.compile #3975. We have been landing many kernels recently and hope to share the torch.compile prototype soon!

I see what you want to do. Hope you can finish supporting dynamic Activation Quantization successfully. I will try to split the activation quantization with LayerNorm, Attention, and SiluAndMul, and put it in QQQLinearMethod. Note that this will reduce some inference performance.

@robertgshaw2-redhat
Copy link
Collaborator

@zhyncs The goal of making this W4A8 optimization "production-ready" is exactly why I also think it is a good idea to land the first step as simply having this only be enabled for Linear modules as QQQLinearMethod. With this, we can enable the method for all models, rather than just supporting Llama.
We have an RFC tracking Int8 W8A8 that lays out some of the high-level goals, including graph fusion through torch.compile #3975. We have been landing many kernels recently and hope to share the torch.compile prototype soon!

I see what you want to do. Hope you can finish supporting dynamic Activation Quantization successfully. I will try to split the activation quantization with LayerNorm, Attention, and SiluAndMul, and put it in QQQLinearMethod. Note that this will reduce some inference performance.

SG - then we can add in the quantized execution of non-Linear layers as part of our broader project

@HandH1998
Copy link
Contributor Author

@robertgshaw2-neuralmagic We have finished the mentioned work. This may help you achieve your high-level goals.

@robertgshaw2-redhat
Copy link
Collaborator

Thank you! This looks much cleaner

@robertgshaw2-redhat
Copy link
Collaborator

@HandH1998 - where are these kernels from? Is there an open source library that we are tracking for these?

I know these are adapted from Marlin so I think we could ramp up on them quickly, but I wanted to understand if this is something that we [Neural Magic + rest of the vLLM team] need to maintain or if we are just simply tracking a remote repo

@zhyncs
Copy link
Contributor

zhyncs commented Jun 7, 2024

It's adapted from Marlin and modified as necessary.
https://github.com/HandH1998/vllm/blob/b4b677138732ac952be8dbc87e44a5a58dffeb73/csrc/quantization/qqq/qqq_gemm_kernel.cu#L2-L7

need to maintain

yep.

simply tracking a remote repo

There seems to be no need at the moment.

@robertgshaw2-redhat
Copy link
Collaborator

It's adapted from Marlin and modified as necessary. https://github.com/HandH1998/vllm/blob/b4b677138732ac952be8dbc87e44a5a58dffeb73/csrc/quantization/qqq/qqq_gemm_kernel.cu#L2-L7

need to maintain

yep.

simply tracking a remote repo

There seems to be no need at the moment.

Im just trying to understand who is responsible for maintaining these kernels moving forward?

We currently have two models

  • With Punica / Flashattention / AWQ , we track from a different repo maintained by a 3rd party (+ we pull them into vllm as needed)
  • With Marlin / cutlass / triton / paged_attention, we maintain them ourselves

So I am trying to understand which case this is

@zhyncs
Copy link
Contributor

zhyncs commented Jun 7, 2024

which case this is

I think it should be the latter.

On another note, speaking off the topic, the performance of AWQ in vLLM is almost the worst among several mainstream frameworks. I also do not recommend continuing to use 3rd party, but rather maintaining our own set.

@robertgshaw2-redhat
Copy link
Collaborator

which case this is

I think it should be the latter.

On another note, speaking off the topic, the performance of AWQ in vLLM is almost the worst among several mainstream frameworks. I also do not recommend continuing to use 3rd party, but rather maintaining our own set.

sg - we are reviewing

re: AWQ. I think we can run AWQ models with Marlin if we support zero points in Marlin. We would welcome a contribution

Copy link
Collaborator

@alexm-redhat alexm-redhat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed the marlin modified kernel, great work guys! Left some comments.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have found in our tests that "stream" PTX is crashing on H100. Is it possible for you to use "cp.async.cg" and "cp.async.ca" without any createpolicy of cachehints? You can also test it on H100 and see if it crashes for you.

Copy link
Contributor Author

@HandH1998 HandH1998 Jun 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will try to modify functions cp_async4_stream and cp_async1_stream as you suggest next week. I think it should be easy to do.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for documenting this! It is helpful

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is "s1" statically defined with the quantization of the model or it is computed dynamically during inference?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The latter. Actually It is the dynamic per-token quantization scale, needing to be computed online.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

During execution, the code uses only FragS_GROUP o FragS_CHANNEL. You can use if (constexpr) to avoid always defining both, or you are relying on the compiler to eliminate them anyway?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, it only uses FragS_CHANNEL for per-channel dequantization, while it needs both FragS_GROUP and FragS_CHANNEL for per-group weight quantization. I will consider your suggestion.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, I am trying to define the register array frag_s3 in if constexpr, but it is invisible to external variables. Do you know a method to solve it?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you need to apply "dequant_per_channel" here and not on the final result (as it is done in the original Marlin, before write to C)?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I understand why you did it this way. You actually have a scale applied at the end on the output, and here you simply convert 4 bit to 8 bit by extracting bits. But maybe I'm missing something.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. Here dequant_per_channel just converts INT4 to INT8 by left shifting 4 bits.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to document here that the global_reduce works on int32 elements (since the original marlin was half), and mention also that this is the reason that you needed to add the temporary buffer C (since the original Marlin was reducing directly on the output buffer)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will fix it later.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can change this code to dynamically detect the L1 cache size (like we did in gptq marlin and marlin24). This will provide you a better estimate of L1 cache, since sometimes you will need more L1 due to scales or increased batching.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good suggestion! I will modify it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you rename it to "marlin_qqq_cuda", I think it is better describing what is happening there.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, "marlin_qqq_gemm" is a better name in my opinion

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK.

@robertgshaw2-redhat
Copy link
Collaborator

robertgshaw2-redhat commented Jun 7, 2024

Another thing - what is the model serialization format? And how do we make models in this format?

@zhyncs
Copy link
Contributor

zhyncs commented Jun 7, 2024

Another thing - what is the model serialization format? And how do we make models in this format?

https://github.com/HandH1998/QQQ/blob/4ab12906bb144ca8977a077f7f191e218fc2e038/examples/quant_model.py#L76-L83

https://github.com/handh1998/qqq?tab=readme-ov-file#quantize-model

Due to the changes from the original pull request, the performance will have a slight difference compared to the data in the image.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I understand why you did it this way. You actually have a scale applied at the end on the output, and here you simply convert 4 bit to 8 bit by extracting bits. But maybe I'm missing something.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't the mask supposed to be 0x0f0f0f0f? why do you extract the higher 4 bits?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For per-channel weight dequantization, I position the INT4 weight into the upper 4 bits of an INT8 by multiplying by 16, essentially performing a left shift by 4 bits. I use the mentioned mask and line 624 to achieve it. Note that the 8 INT4 weights of one INT32 are shuffled offline to ensure fetching the correct weights for every thread of a wrap.

@zhyncs
Copy link
Contributor

zhyncs commented Jun 7, 2024

Hi @alexm-neuralmagic These are the illustrations in the paper on QQQ drawn by @HandH1998 , hoping to help you review the code. The paper will be published on arxiv soon, please stay tuned.

@alexm-redhat
Copy link
Collaborator

@zhyncs thanks for the figure, this is really helpful

@HandH1998
Copy link
Contributor Author

So I am trying to understand which case this is

You guys can maintain qqq_gemm like you handle marlin. And remember to keep the copyright information at https://github.com/HandH1998/vllm/blob/b4b677138732ac952be8dbc87e44a5a58dffeb73/csrc/quantization/qqq/qqq_gemm_kernel.cu#L2-L7. If I have other optimization ideas, I will let you know.

@HandH1998
Copy link
Contributor Author

@robertgshaw2-neuralmagic @alexm-neuralmagic We have fixed the most issues you mentioned. For more details, please refer to our paper.

@HandH1998
Copy link
Contributor Author

@robertgshaw2-neuralmagic @alexm-neuralmagic We have rebased our code on vllm 0.5.0 and hope you guys can review it.

Copy link
Member

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add some unit tests for this kernel?

csrc/ops.h Outdated
Comment on lines 97 to 112
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you mark as const the operands that aren't modified? I know that torch::tensor.data_ptr doesn't respect constness so this will be a somewhat nonfunctional change, but this will make it easier to define torch metafunctions, which are necessary to support this function in torch.compile

torch::Tensor marlin_qqq_gemm(torch::Tensor const& a, torch::Tensor const& b_q_weight,
                              torch::Tensor& s1, torch::Tensor& s2,
                              torch::Tensor& s3, torch::Tensor& workspace,
                              int64_t size_m, int64_t size_n, int64_t size_k);

Comment on lines 229 to 260
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these exactly the same as the functions in marlin_cuda_kernel.cu? If so, could you factor these out and place them in a common location to be included by both?

Comment on lines 502 to 425
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: dependicies -> dependencies and maintining -> maintaining

I think codespell should catch this, so make sure you can run format.sh on this to get the build green :)

@HandH1998
Copy link
Contributor Author

@tlrmchlsmth We are adding the GEMM unit test and fixing the issues. We will release the code next Monday.

@HandH1998 HandH1998 force-pushed the w4a8 branch 2 times, most recently from f30daa6 to e28f2bb Compare June 24, 2024 05:40
@zhyncs
Copy link
Contributor

zhyncs commented Jun 24, 2024

Hi @robertgshaw2-neuralmagic @alexm-neuralmagic @tlrmchlsmth May you help review this latest code? Thanks. If everything goes smoothly, after this PR is merged, out team @HandH1998 plan to continue working on INT4-FP8 QQQ (W4A8) based on this.

@HandH1998
Copy link
Contributor Author

HandH1998 commented Jul 24, 2024

@brisker Thanks for your report! I have confirmed it is a bug in pack function. This is because that we missed clamp operation for per-channel quantization. So when the quant scale is not calculated from the absmax, the diff comes out. Now I have fixed it https://github.com/HandH1998/QQQ/commit/ae9531ce7a4d92080704df1e2b32f18a79600ae7. You can try your script again to verify it.

@brisker
Copy link

brisker commented Jul 24, 2024

@HandH1998
after bug fixed, I have tried it, now is ok

@HandH1998
Copy link
Contributor Author

@tlrmchlsmth Thanks for your review! I plan to resolve conflicts on Thursday. Then you can go ahead.

@brisker
Copy link

brisker commented Jul 25, 2024

@HandH1998
Since many models are trained in bf16, is bf16 supported for this w4a8 PR?

I mean, although the gemm is w4a8, but if it's quantized w must be converted from fp16 weights, there may exists accuracy risks, since many models are trained in bf16, and maybe the quantization tuning process(fake-quant )must also be done in bf16.

https://github.com/HandH1998/QQQ/blob/main/QQQ/gptq/qlinear/qlinear_marlin.py#L181

@HandH1998
Copy link
Contributor Author

@brisker BF16 is not supported in this PR. We don't have enough time to support it for now, though we would like to put it in our list. If you are interested in it, we welcome you submit a PR to our QQQ repo. This mainly includes changes about: offline quant and per-group dequant kernel. The former may be easy to achieve.

@HandH1998
Copy link
Contributor Author

@tlrmchlsmth I have resolved conflicts.

@brisker
Copy link

brisker commented Jul 29, 2024

@HandH1998

A few days ago, I said ok, but that is for w4a8-with-no-group.

Today I tried w4a8-gs128, now assert(torch.allclose(out1, out2, atol=1e-3, rtol=1e-3)) is not ok. ---this time, absmax or absmax*0.4999 are all not ok.

I think maybe there is another bug.

The new get_scale and w4_quant and qqq_linear init code is here:

qqq_linear = QQQ_Linear(
                4,
                group_size,
                2048,
                4096,
                False,
                weight_dtype=torch.float16
            )
def get_scale(x,group_size=-1):
    if group_size==-1:
        q_max = (2 ** (4-1)) - 1
        x_absmax = x.abs().amax(dim=-1, keepdim=True)*clip_ratio
        # if self.lwc:
        #     x_absmax = x_absmax * self.sigmoid(self.absbound_factor)

        scale = x_absmax / q_max
        return scale,None
    elif group_size==128:
        q_max = 2 ** (4) - 1
        zero_point = (q_max +1)/2
        reshaped_x = x.view(x.shape[0],int(x.shape[1]//group_size),group_size)
        xmax = reshaped_x.abs().amax(dim=-1, keepdim=True)*clip_ratio
        # if self.lwc:
        #     xmax = xmax * self.sigmoid(self.absbound_factor)
        xmin = -xmax
        scale = (xmax - xmin) / q_max
        # scale = scale.clamp(min=CLIPMIN, max=CLIPMAX)

        quant_value = torch.clamp(torch.round(reshaped_x / scale) + zero_point, 0, q_max)
        x_dequant =  scale * (quant_value - zero_point)
        x_dequant = x_dequant.view_as(x)


        ############ second step ##########
        second_step_q_max = 127  # int8
        absmax = x_dequant.abs().amax(dim=1, keepdim=True)
        scale_1 = absmax / second_step_q_max
        # import pdb;pdb.set_trace()
        # scale_1 = scale_1.clamp(min=CLIPMIN, max=CLIPMAX)
        return scale[:,:,0],scale_1

    else:
        raise RuntimeError
def w4_quant(x,group_size=-1):
    if group_size==-1:
        q_max = (2 ** (4-1)) - 1
        x_absmax = x.abs().amax(dim=-1, keepdim=True)*clip_ratio
        # x_absmax = x_absmax * 0.4999
        # if self.lwc:
        #     x_absmax = x_absmax * self.sigmoid(self.absbound_factor)
        scale = x_absmax / q_max
        # scale = scale.clamp(min=CLIPMIN, max=CLIPMAX)

        quant_value = torch.clamp(torch.round(x / scale), -q_max, q_max)

        x_dequant = quant_value * scale
        return x_dequant
    elif group_size==128:
        q_max = 2 ** (4) - 1
        zero_point = (q_max +1)/2
        reshaped_x = x.view(x.shape[0],int(x.shape[1]//group_size),group_size)
        # xmax = reshaped_x.amax(dim=-1, keepdim=True)
        # xmin = reshaped_x.amin(dim=-1,keepdim=True)  # bad direct w4a8-g128 acc, even worse than group=-1

        xmax = reshaped_x.abs().amax(dim=-1, keepdim=True)*clip_ratio
        # if self.lwc:
        #     xmax = xmax * self.sigmoid(self.absbound_factor)
        xmin = -xmax

        scale = (xmax - xmin) / q_max
        # scale = scale.clamp(min=CLIPMIN, max=CLIPMAX)

        quant_value = torch.clamp(torch.round(reshaped_x / scale) + zero_point, 0, q_max)
        x_dequant =  scale * (quant_value - zero_point)
        x_dequant = x_dequant.view_as(x)

        ###########


        ############ second step ##########
        second_step_q_max = 127  # int8
        absmax = x_dequant.abs().amax(dim=1, keepdim=True)
        scale_1 = absmax / second_step_q_max
        # scale_1 = scale_1.clamp(min=CLIPMIN, max=CLIPMAX)

        quant_value_1 = torch.clamp(torch.round(x_dequant / scale_1), -second_step_q_max, second_step_q_max)

        final_x_dequant = quant_value_1 * scale_1

        return final_x_dequant

    else:
        raise RuntimeError

@HandH1998 @zhyncs I got a w4a8 llama2-7b-model with wiki-ppl=5.7 when tested using fake-quant, but when I convert the weight into qqq-w4a8 format and test it again using the cuda-kernel, the wiki-ppl becomes 10000. So I run the following code to test the consistency between fake-quant and qqq-w4a8-kernel, I find there seems to be some difference which can not be neglected.

(I'm sure the conversion is right,because before tuning,the two ppl can be close,but all reasonable(7.5 and 7.53.) The QuantLinear class is extracted from here and dynamic_quant function is extracted from here

Any advice on this? I think this may be a potential bug.

I think this issue is important because the carefully tuned accuracy is on the fake-quant domain, but there is gap between fake-quant and cuda-kernel, which means carefully tuned accuracy can not be real accuracy. We can not tune w4a8-accuracy on cuda-kernel-domain, right?

from copy import deepcopy
import torch
import torch.nn as nn
import sys 
import QuantLinear

def dynamic_quant(x: torch.Tensor):
    quant_scale = x.abs().max(dim=-1, keepdim=True)[0].div(127.0).to(torch.float)
    quant_x = (x / quant_scale).round().clamp(-128, 127).to(torch.int8)
    dequant_x = quant_x*quant_scale
    return quant_x, quant_scale,dequant_x

def get_scale(x):
    q_max = (2 ** (4-1)) - 1
    x_absmax = x.abs().amax(dim=-1, keepdim=True)

    scale = x_absmax / q_max

    return scale,None


def w4_quant(x):
    q_max = (2 ** (4-1)) - 1
    x_absmax = x.abs().amax(dim=-1, keepdim=True)

    scale = x_absmax / q_max


    quant_value = torch.clamp(torch.round(x / scale), -q_max, q_max)

    x_dequant = quant_value * scale

    return x_dequant
ori_fc = nn.Linear(2048,4096,bias=False).cuda().half()

qqq_linear = QuantLinear(
                4,
                -1,
                2048,
                4096,
                False,
                weight_dtype=torch.float16
            )
qqq_linear.cuda()

input = torch.Tensor(10,2048).normal_().cuda().half()
quant_input, quant_input_scale,dequant_input = dynamic_quant(input)


scale,scale_extra = get_scale(ori_fc.weight.data)
qqq_linear.pack(ori_fc, scale, scale_extra)
ori_fc.weight.data = w4_quant(ori_fc.weight.data)

with torch.no_grad():
    out1 = ori_fc(dequant_input.half())

with torch.no_grad():
    out2 = qqq_linear.forward(input)
  
print((out1==out2).sum(),out1.numel())

@tlrmchlsmth
Copy link
Member

@brisker this wouldn't be a problem with any of the code in this PR, right, but rather an issue in https://github.com/HandH1998/QQQ?

@tlrmchlsmth
Copy link
Member

BTW, I just tried rerunning the failing jobs

@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 29, 2024
@mgoin
Copy link
Member

mgoin commented Jul 29, 2024

Marked PR as ready and resolved merge conflict, let's see if it is green now

@HandH1998
Copy link
Contributor Author

@tlrmchlsmth @mgoin The failed jobs is using lm_eval to evaluate our model HandH1998/QQQ-Llama-3-8b-g128. I used the command bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m HandH1998/QQQ-Llama-3-8b-g128 -b 32 -l 250 -f 5 -t 1 to get the evaluation result of .buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml on my host before. The result is 0.484 for exact_match,strict-match and 0.492 for exact_match,flexible-extract. It can not match with the CR result (0.448 for exact_match,strict-match).
Today I changed my flash-attn engine from xformers to vllm-flash-attn and got a new result 0.424 for exact_match,strict-match and 0.428 for exact_match,flexible-extract. It seems that flash-attn engine has a non-negligible impact on the final evaluation result. Anyway, the result of my host still don't align with the result of CR.
I wonder if you guys can reproduce the CR result. Please share your results here. Thanks!

@mgoin
Copy link
Member

mgoin commented Jul 30, 2024

Hi @HandH1998 I think you would see a more stable result if you increased the number of samples. Generally we use -l 1000 instead of 250

@HandH1998
Copy link
Contributor Author

@tlrmchlsmth @mgoin Hi, I have solved the CR issue when using -l 1000 instead of 250. However, new CR issues came out. I think the new issues are not relevant with our code. Could you please take a look?

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for letting us know. The PP test is a flaky test so I will attempt to land

@RanchiZhao
Copy link

RanchiZhao commented Aug 19, 2024

so if I use vllm for qqq, what is the speed up ratio compared to gptq-marlin kernel? in this pr
and I also wonder is this pr's QQQ-vllm only supports kv16? here the authored mentioned kv fp8?

Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
LeiWang1999 pushed a commit to LeiWang1999/vllm-bitblas that referenced this pull request Mar 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants