@@ -1082,41 +1082,31 @@ def forward_qdq(self, input, *args, **kwargs):
1082
1082
output_cache = self .orig_mod (qinput , * args , ** kwargs )
1083
1083
return output_cache
1084
1084
1085
- # def forward_quant(self, input, *args, **kwargs):
1086
- # qinput = self.quant_input(input)
1087
- # output_cache = self.orig_mod(qinput, *args, **kwargs)
1088
- # return self.dequant_output(output_cache)
1085
+ def forward_quant (self , input , * args , ** kwargs ):
1086
+ qinput = self .quant_input (input )
1087
+ output_cache = self .orig_mod (qinput , * args , ** kwargs )
1088
+ return self .dequant_output (output_cache )
1089
1089
1090
1090
def forward_measure (self , input , * args , ** kwargs ):
1091
1091
measure_input ((input , ), self ._mod_extra_config .inputs )
1092
1092
output_cache = self .orig_mod (input , * args , ** kwargs )
1093
1093
measure_output ((output_cache , ), self ._mod_extra_config .outputs )
1094
1094
return output_cache
1095
1095
1096
- # def fetch_from_cache(self, cache, blocks, permutations=None):
1097
- # # quant_cache = self.quant_input(cache)
1098
- # quant_cache = cache
1099
- # if permutations:
1100
- # output_cache = self.orig_mod.fetch_from_cache(quant_cache, blocks, permutations)
1101
- # for i in range(len(output_cache)):
1102
- # output_cache[i] = self.dequant_output(output_cache[i])
1103
- # return output_cache
1104
- # output_cache = self.orig_mod.fetch_from_cache(quant_cache, blocks)
1105
- # return self.dequant_output(output_cache)
1106
-
1107
- def forward_quant (self , input , * args , ** kwargs ):
1108
- qinput = self .quant_input (input )
1109
- return self .orig_mod (qinput , * args , ** kwargs )
1110
-
1111
- def fetch_from_cache (self , quant_cache , blocks , permutations = None ):
1096
+ def fetch_from_cache (self , cache , blocks , permutations = None ):
1097
+ # TODO: Remove this workaround in next release [SW-221595]
1098
+ if cache .dtype != self .lp_dtype :
1099
+ quant_cache = self .quant_input (cache )
1100
+ else :
1101
+ quant_cache = cache
1112
1102
if permutations :
1113
1103
output_cache = self .orig_mod .fetch_from_cache (quant_cache , blocks , permutations )
1114
1104
for i in range (len (output_cache )):
1115
1105
output_cache [i ] = self .dequant_output (output_cache [i ])
1116
1106
return output_cache
1117
1107
output_cache = self .orig_mod .fetch_from_cache (quant_cache , blocks )
1118
1108
return self .dequant_output (output_cache )
1119
-
1109
+
1120
1110
def extra_repr (self ) -> str :
1121
1111
return f"PatchedVLLMKVCache"
1122
1112
0 commit comments