@@ -1232,8 +1232,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
1232
1232
std::cerr << " ggml_vulkan: Compiling shaders" ;
1233
1233
1234
1234
// mulmat
1235
- std::vector<uint32_t > l_warptile, m_warptile, s_warptile, l_warptile_mmq, m_warptile_mmq, s_warptile_mmq;
1236
- std::array<uint32_t , 3 > l_wg_denoms, m_wg_denoms, s_wg_denoms;
1235
+ std::vector<uint32_t > l_warptile, m_warptile, s_warptile,
1236
+ l_warptile_mmq, m_warptile_mmq, s_warptile_mmq;
1237
+ std::array<uint32_t , 3 > l_wg_denoms, m_wg_denoms, s_wg_denoms,
1238
+ l_mmq_wg_denoms, m_mmq_wg_denoms, s_mmq_wg_denoms;
1237
1239
uint32_t l_align, m_align, s_align;
1238
1240
1239
1241
l_warptile = { 128 , 128 , 128 , 16 , device->subgroup_size * 2 , 64 , 2 , 4 , 4 , device->subgroup_size };
@@ -1244,14 +1246,48 @@ static void ggml_vk_load_shaders(vk_device& device) {
1244
1246
m_warptile_mmq = { 128 , 64 , 64 , 32 , device->subgroup_size , 32 , 2 , 4 , 2 , device->subgroup_size };
1245
1247
s_warptile_mmq = { std::max (device->subgroup_size , 16u ), 32 , 32 , 32 , 32 , 32 , 2 , 2 , 2 , device->subgroup_size };
1246
1248
1247
- l_wg_denoms = {128 , 128 , 1 };
1248
- m_wg_denoms = { 64 , 64 , 1 };
1249
- s_wg_denoms = { 32 , 32 , 1 };
1249
+ l_mmq_wg_denoms = l_wg_denoms = {128 , 128 , 1 };
1250
+ m_mmq_wg_denoms = m_wg_denoms = { 64 , 64 , 1 };
1251
+ s_mmq_wg_denoms = s_wg_denoms = { 32 , 32 , 1 };
1250
1252
1251
1253
l_align = 128 ;
1252
1254
m_align = 64 ;
1253
1255
s_align = 32 ;
1254
1256
1257
+ // Fallback to smaller sizes if there's not enough shared memory. Given the current shaders
1258
+ // and tile sizes, this should handle 16KB, 32KB, and 48KB+.
1259
+ // This logic doesn't explicitly account for the 12KB row_ids in the mul_mat_mat_id shaders.
1260
+ // But the numbers happen to work out for 32KB shared memory size that when using the medium
1261
+ // size there's enough room for everything, and we assert for this.
1262
+ uint32_t shmem_needed = (l_warptile[1 ] + l_warptile[2 ]) * (l_warptile[3 ] + 1 ) * sizeof (float );
1263
+ if (shmem_needed > device->properties .limits .maxComputeSharedMemorySize ) {
1264
+ l_warptile = m_warptile;
1265
+ l_wg_denoms = m_wg_denoms;
1266
+ shmem_needed = (l_warptile[1 ] + l_warptile[2 ]) * (l_warptile[3 ] + 1 ) * sizeof (float );
1267
+ GGML_ASSERT (shmem_needed <= device->properties .limits .maxComputeSharedMemorySize );
1268
+ }
1269
+ if (device->properties .limits .maxComputeSharedMemorySize >= 32768 ) {
1270
+ // assert mul_mat_mat_id shaders will fit.
1271
+ GGML_ASSERT (shmem_needed + 3072 *4 <= device->properties .limits .maxComputeSharedMemorySize );
1272
+ }
1273
+
1274
+ shmem_needed = (l_warptile_mmq[1 ] + l_warptile_mmq[2 ]) * (l_warptile_mmq[3 ] + 1 ) * sizeof (float );
1275
+ if (shmem_needed > device->properties .limits .maxComputeSharedMemorySize ) {
1276
+ if (device->properties .limits .maxComputeSharedMemorySize == 32768 ) {
1277
+ l_warptile_mmq = m_warptile_mmq;
1278
+ l_mmq_wg_denoms = m_mmq_wg_denoms;
1279
+ } else {
1280
+ l_warptile_mmq = s_warptile_mmq;
1281
+ l_mmq_wg_denoms = s_mmq_wg_denoms;
1282
+ }
1283
+ shmem_needed = (l_warptile_mmq[1 ] + l_warptile_mmq[2 ]) * (l_warptile_mmq[3 ] + 1 ) * sizeof (float );
1284
+ GGML_ASSERT (shmem_needed <= device->properties .limits .maxComputeSharedMemorySize );
1285
+ }
1286
+ if (device->properties .limits .maxComputeSharedMemorySize >= 32768 ) {
1287
+ // assert mul_mat_mat_id shaders will fit.
1288
+ GGML_ASSERT (shmem_needed + 3072 *4 <= device->properties .limits .maxComputeSharedMemorySize );
1289
+ }
1290
+
1255
1291
device->pipeline_matmul_f32 = std::make_shared<vk_matmul_pipeline_struct>();
1256
1292
device->pipeline_matmul_f32_f16 = std::make_shared<vk_matmul_pipeline_struct>();
1257
1293
@@ -1299,35 +1335,38 @@ static void ggml_vk_load_shaders(vk_device& device) {
1299
1335
CREATE_MM (pipeline_matmul_f16.f32acc , matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3 );
1300
1336
CREATE_MM (pipeline_matmul_f16_f32.f32acc , matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3 );
1301
1337
1302
- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc , matmul_q4_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1303
- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc , matmul_q4_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1304
- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc , matmul_q5_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1305
- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc , matmul_q5_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1306
- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc , matmul_q8_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1307
-
1308
- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc , matmul_q2_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1309
- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc , matmul_q3_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1310
- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc , matmul_q4_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1311
- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc , matmul_q5_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1312
- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc , matmul_q6_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1313
- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc , matmul_iq4_nl_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1314
-
1315
- CREATE_MM (pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4 );
1316
- CREATE_MM (pipeline_matmul_id_f16, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4 );
1317
- CREATE_MM (pipeline_matmul_id_f16_f32, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4 );
1318
-
1319
- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1320
- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1321
- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1322
- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1323
- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1324
-
1325
- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1326
- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1327
- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1328
- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1329
- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1330
- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1338
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc , matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1339
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc , matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1340
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc , matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1341
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc , matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1342
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc , matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1343
+
1344
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc , matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1345
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc , matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1346
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc , matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1347
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc , matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1348
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc , matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1349
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc , matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1350
+
1351
+ // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
1352
+ if (device->properties .limits .maxComputeSharedMemorySize >= 32768 ) {
1353
+ CREATE_MM (pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4 );
1354
+ CREATE_MM (pipeline_matmul_id_f16, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4 );
1355
+ CREATE_MM (pipeline_matmul_id_f16_f32, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4 );
1356
+
1357
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1358
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1359
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1360
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1361
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1362
+
1363
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1364
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1365
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1366
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1367
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1368
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1369
+ }
1331
1370
#undef CREATE_MM
1332
1371
} else {
1333
1372
// Create 6 variants, {s,m,l}x{unaligned,aligned}
@@ -1344,35 +1383,38 @@ static void ggml_vk_load_shaders(vk_device& device) {
1344
1383
CREATE_MM (pipeline_matmul_f16.f32acc , matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3 );
1345
1384
CREATE_MM (pipeline_matmul_f16_f32.f32acc , matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3 );
1346
1385
1347
- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc , matmul_q4_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1348
- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc , matmul_q4_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1349
- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc , matmul_q5_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1350
- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc , matmul_q5_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1351
- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc , matmul_q8_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1352
-
1353
- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc , matmul_q2_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1354
- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc , matmul_q3_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1355
- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc , matmul_q4_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1356
- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc , matmul_q5_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1357
- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc , matmul_q6_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1358
- CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc , matmul_iq4_nl_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1359
-
1360
- CREATE_MM (pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4 );
1361
- CREATE_MM (pipeline_matmul_id_f16, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4 );
1362
- CREATE_MM (pipeline_matmul_id_f16_f32, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4 );
1363
-
1364
- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1365
- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1366
- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1367
- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1368
- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1369
-
1370
- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1371
- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1372
- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1373
- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1374
- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1375
- CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1386
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc , matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1387
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc , matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1388
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc , matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1389
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc , matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1390
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc , matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1391
+
1392
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc , matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1393
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc , matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1394
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc , matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1395
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc , matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1396
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc , matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1397
+ CREATE_MM (pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc , matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3 );
1398
+
1399
+ // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
1400
+ if (device->properties .limits .maxComputeSharedMemorySize >= 32768 ) {
1401
+ CREATE_MM (pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4 );
1402
+ CREATE_MM (pipeline_matmul_id_f16, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4 );
1403
+ CREATE_MM (pipeline_matmul_id_f16_f32, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4 );
1404
+
1405
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1406
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1407
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1408
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1409
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1410
+
1411
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1412
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1413
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1414
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1415
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1416
+ CREATE_MM (pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4 );
1417
+ }
1376
1418
#undef CREATE_MM
1377
1419
}
1378
1420
@@ -6541,6 +6583,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
6541
6583
case GGML_OP_MUL_MAT:
6542
6584
case GGML_OP_MUL_MAT_ID:
6543
6585
{
6586
+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context ;
6587
+ if (op->op == GGML_OP_MUL_MAT_ID &&
6588
+ ggml_vk_get_device (ctx->device )->properties .limits .maxComputeSharedMemorySize < 32768 ) {
6589
+ // If there's not enough shared memory for row_ids and the result tile, fallback to CPU
6590
+ return false ;
6591
+ }
6544
6592
switch (op->src [0 ]->type ) {
6545
6593
case GGML_TYPE_F32:
6546
6594
case GGML_TYPE_F16:
0 commit comments