@@ -46,7 +46,6 @@ class CacheManager {
46
46
m_device = execution_devices[0 ];
47
47
// set block_size depending on device
48
48
const size_t cpu_block_size = 32 , gpu_block_size = 16 , gpu_block_size_xattn = 256 ;
49
- bool has_xattention = false ;
50
49
51
50
if (all_gpu_device) {
52
51
m_context = m_request.get_compiled_model ().get_context ();
@@ -61,11 +60,6 @@ class CacheManager {
61
60
if (name.find (" key_cache." ) == 0 ) {
62
61
pshape = input.get_partial_shape ();
63
62
m_block_size_in_bytes += pshape[1 ].get_length () * pshape[2 ].get_length () * pshape[3 ].get_length () * cache_precision.size ();
64
- if (pshape[2 ].get_length () == 256 && pshape[3 ].get_length () != 16 ) {
65
- // use xattention layout
66
- // TODO: better check condition ?
67
- has_xattention = true ;
68
- }
69
63
m_key_shapes.push_back (pshape);
70
64
m_key_precisions.push_back (cache_precision);
71
65
break ;
@@ -79,7 +73,15 @@ class CacheManager {
79
73
}
80
74
}
81
75
}
76
+
77
+ bool has_xattention = false ;
78
+ if ((m_key_shapes[0 ][2 ].get_length () == m_value_shapes[0 ][2 ].get_length ()) &&
79
+ (m_key_shapes[0 ][3 ].get_length () == m_value_shapes[0 ][3 ].get_length ()) &&
80
+ (m_key_shapes[0 ][2 ].get_length () == gpu_block_size_xattn)) {
81
+ has_xattention = true ;
82
+ }
82
83
m_block_size = all_gpu_device ? ( has_xattention ? gpu_block_size_xattn : gpu_block_size ) : cpu_block_size;
84
+ printf (" has_xattention is %d, m_block_size is %d. \n " , has_xattention, m_block_size);
83
85
m_num_decoder_layers = m_value_precisions.size ();
84
86
OPENVINO_ASSERT (m_num_decoder_layers == m_key_precisions.size (), " Invalid case: a different number of K and V caches in a LLM model" );
85
87
}
0 commit comments