Skip to content

Commit 9cc51d3

Browse files
committed
optimise GGML_OP_SUM
1 parent 3f750f8 commit 9cc51d3

File tree

2 files changed

+50
-7
lines changed

2 files changed

+50
-7
lines changed

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -866,12 +866,25 @@ int ggml_metal_op_sum(ggml_metal_op_t ctx, int idx) {
866866

867867
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum(lib, op);
868868

869+
int nth = 32; // SIMD width
870+
871+
while (nth < (int) n && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
872+
nth *= 2;
873+
}
874+
875+
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
876+
nth = std::min(nth, (int) n);
877+
878+
const int nsg = (nth + 31) / 32;
879+
869880
ggml_metal_encoder_set_pipeline(enc, pipeline);
870881
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
871882
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
872883
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
873884

874-
ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, 1, 1, 1);
885+
ggml_metal_encoder_set_threadgroup_memory_size(enc, nsg * sizeof(float), 0);
886+
887+
ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, nth, 1, 1);
875888

876889
return 1;
877890
}

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1727,18 +1727,48 @@ kernel void kernel_op_sum_f32(
17271727
constant ggml_metal_kargs_sum & args,
17281728
device const float * src0,
17291729
device float * dst,
1730-
ushort tiitg[[thread_index_in_threadgroup]]) {
1730+
threadgroup float * shmem_f32 [[threadgroup(0)]],
1731+
uint3 tgpig[[threadgroup_position_in_grid]],
1732+
ushort3 tpitg[[thread_position_in_threadgroup]],
1733+
ushort sgitg[[simdgroup_index_in_threadgroup]],
1734+
ushort tiisg[[thread_index_in_simdgroup]],
1735+
ushort3 ntg[[threads_per_threadgroup]]) {
17311736

1732-
if (tiitg != 0) {
1737+
if (args.np == 0) {
17331738
return;
17341739
}
17351740

1736-
float acc = 0.0f;
1737-
for (ulong i = 0; i < args.np; ++i) {
1738-
acc += src0[i];
1741+
const uint nsg = (ntg.x + 31) / 32;
1742+
1743+
float sumf = 0;
1744+
1745+
for (int64_t i0 = tpitg.x; i0 < args.np; i0 += ntg.x) {
1746+
sumf += src0[i0];
17391747
}
17401748

1741-
dst[0] = acc;
1749+
sumf = simd_sum(sumf);
1750+
1751+
if (tiisg == 0) {
1752+
shmem_f32[sgitg] = sumf;
1753+
}
1754+
1755+
threadgroup_barrier(mem_flags::mem_threadgroup);
1756+
1757+
float total = 0;
1758+
1759+
if (sgitg == 0) {
1760+
float v = 0;
1761+
1762+
if (tpitg.x < nsg) {
1763+
v = shmem_f32[tpitg.x];
1764+
}
1765+
1766+
total = simd_sum(v);
1767+
1768+
if (tpitg.x == 0) {
1769+
dst[0] = total;
1770+
}
1771+
}
17421772
}
17431773

17441774
template <bool norm>

0 commit comments

Comments
 (0)