@@ -263,7 +263,7 @@ def _(func, types, args, kwargs):
263263@implements_torch_function (torch .matmul )
264264def _ (func , types , args , kwargs ):
265265 input_tensor , weight_tensor = args [0 ], args [1 ]
266- return _float8_linear_impl (input_tensor , weight_tensor . t () )
266+ return _float8_mm_impl (input_tensor , weight_tensor )
267267
268268
269269@implements (aten .addmm_ .default )
@@ -273,10 +273,24 @@ def _(func, types, args, kwargs):
273273 args [1 ],
274274 args [2 ] if len (args ) > 2 else None ,
275275 )
276- out = _float8_linear_impl (input_tensor , weight_tensor . t () )
276+ out = _float8_mm_impl (input_tensor , weight_tensor )
277277 return output_tensor .copy_ (out )
278278
279279
280+ def _float8_mm_impl (
281+ input_tensor : torch .Tensor ,
282+ weight_tensor : torch .Tensor ,
283+ ) -> torch .Tensor :
284+ assert isinstance (weight_tensor , Float8Tensor ), (
285+ f"Don't expect to reach here with an override other than weight currently, { type (input_tensor )} { type (weight_tensor )} "
286+ )
287+ # Only support matmul(x, w.t()) for now
288+ is_transposed = weight_tensor .qdata .stride (- 2 ) < weight_tensor .qdata .stride (- 1 )
289+ if not is_transposed :
290+ raise ValueError ("matmul with non-transposed Float8Tensor not supported yet" )
291+ return _float8_linear_impl (input_tensor , weight_tensor .t ())
292+
293+
280294def _float8_linear_impl (
281295 input_tensor : torch .Tensor ,
282296 weight_tensor : torch .Tensor ,
@@ -286,10 +300,10 @@ def _float8_linear_impl(
286300 f"Don't expect to reach here with an override other than weight currently, { type (input_tensor )} { type (weight_tensor )} "
287301 )
288302
289- # During the backward pass, we transpose the weight tensor,
290- # so if the weight tensor was originally rowwise quantized,
291- # now it becomes colwise. In this case, simply dequantize
292- # the tensor and do a bf16 matmul
303+ # If we perform a matmul during the backward pass (e.g. in a LoRA matmul
304+ # autograd.Function), the weight tensor will be transposed. If the weight
305+ # tensor was originally rowwise quantized, now it becomes colwise.
306+ # In this case, simply dequantize the tensor and do a bf16 matmul
293307 is_colwise = (
294308 weight_tensor .block_size [0 ] == weight_tensor .shape [0 ]
295309 and weight_tensor .block_size [1 ] == 1
0 commit comments