Skip to content

Commit eded411

Browse files
committed
merge
2 parents 51d0018 + 352c7a0 commit eded411

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

src/cpp/src/continuous_batching/cache_manager.hpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ class CacheManager {
4646
m_device = execution_devices[0];
4747
// set block_size depending on device
4848
const size_t cpu_block_size = 32, gpu_block_size = 16, gpu_block_size_xattn = 256;
49-
bool has_xattention = false;
5049

5150
if (all_gpu_device) {
5251
m_context = m_request.get_compiled_model().get_context();
@@ -61,11 +60,6 @@ class CacheManager {
6160
if (name.find("key_cache.") == 0) {
6261
pshape = input.get_partial_shape();
6362
m_block_size_in_bytes += pshape[1].get_length() * pshape[2].get_length() * pshape[3].get_length() * cache_precision.size();
64-
if (pshape[2].get_length() == 256 && pshape[3].get_length() != 16) {
65-
// use xattention layout
66-
// TODO: better check condition ?
67-
has_xattention = true;
68-
}
6963
m_key_shapes.push_back(pshape);
7064
m_key_precisions.push_back(cache_precision);
7165
break;
@@ -79,6 +73,13 @@ class CacheManager {
7973
}
8074
}
8175
}
76+
77+
bool has_xattention = false;
78+
if ((m_key_shapes[0][2].get_length() == m_value_shapes[0][2].get_length()) &&
79+
(m_key_shapes[0][3].get_length() == m_value_shapes[0][3].get_length()) &&
80+
(m_key_shapes[0][2].get_length() == gpu_block_size_xattn)) {
81+
has_xattention = true;
82+
}
8283
m_block_size = all_gpu_device ? ( has_xattention ? gpu_block_size_xattn : gpu_block_size ) : cpu_block_size;
8384
m_num_decoder_layers = m_value_precisions.size();
8485
OPENVINO_ASSERT(m_num_decoder_layers == m_key_precisions.size(), "Invalid case: a different number of K and V caches in a LLM model");

0 commit comments

Comments
 (0)