|
169 | 169 | GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
|
170 | 170 | GGML_METAL_KERNEL_TYPE_ROPE_F32,
|
171 | 171 | GGML_METAL_KERNEL_TYPE_ROPE_F16,
|
172 |
| - GGML_METAL_KERNEL_TYPE_ALIBI_F32, |
173 | 172 | GGML_METAL_KERNEL_TYPE_IM2COL_F16,
|
174 | 173 | GGML_METAL_KERNEL_TYPE_IM2COL_F32,
|
175 | 174 | GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
|
@@ -623,7 +622,6 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
|
623 | 622 | GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
|
624 | 623 | GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
|
625 | 624 | GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
|
626 |
| - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true); |
627 | 625 | GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
|
628 | 626 | GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
|
629 | 627 | GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
@@ -759,7 +757,6 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
759 | 757 | case GGML_OP_GROUP_NORM:
|
760 | 758 | return ctx->support_simdgroup_reduction;
|
761 | 759 | case GGML_OP_NORM:
|
762 |
| - case GGML_OP_ALIBI: |
763 | 760 | case GGML_OP_ROPE:
|
764 | 761 | case GGML_OP_IM2COL:
|
765 | 762 | return true;
|
@@ -1357,13 +1354,12 @@ static enum ggml_status ggml_metal_graph_compute(
|
1357 | 1354 | case GGML_OP_SOFT_MAX:
|
1358 | 1355 | {
|
1359 | 1356 | GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
|
1360 |
| - GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32); |
1361 | 1357 |
|
1362 | 1358 | int nth = 32; // SIMD width
|
1363 | 1359 |
|
1364 | 1360 | id<MTLComputePipelineState> pipeline = nil;
|
1365 | 1361 |
|
1366 |
| - const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16); |
| 1362 | + const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); |
1367 | 1363 |
|
1368 | 1364 | if (ne00%4 == 0) {
|
1369 | 1365 | while (nth < ne00/4 && nth < 256) {
|
@@ -1407,20 +1403,15 @@ static enum ggml_status ggml_metal_graph_compute(
|
1407 | 1403 | } else {
|
1408 | 1404 | [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
1409 | 1405 | }
|
1410 |
| - if (id_src2) { |
1411 |
| - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; |
1412 |
| - } else { |
1413 |
| - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2]; |
1414 |
| - } |
1415 |
| - [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; |
1416 |
| - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:4]; |
1417 |
| - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:5]; |
1418 |
| - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:6]; |
1419 |
| - [encoder setBytes:&scale length:sizeof(scale) atIndex:7]; |
1420 |
| - [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:8]; |
1421 |
| - [encoder setBytes:&m0 length:sizeof(m0) atIndex:9]; |
1422 |
| - [encoder setBytes:&m1 length:sizeof(m1) atIndex:10]; |
1423 |
| - [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:11]; |
| 1406 | + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; |
| 1407 | + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; |
| 1408 | + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; |
| 1409 | + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; |
| 1410 | + [encoder setBytes:&scale length:sizeof(scale) atIndex:6]; |
| 1411 | + [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7]; |
| 1412 | + [encoder setBytes:&m0 length:sizeof(m0) atIndex:8]; |
| 1413 | + [encoder setBytes:&m1 length:sizeof(m1) atIndex:9]; |
| 1414 | + [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10]; |
1424 | 1415 | [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
1425 | 1416 |
|
1426 | 1417 | [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
@@ -2225,49 +2216,6 @@ static enum ggml_status ggml_metal_graph_compute(
|
2225 | 2216 |
|
2226 | 2217 | [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
2227 | 2218 | } break;
|
2228 |
| - case GGML_OP_ALIBI: |
2229 |
| - { |
2230 |
| - GGML_ASSERT((src0t == GGML_TYPE_F32)); |
2231 |
| - |
2232 |
| - const int nth = MIN(1024, ne00); |
2233 |
| - |
2234 |
| - //const int n_past = ((int32_t *) dst->op_params)[0]; |
2235 |
| - const int n_head = ((int32_t *) dst->op_params)[1]; |
2236 |
| - |
2237 |
| - float max_bias; |
2238 |
| - memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); |
2239 |
| - |
2240 |
| - const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); |
2241 |
| - const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); |
2242 |
| - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); |
2243 |
| - |
2244 |
| - id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ALIBI_F32].pipeline; |
2245 |
| - |
2246 |
| - [encoder setComputePipelineState:pipeline]; |
2247 |
| - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; |
2248 |
| - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; |
2249 |
| - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; |
2250 |
| - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; |
2251 |
| - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; |
2252 |
| - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; |
2253 |
| - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; |
2254 |
| - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; |
2255 |
| - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; |
2256 |
| - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; |
2257 |
| - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; |
2258 |
| - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; |
2259 |
| - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; |
2260 |
| - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; |
2261 |
| - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; |
2262 |
| - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; |
2263 |
| - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; |
2264 |
| - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; |
2265 |
| - [encoder setBytes:&m0 length:sizeof( float) atIndex:18]; |
2266 |
| - [encoder setBytes:&m1 length:sizeof( float) atIndex:19]; |
2267 |
| - [encoder setBytes:&n_heads_log2_floor length:sizeof(int) atIndex:20]; |
2268 |
| - |
2269 |
| - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; |
2270 |
| - } break; |
2271 | 2219 | case GGML_OP_ROPE:
|
2272 | 2220 | {
|
2273 | 2221 | GGML_ASSERT(ne10 == ne02);
|
|
0 commit comments