-
-
Notifications
You must be signed in to change notification settings - Fork 10.8k
[TPU][Quantization] TPU W8A8
#11785
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
[TPU][Quantization] TPU W8A8
#11785
Changes from all commits
Commits
Show all changes
73 commits
Select commit
Hold shift + click to select a range
3b0c8a6
w8a8 working
36fc1db
format
d83c04c
added all kernels
af9d0f4
format
0f9fd21
working on cuda
7b3203f
added mixed precision directory
bf50fa4
formatting
226ef52
cache current state - w8a16 running oom
bb7c741
[TPU] Ensure torch._sync(param) is called after param.data.copy_()
WoosukKwon cf842bd
yapf
WoosukKwon 67039bc
[TPU] Correctly profile peak memory usage
WoosukKwon 0695f77
Upgrade PyTorch XLA
WoosukKwon 11cf82f
Merge branch 'main' into tpu-peak-mem
WoosukKwon e016e38
stash
717b859
Merge branch 'main' into compressed-tensors-tpu
c848735
proper merge
1539915
add mixed precision
f00412a
format
b0a6b70
stash
e812d7e
Merge branch 'tpu-peak-mem' into compressed-tensors-tpu
764dda1
stash
87b2ae6
remove name
e813ff8
revert woosuk change
8cfaa1b
format
bbc9741
update
eb3f39e
fix nit
bb2fbe1
update
14ccb90
fix spurious
4092be2
stash branch for brittany
1aaa628
Merge branch 'main' into tpu-w8a8
48aa54b
revert
4efe915
fix
e98b79c
updated
5a89668
reduce cruft
57cbf5c
reduce cruft
3451c4d
updated
0c2e62a
update comment
172c9ca
revert spurious change
938ca81
remove cruft
9e18911
cruft reduction
5f58ec7
update docs
af9f298
added integration test
6fe2f62
updated
f2c0beb
Add bias back
8b29718
add bias support
1e2a373
updated
2a359ef
stash
f7e8975
Merge branch 'main' into remove-async-stream
0d4c3fd
fix
57340d2
update
38291d5
trigger test in CI
ead1e94
fix AZP
cea5e54
fixed!
940ddde
Merge branch 'tpu-w8a8' of https://github.com/neuralmagic/vllm into t…
84a5b29
fix azp adju
a1d7b4a
make docker command look better on gh
2b4ecfd
remove torch warnings
186c108
stash
7e8598a
Merge branch 'tpu-w8a8' of https://github.com/neuralmagic/vllm into t…
de773cd
fix AZP
3a53d7d
merged
0be5f69
added
cb69ba7
fix formatting
3896f6c
remove comment
33e1e13
formatted
dde72d6
add llama to ci
d7a9c93
Merge branch 'main' into tpu-w8a8
db9f795
Update supported_hardware.md
robertgshaw2-redhat 09ad869
Update supported_hardware.md
robertgshaw2-redhat b74c88a
ixed docs build
da4369e
Merge branch 'tpu-w8a8' of https://github.com/neuralmagic/vllm into t…
5ddcac2
Merge branch 'main' into tpu-w8a8
f353c43
fix CI
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,49 @@ | ||
| from dataclasses import dataclass | ||
|
|
||
| import lm_eval | ||
| import pytest | ||
|
|
||
| TASK = "gsm8k" | ||
| FILTER = "exact_match,strict-match" | ||
| RTOL = 0.03 | ||
|
|
||
|
|
||
| @dataclass | ||
| class GSM8KAccuracyTestConfig: | ||
| model_name: str | ||
| excepted_value: float | ||
|
|
||
| def get_model_args(self) -> str: | ||
| return (f"pretrained={self.model_name}," | ||
| "max_model_len=4096,max_num_seqs=32") | ||
|
|
||
|
|
||
| # NOTE: Accuracy scores measured on GPUs. | ||
| ACCURACY_CONFIGS = [ | ||
| GSM8KAccuracyTestConfig( | ||
| model_name="neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", | ||
| excepted_value=0.76), # no bias | ||
| # NOTE(rob): We cannot re-initialize VLLM in the same process for TPU, | ||
| # so only one of these tests can run in a single call to pytest. As | ||
| # a follow up, move this into the LM-EVAL section of the CI. | ||
| # GSM8KAccuracyTestConfig( | ||
| # model_name="neuralmagic/Qwen2-7B-Instruct-quantized.w8a8", | ||
| # excepted_value=0.66), # bias in QKV layers | ||
| ] | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("config", ACCURACY_CONFIGS) | ||
| def test_gsm8k_correctness(config: GSM8KAccuracyTestConfig): | ||
|
|
||
| results = lm_eval.simple_evaluate( | ||
| model="vllm", | ||
| model_args=config.get_model_args(), | ||
| tasks="gsm8k", | ||
| batch_size="auto", | ||
| ) | ||
|
|
||
| EXPECTED_VALUE = config.excepted_value | ||
| measured_value = results["results"][TASK][FILTER] | ||
| assert (measured_value - RTOL < EXPECTED_VALUE | ||
| and measured_value + RTOL > EXPECTED_VALUE | ||
| ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
74 changes: 0 additions & 74 deletions
74
vllm/model_executor/layers/quantization/kernels/__init__.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,74 +0,0 @@ | ||
| from typing import List, Optional, Type | ||
|
|
||
| import vllm.envs as envs | ||
| from vllm.model_executor.layers.quantization.kernels.exllama import ( | ||
| ExllamaLinearKernel) | ||
| from vllm.model_executor.layers.quantization.kernels.machete import ( | ||
| MacheteLinearKernel) | ||
| from vllm.model_executor.layers.quantization.kernels.marlin import ( | ||
| MarlinLinearKernel) | ||
| from vllm.model_executor.layers.quantization.kernels.MPLinearKernel import ( | ||
| MPLinearKernel, MPLinearLayerConfig) | ||
| from vllm.platforms import current_platform | ||
|
|
||
| # in priority/performance order (when available) | ||
| _POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [ | ||
| MacheteLinearKernel, | ||
| MarlinLinearKernel, | ||
| ExllamaLinearKernel, | ||
| ] | ||
|
|
||
|
|
||
| def choose_mp_linear_kernel( | ||
| config: MPLinearLayerConfig, | ||
| compute_capability: Optional[int] = None) -> Type[MPLinearKernel]: | ||
| """ | ||
| Choose an MPLinearKernel that can implement the given config for the given | ||
| compute capability. Attempts to choose the best kernel in terms of | ||
| performance. | ||
|
|
||
| Args: | ||
| config (MPLinearLayerConfig): Description of the linear layer to be | ||
| implemented. | ||
| compute_capability (Optional[int], optional): The compute capability of | ||
| the target device, if None uses `current_platform` to get the compute | ||
| capability. Defaults to None. | ||
|
|
||
| Raises: | ||
| ValueError: If no kernel can implement the given config. | ||
|
|
||
| Returns: | ||
| Type[MPLinearKernel]: Chosen kernel. | ||
| """ | ||
| if compute_capability is None: | ||
| if current_platform is None: | ||
| raise ValueError("Cannot determine compute capability") | ||
| _cc = current_platform.get_device_capability() | ||
| compute_capability = _cc[0] * 10 + _cc[1] | ||
|
|
||
| failure_reasons = [] | ||
| for kernel in _POSSIBLE_KERNELS: | ||
| if kernel.__name__ in envs.VLLM_DISABLED_KERNELS: | ||
| failure_reasons.append( | ||
| f' {kernel.__name__} disabled by environment variable') | ||
| continue | ||
|
|
||
| if kernel.get_min_capability() > compute_capability: | ||
| failure_reasons.append( | ||
| f"{kernel.__name__} requires capability " | ||
| f"{kernel.get_min_capability()}, current compute capability " | ||
| f"is {compute_capability}") | ||
| continue | ||
|
|
||
| can_implement, failure_reason = kernel.can_implement(config) | ||
| if can_implement: | ||
| return kernel | ||
| else: | ||
| failure_reasons.append( | ||
| f' {kernel.__name__} cannot implement due to: {failure_reason}' | ||
| ) | ||
|
|
||
| raise ValueError( | ||
| "Failed to find a kernel that can implement the "\ | ||
| "WNA16 linear layer. Reasons: \n" | ||
| + '\n'.join(failure_reasons)) | ||
File renamed without changes.
74 changes: 74 additions & 0 deletions
74
vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,74 @@ | ||
| from typing import List, Optional, Type | ||
|
|
||
| import vllm.envs as envs | ||
| from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501 | ||
| ExllamaLinearKernel) | ||
| from vllm.model_executor.layers.quantization.kernels.mixed_precision.machete import ( # noqa: E501 | ||
| MacheteLinearKernel) | ||
| from vllm.model_executor.layers.quantization.kernels.mixed_precision.marlin import ( # noqa: E501 | ||
| MarlinLinearKernel) | ||
| from vllm.model_executor.layers.quantization.kernels.mixed_precision.MPLinearKernel import ( # noqa: E501 | ||
| MPLinearKernel, MPLinearLayerConfig) | ||
| from vllm.platforms import current_platform | ||
|
|
||
| # in priority/performance order (when available) | ||
| _POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [ | ||
| MacheteLinearKernel, | ||
| MarlinLinearKernel, | ||
| ExllamaLinearKernel, | ||
| ] | ||
|
|
||
|
|
||
| def choose_mp_linear_kernel( | ||
| config: MPLinearLayerConfig, | ||
| compute_capability: Optional[int] = None) -> Type[MPLinearKernel]: | ||
| """ | ||
| Choose an MPLinearKernel that can implement the given config for the given | ||
| compute capability. Attempts to choose the best kernel in terms of | ||
| performance. | ||
|
|
||
| Args: | ||
| config (MPLinearLayerConfig): Description of the linear layer to be | ||
| implemented. | ||
| compute_capability (Optional[int], optional): The compute capability of | ||
| the target device, if None uses `current_platform` to get the compute | ||
| capability. Defaults to None. | ||
|
|
||
| Raises: | ||
| ValueError: If no kernel can implement the given config. | ||
|
|
||
| Returns: | ||
| Type[MPLinearKernel]: Chosen kernel. | ||
| """ | ||
| if compute_capability is None: | ||
| if current_platform is None: | ||
| raise ValueError("Cannot determine compute capability") | ||
| _cc = current_platform.get_device_capability() | ||
| compute_capability = _cc[0] * 10 + _cc[1] | ||
|
|
||
| failure_reasons = [] | ||
| for kernel in _POSSIBLE_KERNELS: | ||
| if kernel.__name__ in envs.VLLM_DISABLED_KERNELS: | ||
| failure_reasons.append( | ||
| f' {kernel.__name__} disabled by environment variable') | ||
| continue | ||
|
|
||
| if kernel.get_min_capability() > compute_capability: | ||
| failure_reasons.append( | ||
| f"{kernel.__name__} requires capability " | ||
| f"{kernel.get_min_capability()}, current compute capability " | ||
| f"is {compute_capability}") | ||
| continue | ||
|
|
||
| can_implement, failure_reason = kernel.can_implement(config) | ||
| if can_implement: | ||
| return kernel | ||
| else: | ||
| failure_reasons.append( | ||
| f' {kernel.__name__} cannot implement due to: {failure_reason}' | ||
| ) | ||
|
|
||
| raise ValueError( | ||
| "Failed to find a kernel that can implement the "\ | ||
| "WNA16 linear layer. Reasons: \n" | ||
| + '\n'.join(failure_reasons)) | ||
File renamed without changes.
File renamed without changes.
File renamed without changes.
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NOTE for reviewer - this file is not changed, it is just moved