Skip to content

Conversation

younesbelkada
Copy link
Contributor

@younesbelkada younesbelkada commented Apr 24, 2024

What does this PR do?

This PR adds EETQ quantized linear layers support in PEFT.
EETQ has been recently added in transformers and offers a rapid 8-bit quantization inference: huggingface/transformers#30262

Fixes: #1643

Learn more about EETQ here: https://github.com/NetEase-FuXi/EETQ

TODO

  •  Add in Dockerfile
  • Add tests

cc @SunMarc @BenjaminBossan @pacman100

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@younesbelkada younesbelkada marked this pull request as ready for review April 24, 2024 09:41
Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for adding EETQ support. This implementation looks really smooth, nice!

Could you please add an entry to the quantization docs of PEFT? Maybe mention some pros/cons of EETQ in comparison to the quantization methods or a reference to where it's better explained (their README wasn't very helpful in this regard).

Did you see the CI error:

Error: The version '3.9' with architecture 'arm64' was not found for macOS 14.4.1.

Any idea what that is about?



if is_eetq_available():
from eetq import EetqLinear
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we have lazy import as for bnb or is it not necessary for EETQ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I mean is should we indent all the code below to be inside of if is_eetq_available():? Or is it not necessary because, unlike bnb, EETQ does not initialize cuda?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it does, let's indent it to be on the safe zon


self._active_adapter = adapter_name
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would merging currently work with EETQ? I would assume not. Maybe we can raise an error when users try it?

return result

def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
raise ValueError("Merging LoRA layers is not supported for Eetq layers.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also add unmerge for completeness. I also wonder if ValueError is best or if it should be TypeError (maybe AttributeError?).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense!



if is_eetq_available():
from eetq import EetqLinear
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I mean is should we indent all the code below to be inside of if is_eetq_available():? Or is it not necessary because, unlike bnb, EETQ does not initialize cuda?

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much for adding EETQ, LGTM. I have 2 nits regarding the documentation, up to you if you want to fix them.

import torch
from transformers import EetqConfig

config = EetqConfig("int8")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This probably requires the latest transformers, right? Maybe worth adding the min version?

@younesbelkada younesbelkada merged commit d0fa70a into main Apr 26, 2024
Guy-Bilitski pushed a commit to Guy-Bilitski/peft that referenced this pull request May 13, 2025
* v1

* fix tests'

* fix unneeded change

* fix unneeded change

* fix unneeded change

* fix

* fix CI

* fix docker image

* fix docker image

* add docs

* lazy import

* raise when merge

* raise when merge

* Update eetq.py

* merge

* style

* add unmerge

* indent

* Update docs/source/developer_guides/quantization.md

Co-authored-by: Benjamin Bossan <[email protected]>

* add details about transformers

---------

Co-authored-by: Benjamin Bossan <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support EETQ QLoRA

3 participants