1
1
"""The group quantization config"""
2
2
from dataclasses import dataclass , field
3
- from typing import Any , Callable , Dict , List , Optional , Tuple
3
+ from typing import Any , Callable , Dict , List , Optional , Sequence , Tuple
4
4
5
5
import numpy as np
6
6
from tvm import DataType , DataTypeCode
@@ -128,7 +128,7 @@ def _dequantize(
128
128
):
129
129
tir_bin_mask = tir .const ((1 << DataType (self .quantize_dtype ).bits ) - 1 , self .storage_dtype )
130
130
tir_max_int = tir .const (self .max_int_value , self .model_dtype )
131
- dequantized_weight = te .compute (
131
+ return te .compute (
132
132
shape = [weight .shape [0 ], weight .shape [1 ] * self .num_elem_per_storage ]
133
133
if out_shape is None
134
134
else out_shape ,
@@ -149,8 +149,8 @@ def _dequantize(
149
149
),
150
150
scale [i , j // self .group_size ],
151
151
),
152
+ name = "decode" ,
152
153
)
153
- return dequantized_weight
154
154
155
155
def quantize_weight (self , weight : NDArray ) -> List [NDArray ]:
156
156
"""
@@ -186,8 +186,10 @@ def quantize_weight(self, weight: NDArray) -> List[NDArray]:
186
186
if target is None :
187
187
target = Target .from_device (dev )
188
188
with target :
189
- mod = dl .ApplyDefaultSchedule ( # pylint: disable=not-callable
190
- dl .gpu .Reduction (), dl .gpu .GeneralReduction (), dl .gpu .Fallback ()
189
+ mod = dl .ApplyDefaultSchedule ( # type: ignore # pylint: disable=not-callable
190
+ dl .gpu .Reduction (),
191
+ dl .gpu .GeneralReduction (),
192
+ dl .gpu .Fallback (),
191
193
)(mod )
192
194
elif device_type == "cpu" :
193
195
target = "llvm"
@@ -400,7 +402,7 @@ def from_multilinear(
400
402
out_dtype = multi_linear .out_dtype ,
401
403
)
402
404
403
- def forward (self , x : nn .Tensor ) -> nn .Tensor : # pylint: disable=invalid-name
405
+ def forward (self , x : nn .Tensor ) -> Sequence [ nn .Tensor ] : # pylint: disable=invalid-name
404
406
"""
405
407
Forward method for multi linear layer.
406
408
0 commit comments