Skip to content

Commit 7fdca33

Browse files
committed
ggml : full ALiBi support
1 parent d11afd6 commit 7fdca33

10 files changed

+82
-680
lines changed

ggml-cuda.cu

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
#include "ggml-cuda/common.cuh"
66
#include "ggml-cuda/acc.cuh"
7-
#include "ggml-cuda/alibi.cuh"
87
#include "ggml-cuda/arange.cuh"
98
#include "ggml-cuda/argsort.cuh"
109
#include "ggml-cuda/binbcast.cuh"
@@ -2277,9 +2276,6 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
22772276
case GGML_OP_ROPE:
22782277
ggml_cuda_op_rope(ctx, dst);
22792278
break;
2280-
case GGML_OP_ALIBI:
2281-
ggml_cuda_op_alibi(ctx, dst);
2282-
break;
22832279
case GGML_OP_IM2COL:
22842280
ggml_cuda_op_im2col(ctx, dst);
22852281
break;
@@ -2829,7 +2825,6 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
28292825
case GGML_OP_DIAG_MASK_INF:
28302826
case GGML_OP_SOFT_MAX:
28312827
case GGML_OP_ROPE:
2832-
case GGML_OP_ALIBI:
28332828
case GGML_OP_IM2COL:
28342829
case GGML_OP_POOL_2D:
28352830
case GGML_OP_SUM_ROWS:

ggml-cuda/alibi.cu

Lines changed: 0 additions & 63 deletions
This file was deleted.

ggml-cuda/alibi.cuh

Lines changed: 0 additions & 5 deletions
This file was deleted.

ggml-metal.m

Lines changed: 10 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,6 @@
169169
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
170170
GGML_METAL_KERNEL_TYPE_ROPE_F32,
171171
GGML_METAL_KERNEL_TYPE_ROPE_F16,
172-
GGML_METAL_KERNEL_TYPE_ALIBI_F32,
173172
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
174173
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
175174
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
@@ -623,7 +622,6 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
623622
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
624623
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
625624
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
626-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
627625
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
628626
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
629627
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
@@ -759,7 +757,6 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
759757
case GGML_OP_GROUP_NORM:
760758
return ctx->support_simdgroup_reduction;
761759
case GGML_OP_NORM:
762-
case GGML_OP_ALIBI:
763760
case GGML_OP_ROPE:
764761
case GGML_OP_IM2COL:
765762
return true;
@@ -1357,13 +1354,12 @@ static enum ggml_status ggml_metal_graph_compute(
13571354
case GGML_OP_SOFT_MAX:
13581355
{
13591356
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
1360-
GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32);
13611357

13621358
int nth = 32; // SIMD width
13631359

13641360
id<MTLComputePipelineState> pipeline = nil;
13651361

1366-
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
1362+
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
13671363

13681364
if (ne00%4 == 0) {
13691365
while (nth < ne00/4 && nth < 256) {
@@ -1407,20 +1403,15 @@ static enum ggml_status ggml_metal_graph_compute(
14071403
} else {
14081404
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
14091405
}
1410-
if (id_src2) {
1411-
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
1412-
} else {
1413-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:2];
1414-
}
1415-
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
1416-
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:4];
1417-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:5];
1418-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:6];
1419-
[encoder setBytes:&scale length:sizeof(scale) atIndex:7];
1420-
[encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:8];
1421-
[encoder setBytes:&m0 length:sizeof(m0) atIndex:9];
1422-
[encoder setBytes:&m1 length:sizeof(m1) atIndex:10];
1423-
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:11];
1406+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1407+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1408+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1409+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1410+
[encoder setBytes:&scale length:sizeof(scale) atIndex:6];
1411+
[encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7];
1412+
[encoder setBytes:&m0 length:sizeof(m0) atIndex:8];
1413+
[encoder setBytes:&m1 length:sizeof(m1) atIndex:9];
1414+
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10];
14241415
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
14251416

