Skip to content

Commit fcf3031

Browse files
authored
Update PatchedVLLMKVCache for deepseek performance (#2165)
1 parent eb5f04d commit fcf3031

File tree

1 file changed

+11
-21
lines changed

1 file changed

+11
-21
lines changed

neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1082,41 +1082,31 @@ def forward_qdq(self, input, *args, **kwargs):
10821082
output_cache = self.orig_mod(qinput, *args, **kwargs)
10831083
return output_cache
10841084

1085-
# def forward_quant(self, input, *args, **kwargs):
1086-
# qinput = self.quant_input(input)
1087-
# output_cache = self.orig_mod(qinput, *args, **kwargs)
1088-
# return self.dequant_output(output_cache)
1085+
def forward_quant(self, input, *args, **kwargs):
1086+
qinput = self.quant_input(input)
1087+
output_cache = self.orig_mod(qinput, *args, **kwargs)
1088+
return self.dequant_output(output_cache)
10891089

10901090
def forward_measure(self, input, *args, **kwargs):
10911091
measure_input((input, ), self._mod_extra_config.inputs)
10921092
output_cache = self.orig_mod(input, *args, **kwargs)
10931093
measure_output((output_cache, ), self._mod_extra_config.outputs)
10941094
return output_cache
10951095

1096-
# def fetch_from_cache(self, cache, blocks, permutations=None):
1097-
# # quant_cache = self.quant_input(cache)
1098-
# quant_cache = cache
1099-
# if permutations:
1100-
# output_cache = self.orig_mod.fetch_from_cache(quant_cache, blocks, permutations)
1101-
# for i in range(len(output_cache)):
1102-
# output_cache[i] = self.dequant_output(output_cache[i])
1103-
# return output_cache
1104-
# output_cache = self.orig_mod.fetch_from_cache(quant_cache, blocks)
1105-
# return self.dequant_output(output_cache)
1106-
1107-
def forward_quant(self, input, *args, **kwargs):
1108-
qinput = self.quant_input(input)
1109-
return self.orig_mod(qinput, *args, **kwargs)
1110-
1111-
def fetch_from_cache(self, quant_cache, blocks, permutations=None):
1096+
def fetch_from_cache(self, cache, blocks, permutations=None):
1097+
# TODO: Remove this workaround in next release [SW-221595]
1098+
if cache.dtype != self.lp_dtype:
1099+
quant_cache = self.quant_input(cache)
1100+
else:
1101+
quant_cache = cache
11121102
if permutations:
11131103
output_cache = self.orig_mod.fetch_from_cache(quant_cache, blocks, permutations)
11141104
for i in range(len(output_cache)):
11151105
output_cache[i] = self.dequant_output(output_cache[i])
11161106
return output_cache
11171107
output_cache = self.orig_mod.fetch_from_cache(quant_cache, blocks)
11181108
return self.dequant_output(output_cache)
1119-
1109+
11201110
def extra_repr(self) -> str:
11211111
return f"PatchedVLLMKVCache"
11221112

0 commit comments

Comments
 (0)