@@ -175,8 +175,6 @@ def from_float(
175
175
dtype = torch .int16
176
176
scale , zero_point = choose_qparams_affine (input_float , mapping_type , block_size , dtype )
177
177
int_data = quantize_affine (input_float , block_size , scale , zero_point , dtype )
178
- # int_data = (input_float / scale).to(torch.int8)
179
- print ("initial:" , scale .shape , " int data:" , int_data .shape )
180
178
layout_tensor_ctr = get_layout_tensor_constructor (type (layout_type ))
181
179
layout_tensor = layout_tensor_ctr (int_data , scale , layout_type )
182
180
return cls (layout_tensor , input_float .shape )
@@ -309,6 +307,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
309
307
func , args , kwargs , args [0 ]._apply_fn_to_data (torch .detach )
310
308
)
311
309
310
+ # Tensor parallel support START
312
311
elif func in [aten ._to_copy .default , aten .clone .default ]:
313
312
return return_and_correct_aliasing (
314
313
func , args , kwargs , args [0 ]._apply_fn_to_data (torch .clone )
@@ -334,6 +333,8 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
334
333
elif func is aten .t .default :
335
334
return return_and_correct_aliasing (func , args , kwargs , PlainMyDTypeLayout (args [0 ].int_data , args [0 ].scale , not args [0 ].transposed , args [0 ].layout_type ))
336
335
336
+ # Tensor parallel support END
337
+
337
338
raise NotImplementedError (
338
339
f"PlainMyDTypeLayout dispatch: attempting to run { func } , this is not supported"
339
340
)
0 commit comments