|
1 | 1 | import torch
|
2 |
| -from torch.testing._internal.common_utils import TestCase, IS_FBCODE |
| 2 | +from torch.testing._internal.common_utils import ( |
| 3 | + TestCase, |
| 4 | + instantiate_parametrized_tests, |
| 5 | + parametrize, |
| 6 | + run_tests, |
| 7 | +) |
3 | 8 | from torch.testing._internal.optests import opcheck
|
4 |
| -import torchao |
5 |
| -from torchao.prototype.fp6_llm.fp6_llm import from_tc_float6_e3m2 |
6 |
| -import unittest |
7 |
| -from parameterized import parameterized |
| 9 | +from torchao.utils import is_fbcode |
| 10 | +from torchao.prototype.quant_llm import from_scaled_tc_fpx |
8 | 11 | import pytest
|
9 | 12 |
|
| 13 | +if is_fbcode(): |
| 14 | + pytest.skip("Skipping the test in fbcode since we don't have TARGET file for kernels") |
| 15 | + |
10 | 16 | try:
|
11 | 17 | import torchao.ops
|
12 | 18 | except RuntimeError:
|
13 | 19 | pytest.skip("torchao.ops not available")
|
14 | 20 |
|
15 | 21 |
|
16 |
| -# torch.testing._internal.optests.generate_tests.OpCheckError: opcheck(op, ...): |
17 |
| -# test_faketensor failed with module 'torch' has no attribute '_custom_ops' (scroll up for stack trace) |
18 |
| -@pytest.mark.filterwarnings("ignore:create_unbacked_symint is deprecated, please use new_dynamic_size instead:UserWarning") |
19 |
| -@unittest.skipIf(IS_FBCODE, "Skipping the test in fbcode since we don't have TARGET file for kernels") |
20 | 22 | class TestOps(TestCase):
|
21 |
| - def _create_fp6_inputs(self, BS: int, OC: int, IC: int, device): |
22 |
| - # Randomly initialize each bytes. The highest value for randint() is set the the max value of uint32_t. |
23 |
| - fp6_weight = torch.randint(4294967295, (OC, IC // 16 * 3)).to(torch.int) |
24 |
| - fp16_scale = torch.rand(OC).half() + 0.5 |
25 |
| - fp16_activation = torch.rand(BS, IC).half() + 0.5 |
26 |
| - return fp6_weight.to(device), fp16_scale.to(device), fp16_activation.to(device) |
27 |
| - |
28 |
| - @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") |
29 |
| - def test_fp6_llm_linear(self): |
| 23 | + def _create_fpx_inputs(self, ebits: int, mbits: int, BS: int, OC: int, IC: int, device): |
| 24 | + # Randomly initialize each byte |
| 25 | + nbits = 1 + ebits + mbits |
| 26 | + fpx_weight = torch.randint(256, (OC, IC // 8 * nbits), dtype=torch.uint8) |
| 27 | + scale = torch.rand(OC).half() + 0.5 |
| 28 | + fp16_act = torch.rand(BS, IC).half() + 0.5 |
| 29 | + return fpx_weight.to(device), scale.to(device), fp16_act.to(device) |
| 30 | + |
| 31 | + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") |
| 32 | + @parametrize("ebits,mbits", [(3, 2), (2, 2)]) |
| 33 | + def test_quant_llm_linear(self, ebits, mbits): |
30 | 34 | BS = 2
|
31 | 35 | OC = 256
|
32 | 36 | IC = 256
|
33 | 37 | splitK = 1
|
34 |
| - fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC, "cuda") |
| 38 | + fpx_weight, scale, fp16_act = self._create_fpx_inputs(ebits, mbits, BS, OC, IC, "cuda") |
35 | 39 |
|
36 | 40 | # smoke test
|
37 |
| - torchao.ops.fp6_llm_linear(fp16_activation, fp6_weight, fp16_scale, splitK) |
| 41 | + torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, fpx_weight, scale, splitK) |
38 | 42 |
|
39 | 43 | # comprehensive testing
|
40 | 44 | test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"]
|
41 |
| - opcheck(torch.ops.torchao.fp6_llm_linear, (fp16_activation, fp6_weight, fp16_scale, splitK), test_utils=test_utils) |
| 45 | + opcheck(torch.ops.torchao.quant_llm_linear, (ebits, mbits, fp16_act, fpx_weight, scale, splitK), test_utils=test_utils) |
| 46 | + |
| 47 | + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") |
| 48 | + @parametrize("BS,OC,IC,splitK", [(1, 2048, 4096, 5), (2, 8192, 8192, 6)]) |
| 49 | + @parametrize("ebits,mbits", [(3, 2), (2, 2)]) |
| 50 | + def test_quant_llm_linear_correctness(self, ebits, mbits, BS, OC, IC, splitK): |
| 51 | + # adapted from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/tests/python/kernel_test_fpx.py |
| 52 | + fpx_weight, scale, fp16_act = self._create_fpx_inputs(ebits, mbits, BS, OC, IC, "cuda") |
| 53 | + |
| 54 | + results_fpx = torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, fpx_weight, scale, splitK) |
42 | 55 |
|
43 |
| - # adapted from https://github.com/usyd-fsalab/fp6_llm/blob/main/tests/python/kernel_test.py |
44 |
| - @parameterized.expand([(1, 2048, 4096, 5), (2, 8192, 8192, 6)]) |
45 |
| - @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") |
46 |
| - def test_fp6_llm_linear_correctness(self, BS, OC, IC, splitK): |
47 |
| - fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC, "cuda") |
| 56 | + fp16_weight = from_scaled_tc_fpx(fpx_weight, ebits, mbits, scale).half() |
| 57 | + results_fp16 = fp16_act @ fp16_weight.T |
48 | 58 |
|
49 |
| - results_fp6 = torchao.ops.fp6_llm_linear(fp16_activation, fp6_weight, fp16_scale, splitK) |
| 59 | + error = (results_fpx - results_fp16).abs().mean() |
| 60 | + gt = results_fp16.abs().mean() |
| 61 | + relative_error = error / gt |
| 62 | + assert relative_error < 1e-3 |
50 | 63 |
|
51 |
| - fp16_weight = from_tc_float6_e3m2(fp6_weight.view(torch.uint8), dtype=torch.float16) * fp16_scale[:, None] |
52 |
| - results_fp16 = fp16_activation @ fp16_weight.T |
53 | 64 |
|
54 |
| - error = (results_fp6 - results_fp16).abs() |
55 |
| - relative_error = error / results_fp16.abs() |
56 |
| - assert relative_error.mean() < 1e-2 |
| 65 | +instantiate_parametrized_tests(TestOps) |
57 | 66 |
|
58 | 67 |
|
59 | 68 | if __name__ == "__main__":
|
60 |
| - unittest.main() |
| 69 | + run_tests() |
0 commit comments