@@ -58,10 +58,13 @@ def from_plain(
58
58
):
59
59
pass
60
60
61
+ @torch ._dynamo .disable
61
62
def __repr__ (self ):
62
- int_data , scale , zero_point = self .get_plain ()
63
- layout_type = self .get_layout_type ()
64
- return f"{ self .__class__ .__name__ } (int_data={ int_data } , scale={ scale } , zero_point={ zero_point } , layout_type={ layout_type } )"
63
+ # This is a hack, torch.compile tries to trace the __repr__ function which then calls `dequantize` function, causing an error.
64
+ # by removing the call to dequantize the error goes away.
65
+ # int_data, scale, zero_point = self.get_plain()
66
+ # layout_type = self.get_layout_type()
67
+ return f"{ self .__class__ .__name__ } " #(int_data={int_data}, scale={scale}, zero_point={zero_point}, layout_type={layout_type})"
65
68
66
69
def _get_to_kwargs (self , * args , ** kwargs ):
67
70
device , dtype , _ , memory_format = torch ._C ._nn ._parse_to (* args , ** kwargs )
@@ -152,10 +155,13 @@ def __init__(
152
155
self .quant_max = quant_max
153
156
self .zero_point_domain = zero_point_domain
154
157
158
+ @torch ._dynamo .disable
155
159
def __repr__ (self ):
156
160
return (
157
- f"{ self .__class__ .__name__ } (data={ self .dequantize ()} , shape={ self .shape } , "
158
- f"device={ self .device } , dtype={ self .dtype } , requires_grad={ self .requires_grad } )"
161
+ f"{ self .__class__ .__name__ } "
162
+ # Same hack here
163
+ #(data={self.dequantize()}, shape={self.shape}, "
164
+ #f"device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})"
159
165
)
160
166
161
167
def dequantize (self , output_dtype = None ):
@@ -552,6 +558,8 @@ class MarlinSparseAQTLayout(AQTLayout):
552
558
__torch_dispatch__ = classmethod (_dispatch__torch_dispatch__ )
553
559
__torch_function__ = classmethod (_dispatch__torch_function__ )
554
560
561
+ @staticmethod
562
+ @torch ._dynamo .disable
555
563
def __new__ (
556
564
cls ,
557
565
int_data : torch .Tensor ,
@@ -573,6 +581,7 @@ def __new__(
573
581
shape = int_data .shape
574
582
return torch .Tensor ._make_wrapper_subclass (cls , shape , ** kwargs ) # type: ignore[attr-defined]
575
583
584
+ @torch ._dynamo .disable
576
585
def __init__ (
577
586
self ,
578
587
int_data : torch .Tensor ,
@@ -593,8 +602,24 @@ def __init__(
593
602
self .group_size = group_size
594
603
self .num_bits = num_bits
595
604
605
+ def __tensor_flatten__ (self ):
606
+ return ["int_data" , "scale" , "zero_point" , "meta" ], [self .layout_type , self .original_shape , self .group_size , self .num_bits ]
607
+
608
+ @classmethod
609
+ def __tensor_unflatten__ (
610
+ cls , tensor_data_dict , tensor_attributes , outer_size , outer_stride
611
+ ):
612
+ int_data = tensor_data_dict ["int_data" ]
613
+ scale = tensor_data_dict ["scale" ]
614
+ zero_point = tensor_data_dict ["zero_point" ]
615
+ meta = tensor_data_dict ["meta" ]
616
+ layout_type , original_shape , group_size , num_bits = tensor_attributes
617
+ return cls (int_data , scale , zero_point , meta , layout_type , original_shape , group_size , num_bits )
618
+
619
+ @torch ._dynamo .disable
596
620
def get_plain (self ):
597
621
from torchao .sparsity .marlin import unpack_from_marlin_24 # avoid circular import
622
+ unpack_from_marlin_24 = torch ._dynamo .disable (unpack_from_marlin_24 )
598
623
int_data_expanded , scales_expanded = unpack_from_marlin_24 (
599
624
self .int_data ,
600
625
self .scale ,
@@ -606,6 +631,7 @@ def get_plain(self):
606
631
return int_data_expanded , scales_expanded , self .zero_point
607
632
608
633
@classmethod
634
+ @torch ._dynamo .disable
609
635
def from_plain (
610
636
cls ,
611
637
int_data : torch .Tensor ,
@@ -674,7 +700,7 @@ def _apply_fn_to_data(self, fn):
674
700
@MarlinSparseAQTLayout .implements (aten .detach .default )
675
701
def block_sparse_detach (func , types , args , kwargs ):
676
702
return return_and_correct_aliasing (func , args , kwargs , args [0 ]._apply_fn_to_data (torch .detach ))
677
-
703
+
678
704
679
705
@register_layout_cls (TensorCoreTiledLayoutType )
680
706
class TensorCoreTiledAQTLayout (AQTLayout ):
@@ -920,7 +946,7 @@ def _linear_int8_act_int8_weight_semi_structured_sparse_impl(input_tensor, weigh
920
946
tmp = x_vals_int8 .reshape (- 1 , x_vals_int8 .shape [- 1 ])
921
947
# we fuse one of the scalar matrix multiplications (w_scales) into the sparse mm
922
948
y_dot_bf16_w_scales_fused = torch ._cslt_sparse_mm (
923
- w_vals_int8 , tmp .t (), alpha = w_scales .to (torch .float32 ), out_dtype = torch .bfloat16
949
+ w_vals_int8 , tmp .t (), alpha = w_scales .to (torch .float32 ), out_dtype = torch .bfloat16 ,
924
950
).t ()
925
951
y = (y_dot_bf16_w_scales_fused * x_scales .reshape (- 1 , 1 )).reshape (
926
952
* x_vals_int8 .shape [:- 1 ], y_dot_bf16_w_scales_fused .shape [- 1 ]
@@ -1037,6 +1063,7 @@ def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias):
1037
1063
1038
1064
def _linear_fp_act_int4_weight_sparse_marlin_check (input_tensor , weight_tensor , bias ):
1039
1065
return (
1066
+ isinstance (weight_tensor , AffineQuantizedTensor ) and
1040
1067
_aqt_is_uint4 (weight_tensor ) and
1041
1068
input_tensor .dtype == torch .float16 and
1042
1069
len (weight_tensor .shape ) == 2 and
@@ -1046,11 +1073,13 @@ def _linear_fp_act_int4_weight_sparse_marlin_check(input_tensor, weight_tensor,
1046
1073
1047
1074
def _linear_fp_act_int4_weight_sparse_marlin_impl (input_tensor , weight_tensor , bias ):
1048
1075
from torchao .sparsity .marlin import marlin_24_workspace , const
1076
+ assert isinstance (weight_tensor , AffineQuantizedTensor )
1049
1077
1050
1078
sparse_w_int4 = weight_tensor .layout_tensor .int_data
1051
1079
scale = weight_tensor .layout_tensor .scale
1052
1080
meta = weight_tensor .layout_tensor .meta
1053
1081
original_shape = weight_tensor .layout_tensor .original_shape
1082
+ print ("original_shape" , original_shape )
1054
1083
num_bits = weight_tensor .layout_tensor .num_bits
1055
1084
1056
1085
# Saves batch size for reshaping back to original shape after the matmul
@@ -1059,13 +1088,15 @@ def _linear_fp_act_int4_weight_sparse_marlin_impl(input_tensor, weight_tensor, b
1059
1088
batch_size = - 1
1060
1089
if input_tensor .dim () == 3 :
1061
1090
batch_size = input_tensor .size (0 )
1062
- input_tensor = input_tensor .reshape (- 1 , input_tensor .shape [- 1 ]). contiguous ()
1091
+ input_tensor = input_tensor .reshape (- 1 , input_tensor .shape [- 1 ])
1063
1092
1064
1093
size_m = input_tensor .shape [0 ]
1065
1094
size_n = original_shape [1 ]
1066
1095
size_k = input_tensor .shape [1 ]
1067
1096
workspace_24 = marlin_24_workspace (original_shape [1 ])
1068
1097
1098
+ print (size_m , size_n , size_k )
1099
+
1069
1100
# Pad input_tensor dim 1 to a multiple of the marlin tile size (16)
1070
1101
if size_k % const .TILE != 0 :
1071
1102
pad_size = find_multiple (size_k , const .TILE )
@@ -1076,11 +1107,9 @@ def _linear_fp_act_int4_weight_sparse_marlin_impl(input_tensor, weight_tensor, b
1076
1107
input_tensor , sparse_w_int4 , meta , scale ,
1077
1108
workspace_24 , num_bits , size_m , size_n , size_k
1078
1109
)
1079
- torch .cuda .synchronize ()
1080
1110
1081
- # Reshape back to original shape
1082
1111
if batch_size != - 1 :
1083
- out = out .reshape (batch_size , - 1 , out .shape [- 1 ])
1112
+ out = out .view (batch_size , - 1 , out .shape [- 1 ])
1084
1113
1085
1114
if bias is not None :
1086
1115
out += bias .to (out .dtype )
@@ -1113,14 +1142,14 @@ def _(func, types, args, kwargs):
1113
1142
# using try/except here so that we can have a general fallback when input_tensor/weight_tensor
1114
1143
# is not picked up by any of the dispatch paths in `_quantized_linear_op`, this allows us to
1115
1144
# make the branches easier to understand in `_quantized_linear_op`
1116
- try :
1117
- return weight_tensor ._quantized_linear_op (input_tensor , weight_tensor , bias )
1118
- except :
1119
- if isinstance (input_tensor , AffineQuantizedTensor ):
1120
- input_tensor = input_tensor .dequantize ()
1121
- if isinstance (weight_tensor , AffineQuantizedTensor ):
1122
- weight_tensor = weight_tensor .dequantize ()
1123
- return torch .nn .functional .linear (input_tensor , weight_tensor , bias )
1145
+ # try:
1146
+ return weight_tensor ._quantized_linear_op (input_tensor , weight_tensor , bias )
1147
+ # except:
1148
+ # if isinstance(input_tensor, AffineQuantizedTensor):
1149
+ # input_tensor = input_tensor.dequantize()
1150
+ # if isinstance(weight_tensor, AffineQuantizedTensor):
1151
+ # weight_tensor = weight_tensor.dequantize()
1152
+ # return torch.nn.functional.linear(input_tensor, weight_tensor, bias)
1124
1153
1125
1154
@implements (aten .addmm .default )
1126
1155
def _ (func , types , args , kwargs ):
0 commit comments