14261417
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
@@ -2225,49 +2216,6 @@ static enum ggml_status ggml_metal_graph_compute(
22252216

22262217
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
22272218
} break;
2228-
case GGML_OP_ALIBI:
2229-
{
2230-
GGML_ASSERT((src0t == GGML_TYPE_F32));
2231-
2232-
const int nth = MIN(1024, ne00);
2233-
2234-
//const int n_past = ((int32_t *) dst->op_params)[0];
2235-
const int n_head = ((int32_t *) dst->op_params)[1];
2236-
2237-
float max_bias;
2238-
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
2239-
2240-
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
2241-
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
2242-
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
2243-
2244-
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ALIBI_F32].pipeline;
2245-
2246-
[encoder setComputePipelineState:pipeline];
2247-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2248-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2249-
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2250-
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
2251-
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
2252-
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
2253-
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
2254-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
2255-
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
2256-
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
2257-
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
2258-
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
2259-
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
2260-
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
2261-
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
2262-
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
2263-
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
2264-
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
2265-
[encoder setBytes:&m0 length:sizeof( float) atIndex:18];
2266-
[encoder setBytes:&m1 length:sizeof( float) atIndex:19];
2267-
[encoder setBytes:&n_heads_log2_floor length:sizeof(int) atIndex:20];
2268-
2269-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2270-
} break;
22712219
case GGML_OP_ROPE:
22722220
{
22732221
GGML_ASSERT(ne10 == ne02);

ggml-metal.metal

Lines changed: 6 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,6 @@ template<typename T>
356356
kernel void kernel_soft_max(
357357
device const char * src0,
358358
device const char * src1,
359-
device const char * src2,
360359
device char * dst,
361360
constant int64_t & ne00,
362361
constant int64_t & ne01,
@@ -378,10 +377,9 @@ kernel void kernel_soft_max(
378377

379378
device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
380379
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr;
381-
device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
382380
device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
383381

384-
float slope = 0.0f;
382+
float slope = 1.0f;
385383

386384
// ALiBi
387385
if (max_bias > 0.0f) {
@@ -397,7 +395,7 @@ kernel void kernel_soft_max(
397395
float lmax = -INFINITY;
398396

399397
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
400-
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f));
398+
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
401399
}
402400

403401
// find the max value in the block
@@ -422,7 +420,7 @@ kernel void kernel_soft_max(
422420
// parallel sum
423421
float lsum = 0.0f;
424422
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
425-
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)) - max_val);
423+
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
426424
lsum += exp_psrc0;
427425
pdst[i00] = exp_psrc0;
428426
}
@@ -461,7 +459,6 @@ template<typename T>
461459
kernel void kernel_soft_max_4(
462460
device const char * src0,
463461
device const char * src1,
464-
device const char * src2,
465462
device char * dst,
466463
constant int64_t & ne00,
467464
constant int64_t & ne01,
@@ -483,10 +480,9 @@ kernel void kernel_soft_max_4(
483480

484481
device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
485482
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr;
486-
device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
487483
device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
488484

489-
float slope = 0.0f;
485+
float slope = 1.0f;
490486

491487
if (max_bias > 0.0f) {
492488
const int64_t h = i02;
@@ -501,7 +497,7 @@ kernel void kernel_soft_max_4(
501497
float4 lmax4 = -INFINITY;
502498

503499
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
504-
lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)));
500+
lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
505501
}
506502

507503
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
@@ -527,7 +523,7 @@ kernel void kernel_soft_max_4(
527523
// parallel sum
528524
float4 lsum4 = 0.0f;
529525
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
530-
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f))) - max_val);
526+
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
531527
lsum4 += exp_psrc4;
532528
pdst4[i00] = exp_psrc4;
533529
}
@@ -1595,60 +1591,6 @@ kernel void kernel_mul_mv_f16_f32_l4(
15951591
}
15961592
}
15971593

1598-
kernel void kernel_alibi_f32(
1599-
device const float * src0,
1600-
device float * dst,
1601-
constant int64_t & ne00,
1602-
constant int64_t & ne01,
1603-
constant int64_t & ne02,
1604-
constant int64_t & ne03,
1605-
constant uint64_t & nb00,
1606-
constant uint64_t & nb01,
1607-
constant uint64_t & nb02,
1608-
constant uint64_t & nb03,
1609-
constant int64_t & ne0,
1610-
constant int64_t & ne1,
1611-
constant int64_t & ne2,
1612-
constant int64_t & ne3,
1613-
constant uint64_t & nb0,
1614-
constant uint64_t & nb1,
1615-
constant uint64_t & nb2,
1616-
constant uint64_t & nb3,
1617-
constant float & m0,
1618-
constant float & m1,
1619-
constant int & n_heads_log2_floor,
1620-
uint3 tgpig[[threadgroup_position_in_grid]],
1621-
uint3 tpitg[[thread_position_in_threadgroup]],
1622-
uint3 ntg[[threads_per_threadgroup]]) {
1623-
const int64_t i03 = tgpig[2];
1624-
const int64_t i02 = tgpig[1];
1625-
const int64_t i01 = tgpig[0];
1626-
1627-
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
1628-
1629-
const int64_t i3 = n / (ne2*ne1*ne0);
1630-
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1631-
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1632-
//const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
1633-
1634-
const int64_t k = i3*ne3 + i2;
1635-
1636-
float m_k;
1637-
if (k < n_heads_log2_floor) {
1638-
m_k = pow(m0, k + 1);
1639-
} else {
1640-
m_k = pow(m1, 2 * (k - n_heads_log2_floor) + 1);
1641-
}
1642-
1643-
device char * dst_row = (device char *) dst + i3*nb3 + i2*nb2 + i1*nb1;
1644-
device const char * src_row = (device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01;
1645-
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
1646-
const float src_v = *(device float *)(src_row + i00*nb00);
1647-
device float * dst_v = (device float *)(dst_row + i00*nb0);
1648-
*dst_v = i00 * m_k + src_v;
1649-
}
1650-
}
1651-
16521594
static float rope_yarn_ramp(const float low, const float high, const int i0) {
16531595
const float y = (i0 / 2 - low) / max(0.001f, high - low);
16541596
return 1.0f - min(1.0f, max(0.0f, y));

0 commit comments

Comments
 (0)