@@ -1094,20 +1094,31 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
1094
1094
return return_and_correct_aliasing (
1095
1095
func , args , kwargs , args [0 ]._apply_fn_to_data (torch .detach )
1096
1096
)
1097
- if func is aten .clone .default :
1097
+ elif func is aten .clone .default :
1098
1098
return return_and_correct_aliasing (
1099
1099
func , args , kwargs , args [0 ]._apply_fn_to_data (torch .clone )
1100
1100
)
1101
- if func is aten .t .default :
1101
+ elif func is aten .t .default :
1102
1102
"""we don't need to repack the weight and just rely on external
1103
1103
shape being changed and record the status of transpose/no-transpose
1104
1104
"""
1105
1105
args [0 ].transposed = not args [0 ].transposed
1106
1106
return return_and_correct_aliasing (func , args , kwargs , args [0 ])
1107
-
1108
- raise NotImplementedError (
1109
- f"Float8AQTLayout dispatch: attempting to run { func } , this is not supported"
1110
- )
1107
+ elif func is aten .slice .Tensor :
1108
+ self , dim , start , end , step = fill_defaults (args , 5 , [0 , None , None , 1 ])
1109
+ if dim == 0 :
1110
+ return return_and_correct_aliasing (
1111
+ func , args , kwargs , args [0 ]._apply_fn_to_data (lambda x : aten .slice .Tensor (x , dim , start , end , step ))
1112
+ )
1113
+ elif dim == 1 :
1114
+ assert len (self .scale .shape ) == 1 , f"slice dim==1 only works when len(scale.shape) == 1 currently, got: { self .scale .shape } "
1115
+ return Float8AQTLayout (aten .slice .Tensor (self .float8_data , dim , start , end , step ), self .scale , None , self .layout_type )
1116
+ else :
1117
+ raise NotImplementedError (f"Float8AQTLayout dispatch: attempting to run { func } , with dim={ dim } , that is not supported" )
1118
+ else :
1119
+ raise NotImplementedError (
1120
+ f"Float8AQTLayout dispatch: attempting to run { func } , this is not supported"
1121
+ )
1111
1122
1112
1123
__torch_function__ = torch ._C ._disabled_torch_function_impl
1113
1124
@@ -1644,6 +1655,28 @@ def _linear_fp8_act_fp8_weight_impl(
1644
1655
use_fast_accum = scaled_mm_config .use_fast_accum ,
1645
1656
).reshape (out_shape )
1646
1657
1658
+ def _linear_fp_act_fp8_weight_check (
1659
+ input_tensor : Union [torch .Tensor , AffineQuantizedTensor ],
1660
+ weight_tensor : Union [torch .Tensor , AffineQuantizedTensor ],
1661
+ bias : Optional [torch .Tensor ],
1662
+ ) -> bool :
1663
+ return (
1664
+ # input is native float tensor
1665
+ not is_traceable_wrapper_subclass (input_tensor ) and
1666
+ input_tensor .is_floating_point () and
1667
+ # weight is float8 quantized affine quantized tensor
1668
+ isinstance (weight_tensor , AffineQuantizedTensor ) and
1669
+ isinstance (weight_tensor .layout_type , Float8LayoutType )
1670
+ and weight_tensor .layout_tensor .dtype in [torch .float8_e4m3fn , torch .float8_e5m2 ]
1671
+ and (weight_tensor .shape == weight_tensor .block_size or _is_rowwise_scaled (weight_tensor ))
1672
+ )
1673
+
1674
+ def _linear_fp_act_fp8_weight_impl (
1675
+ input_tensor : torch .Tensor ,
1676
+ weight_tensor : AffineQuantizedTensor ,
1677
+ bias : Optional [torch .Tensor ],
1678
+ ):
1679
+ return torch .nn .functional .linear (input_tensor , weight_tensor .dequantize (), bias )
1647
1680
1648
1681
def _linear_fp_act_int4_weight_sparse_marlin_check (input_tensor , weight_tensor , bias ):
1649
1682
return (
@@ -1694,6 +1727,7 @@ def _register_aqt_quantized_linear_dispatches():
1694
1727
(_linear_int8_act_int8_weight_semi_structured_sparse_check , _linear_int8_act_int8_weight_semi_structured_sparse_impl ),
1695
1728
(_linear_int8_act_int8_weight_block_sparse_check , _linear_int8_act_int8_weight_block_sparse_impl ),
1696
1729
(_linear_fp8_act_fp8_weight_check , _linear_fp8_act_fp8_weight_impl ),
1730
+ (_linear_fp_act_fp8_weight_check , _linear_fp_act_fp8_weight_impl ),
1697
1731
(_linear_bf16_act_uint4_weight_check , _linear_bf16_act_uint4_weight_impl ),
1698
1732
(_linear_fp_act_int8_weight_check , _linear_fp_act_int8_weight_impl ),
1699
1733
(_linear_f16_act_floatx_weight_check , _linear_f16_act_floatx_weight_impl ),
0 commit comments