Skip to content

Commit a323bbe

Browse files
committed
Update Float8Tensor for GRPO training in unsloth
**Summary:** Support a few extra ops called during GRPO loop in unsloth/vllm for Float8Tensor. **Test Plan:** ``` python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_fp8_matmul_lora python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_to_dtype_layout python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_has_compatible_shallow_copy_type python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_transpose ```
1 parent e418066 commit a323bbe

File tree

3 files changed

+210
-18
lines changed

3 files changed

+210
-18
lines changed

test/quantization/quantize_/workflows/float8/test_float8_tensor.py

Lines changed: 114 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import copy
88
import unittest
99
from contextlib import nullcontext
10-
from typing import Tuple
10+
from typing import Tuple, Type
1111

1212
import torch
1313
from torch._inductor.utils import run_and_get_code
@@ -18,6 +18,7 @@
1818
from torchao.quantization import (
1919
Float8DynamicActivationFloat8WeightConfig,
2020
Float8WeightOnlyConfig,
21+
Granularity,
2122
PerRow,
2223
PerTensor,
2324
quantize_,
@@ -72,6 +73,24 @@ def forward(self, x):
7273
return self.conv(x)
7374

7475

76+
class ToyLoRAModel(torch.nn.Module):
77+
def __init__(
78+
self,
79+
in_features: int,
80+
out_features: int,
81+
lora_rank: int = 8,
82+
):
83+
super().__init__()
84+
self.linear = torch.nn.Linear(in_features, out_features, bias=False)
85+
self.lora_A = torch.nn.Parameter(torch.randn(in_features, lora_rank))
86+
self.lora_B = torch.nn.Parameter(torch.randn(lora_rank, out_features))
87+
88+
def forward(self, x):
89+
matmul_out = torch.matmul(x, self.linear.weight.t())
90+
lora_out = x @ self.lora_A @ self.lora_B
91+
return matmul_out + lora_out
92+
93+
7594
# TODO: move tests in test_affine_quantized_float.py here after we migrated all implementations
7695
@unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+")
7796
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@@ -105,16 +124,75 @@ def test_fp8_linear_variants(
105124
dtype: torch.dtype,
106125
mode: str,
107126
compile: bool,
108-
granularity,
127+
granularity: Granularity,
109128
kernel_preference: KernelPreference,
110129
sizes: Tuple,
130+
):
131+
self._test_fp8_matmul_variants(
132+
dtype,
133+
mode,
134+
compile,
135+
granularity,
136+
kernel_preference,
137+
sizes,
138+
ToyLinearModel,
139+
)
140+
141+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
142+
@unittest.skipIf(
143+
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
144+
)
145+
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
146+
@common_utils.parametrize("mode", ["dynamic", "weight-only"])
147+
@common_utils.parametrize("compile", [True, False])
148+
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
149+
@common_utils.parametrize(
150+
"kernel_preference",
151+
[KernelPreference.AUTO, KernelPreference.TORCH, KernelPreference.FBGEMM],
152+
)
153+
# Inputs are (M,..), K, N
154+
@common_utils.parametrize(
155+
"sizes",
156+
[
157+
((128,), 256, 128),
158+
((32, 128), 64, 256),
159+
],
160+
)
161+
def test_fp8_matmul_lora(
162+
self,
163+
dtype: torch.dtype,
164+
mode: str,
165+
compile: bool,
166+
granularity: Granularity,
167+
kernel_preference: KernelPreference,
168+
sizes: Tuple,
169+
):
170+
self._test_fp8_matmul_variants(
171+
dtype,
172+
mode,
173+
compile,
174+
granularity,
175+
kernel_preference,
176+
sizes,
177+
ToyLoRAModel,
178+
)
179+
180+
def _test_fp8_matmul_variants(
181+
self,
182+
dtype: torch.dtype,
183+
mode: str,
184+
compile: bool,
185+
granularity: Granularity,
186+
kernel_preference: KernelPreference,
187+
sizes: Tuple,
188+
model_class: Type[torch.nn.Module],
111189
):
112190
if (
113191
isinstance(granularity, PerTensor)
114192
and kernel_preference == KernelPreference.FBGEMM
115193
):
116194
return unittest.skip(
117-
"per tensor with fbgemm kernel preferece does not work yet"
195+
"per tensor with fbgemm kernel preference does not work yet"
118196
)
119197

120198
error_message = None
@@ -145,7 +223,7 @@ def test_fp8_linear_variants(
145223
input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda")
146224

147225
# Create a linear layer with bfloat16 dtype
148-
model = ToyLinearModel(K, N).eval().to(dtype).to("cuda")
226+
model = model_class(K, N).eval().to(dtype).to("cuda")
149227

150228
quantized_model = copy.deepcopy(model)
151229

@@ -758,6 +836,38 @@ def test_slice_3d_operation(self, granularity, slice_dim, tensor_shape):
758836

759837
self.assertEqual(sliced_dequantized, sliced_original)
760838

839+
def test_to_dtype_layout(self):
840+
x = torch.randn(128, 512, device="cuda", dtype=torch.bfloat16)
841+
x_fp8 = Float8Tensor.from_hp(x)
842+
y_fp8 = torch.ops.aten.to.dtype_layout(
843+
x_fp8, dtype=x_fp8.dtype, layout=x_fp8.layout, device="cpu"
844+
)
845+
self.assertEqual(y_fp8.dtype, x_fp8.dtype)
846+
self.assertEqual(y_fp8.layout, x_fp8.layout)
847+
self.assertEqual(y_fp8.device, torch.device("cpu"))
848+
849+
def test_has_compatible_shallow_copy_type(self):
850+
x1 = torch.randn(128, 512, device="cuda", dtype=torch.bfloat16)
851+
x2 = torch.randn(128, 512, device="cuda", dtype=torch.bfloat16)
852+
x3 = torch.randn(128, 256, device="cuda", dtype=torch.bfloat16)
853+
x1_fp8 = Float8Tensor.from_hp(x1)
854+
x2_fp8 = Float8Tensor.from_hp(x2)
855+
x3_fp8 = Float8Tensor.from_hp(x3)
856+
self.assertFalse(torch._has_compatible_shallow_copy_type(x1, x2_fp8))
857+
self.assertFalse(torch._has_compatible_shallow_copy_type(x1_fp8, x2))
858+
self.assertTrue(torch._has_compatible_shallow_copy_type(x1_fp8, x2_fp8))
859+
# Wrong shape
860+
self.assertFalse(torch._has_compatible_shallow_copy_type(x1_fp8, x3_fp8))
861+
862+
def test_transpose(self):
863+
x = torch.randn(128, 512, device="cuda", dtype=torch.bfloat16)
864+
x_fp8 = Float8Tensor.from_hp(x)
865+
x_fp8_t = x_fp8.t()
866+
torch.testing.assert_close(x_fp8_t.qdata, x_fp8.qdata.t(), atol=0, rtol=0)
867+
torch.testing.assert_close(x_fp8_t.scale, x_fp8.scale.t(), atol=0, rtol=0)
868+
self.assertEqual(x_fp8.block_size, (1, 512), atol=0, rtol=0)
869+
self.assertEqual(x_fp8_t.block_size, (512, 1), atol=0, rtol=0)
870+
761871

762872
common_utils.instantiate_parametrized_tests(TestFloat8Tensor)
763873

torchao/quantization/quantize_/workflows/float8/float8_tensor.py

Lines changed: 66 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -249,21 +249,59 @@ def from_hp(
249249
implements_torch_function = Float8Tensor.implements_torch_function
250250

251251

252-
@implements([aten.linear.default])
253-
@implements_torch_function([torch.nn.functional.linear])
252+
@implements(aten.linear.default)
253+
@implements_torch_function(torch.nn.functional.linear)
254254
def _(func, types, args, kwargs):
255255
input_tensor, weight_tensor, bias = (
256256
args[0],
257257
args[1],
258258
args[2] if len(args) > 2 else None,
259259
)
260+
return _float8_mm_impl(input_tensor, weight_tensor.t(), bias)
261+
262+
263+
@implements(aten.matmul.default)
264+
@implements_torch_function(torch.matmul)
265+
def _(func, types, args, kwargs):
266+
input_tensor, weight_tensor = args[0], args[1]
267+
return _float8_mm_impl(input_tensor, weight_tensor)
268+
269+
270+
@implements(aten.mm.default)
271+
@implements_torch_function(torch.mm)
272+
def _(func, types, args, kwargs):
273+
input_tensor, weight_tensor = args[0], args[1]
274+
return _float8_mm_impl(input_tensor, weight_tensor)
275+
276+
277+
@implements(aten.addmm_.default)
278+
def _(func, types, args, kwargs):
279+
bias_tensor, input_tensor, weight_tensor = (
280+
args[0],
281+
args[1],
282+
args[2],
283+
)
284+
assert kwargs.get("alpha", 1) == 1, "only alpha=1 is supported"
285+
assert kwargs.get("beta", 1) == 1, "only beta=1 is supported"
286+
out = _float8_mm_impl(input_tensor, weight_tensor)
287+
return bias_tensor.add_(out)
288+
289+
290+
def _float8_mm_impl(
291+
input_tensor: Float8Tensor,
292+
weight_tensor: Float8Tensor,
293+
bias: Optional[torch.Tensor] = None,
294+
) -> torch.Tensor:
260295
assert isinstance(weight_tensor, Float8Tensor), (
261296
f"Don't expect to reach here with an override other than weight currently, {type(input_tensor)} {type(weight_tensor)}"
262297
)
263298

264299
act_quant_kwargs = weight_tensor.act_quant_kwargs
265300
# quantize activation, if `act_quant_kwargs` is specified
266301
if act_quant_kwargs is not None:
302+
assert not isinstance(input_tensor, TorchAOBaseTensor), (
303+
"input tensor was already quantized"
304+
)
267305
input_tensor = _choose_quant_func_and_quantize_tensor(
268306
input_tensor, act_quant_kwargs
269307
)
@@ -290,6 +328,7 @@ def _(func, types, args, kwargs):
290328
assert is_sm_at_least_90(), "Expected SM90+ for fbgemm_gpu_genai"
291329
mm_config = weight_tensor.mm_config
292330
assert mm_config is not None
331+
weight_tensor = weight_tensor.t()
293332

294333
out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape)
295334
xq = input_tensor.qdata.reshape(-1, input_tensor.qdata.shape[-1])
@@ -324,23 +363,16 @@ def _(func, types, args, kwargs):
324363
assert kernel_choice == "torch"
325364
scaled_mm_config = weight_tensor.mm_config
326365
assert scaled_mm_config is not None
327-
out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape)
366+
out_shape = (*input_tensor.shape[:-1], weight_tensor.shape[1])
328367

329368
# Extract tensor data and scales
330369
inpt_data = input_tensor.qdata.reshape(-1, input_tensor.qdata.shape[-1])
331370
w_data = weight_tensor.qdata
332371
input_scale = input_tensor.scale
333372
w_scale = weight_tensor.scale
334373

335-
# Handle rowwise scaling
336-
if _is_rowwise_scaled(weight_tensor):
337-
assert _is_rowwise_scaled(input_tensor), (
338-
"Input tensor must be rowwise block size"
339-
)
340-
w_scale = w_scale.transpose(-1, -2)
341-
342374
input_scale = preprocess_scale(input_scale, input_tensor.shape)
343-
inpt_data, w_data = preprocess_data(inpt_data, w_data.T, scaled_mm_config)
375+
inpt_data, w_data = preprocess_data(inpt_data, w_data, scaled_mm_config)
344376

345377
return addmm_float8_unwrapped_inference(
346378
inpt_data,
@@ -357,9 +389,11 @@ def _(func, types, args, kwargs):
357389
)
358390
# when input is not `Float8Tensor`, we expect that it is not quantized
359391
# so this is float8 weight only quantization
360-
return torch.nn.functional.linear(
361-
input_tensor, weight_tensor.dequantize(), bias
362-
)
392+
out = torch.matmul(input_tensor, weight_tensor.dequantize())
393+
if bias is not None:
394+
return out + bias
395+
else:
396+
return out
363397

364398

365399
@implements_torch_function(torch.bmm)
@@ -677,6 +711,7 @@ def _(func, types, args, kwargs):
677711
assert original_shape[-1] == size[-1], (
678712
f"Only support reshaping when last dimension matches, requested: reshaping from {original_shape} to {size}"
679713
)
714+
# TODO: this seems wrong, we should merge the first two dimensions instead
680715
qdata = self.qdata.reshape(*size)
681716
scale = self.scale.reshape(*size)
682717
block_size = self.block_size.copy()
@@ -785,6 +820,23 @@ def _(func, types, args, kwargs):
785820
return return_and_correct_aliasing(func, args, kwargs, new)
786821

787822

823+
@implements(aten.t.default)
824+
def _(func, types, args, kwargs):
825+
assert len(args) == 1
826+
self = args[0]
827+
assert len(self.block_size) == 2
828+
new_tensor = self.__class__(
829+
self.qdata.t(),
830+
self.scale.t(),
831+
(self.block_size[1], self.block_size[0]),
832+
self.mm_config,
833+
self.act_quant_kwargs,
834+
self.kernel_preference,
835+
self.dtype,
836+
)
837+
return return_and_correct_aliasing(func, args, kwargs, new_tensor)
838+
839+
788840
Float8Tensor.__module__ = "torchao.quantization"
789841

790842
# Allow a model with Float8Tensor weights to be loaded with `weights_only=True`

torchao/utils.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,36 @@ def _implements_common_tensor_ops(cls):
508508
implements_torch_function = cls.implements_torch_function
509509
aten = torch.ops.aten
510510

511+
@implements(torch.ops.aten.to.dtype_layout)
512+
def _(func, types, args, kwargs):
513+
# only support kwargs for now
514+
assert len(args) == 1
515+
self = args[0]
516+
# only support dtype, layout, and device for now
517+
for k in kwargs.keys():
518+
assert k in ["dtype", "layout", "device"]
519+
# only support same dtype and layout
520+
# different dtype and layout has undefined behavior
521+
if "dtype" in kwargs:
522+
assert kwargs["dtype"] == self.dtype
523+
if "layout" in kwargs:
524+
assert kwargs["layout"] == self.layout
525+
# if device is the same, treat this like a no-op
526+
device = kwargs.get("device")
527+
if device == self.device:
528+
return self
529+
new_tensor = args[0]._apply_fn_to_data(lambda x: func(x, device=device))
530+
return return_and_correct_aliasing(func, args, kwargs, new_tensor)
531+
532+
# This is called during _apply() to see if we can shallow
533+
# copy the content of one tensor into another. For now,
534+
# we only allow shallow copy if both tensors are of the
535+
# same type and have the same shape.
536+
@implements_torch_function(torch._has_compatible_shallow_copy_type)
537+
def _(func, types, args, kwargs):
538+
assert len(args) == 2
539+
return type(args[0]) == type(args[1]) and args[0].shape == args[1].shape
540+
511541
@implements_torch_function(
512542
[
513543
torch.Tensor.contiguous,

0 commit comments

Comments
 (0)