Skip to content

Commit c1079ff

Browse files
qnixsynapsepockers21
authored and
pockers21
committed
SYCL: Refactor and enable FP16 in binary broadcast OPs (ggml-org#12975)
* SYCL: refactor move to a separate file * Fix binbcast * Remove duplicates * fix include formatting * fix typo
1 parent 5ffc9a2 commit c1079ff

File tree

7 files changed

+393
-372
lines changed

7 files changed

+393
-372
lines changed

ggml/src/ggml-sycl/backend.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#ifndef GGML_SYCL_BACKEND_HPP
1414
#define GGML_SYCL_BACKEND_HPP
1515

16+
#include "binbcast.hpp"
1617
#include "concat.hpp"
1718
#include "common.hpp"
1819
#include "conv.hpp"

ggml/src/ggml-sycl/binbcast.cpp

Lines changed: 350 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,350 @@
1+
#include "binbcast.hpp"
2+
3+
#include <cstddef>
4+
#include <cstdint>
5+
#include <sycl/sycl.hpp>
6+
7+
#include "ggml.h"
8+
9+
template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
10+
static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
11+
int ne0, int ne1, int ne2, int ne3,
12+
int ne10, int ne11, int ne12, int ne13,
13+
/*int s0, */ int s1, int s2, int s3,
14+
/*int s00,*/ int s01, int s02, int s03,
15+
/*int s10,*/ int s11, int s12, int s13,
16+
const sycl::nd_item<3> &item_ct1) {
17+
const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
18+
item_ct1.get_local_id(2);
19+
const int i1 = (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
20+
item_ct1.get_local_id(1));
21+
const int i2 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
22+
item_ct1.get_local_id(0)) /
23+
ne3;
24+
const int i3 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
25+
item_ct1.get_local_id(0)) %
26+
ne3;
27+
28+
if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
29+
return;
30+
}
31+
32+
const int i11 = i1 % ne11;
33+
const int i12 = i2 % ne12;
34+
const int i13 = i3 % ne13;
35+
36+
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
37+
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
38+
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
39+
40+
const src0_t * src0_row = src0 + i_src0;
41+
const src1_t * src1_row = src1 + i_src1;
42+
dst_t * dst_row = dst + i_dst;
43+
44+
for (int i0 = i0s; i0 < ne0;
45+
i0 += item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) {
46+
const int i10 = i0 % ne10;
47+
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
48+
}
49+
}
50+
51+
template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
52+
static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
53+
int ne0, int ne1, int ne2, int ne3,
54+
int ne10, int ne11, int ne12, int ne13,
55+
/*int s0, */ int s1, int s2, int s3,
56+
/*int s00,*/ int s01, int s02, int s03,
57+
/*int s10,*/ int s11, int s12, int s13,
58+
const sycl::nd_item<3> &item_ct1) {
59+
60+
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
61+
item_ct1.get_local_id(2);
62+
63+
const int i3 = i/(ne2*ne1*ne0);
64+
const int i2 = (i/(ne1*ne0)) % ne2;
65+
const int i1 = (i/ne0) % ne1;
66+
const int i0 = i % ne0;
67+
68+
if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
69+
return;
70+
}
71+
72+
const int i11 = i1 % ne11;
73+
const int i12 = i2 % ne12;
74+
const int i13 = i3 % ne13;
75+
76+
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
77+
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
78+
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
79+
80+
const src0_t * src0_row = src0 + i_src0;
81+
const src1_t * src1_row = src1 + i_src1;
82+
dst_t * dst_row = dst + i_dst;
83+
84+
const int i10 = i0 % ne10;
85+
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
86+
}
87+
88+
89+
template<float (*bin_op)(const float, const float)>
90+
struct bin_bcast_sycl {
91+
template <typename src0_t, typename src1_t, typename dst_t>
92+
void operator()(const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd, const int64_t ne00,
93+
const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t ne10, const int64_t ne11,
94+
const int64_t ne12, const int64_t ne13, const int64_t ne0, const int64_t ne1, const int64_t ne2,
95+
const int64_t ne3, const size_t nb00, const size_t nb01, const size_t nb02, const size_t nb03,
96+
const size_t nb10, const size_t nb11, const size_t nb12, const size_t nb13, const size_t nb0,
97+
const size_t nb1, const size_t nb2, const size_t nb3, const bool src0_is_contiguous,
98+
const bool src1_is_contiguous, const bool dst_is_contiguous, queue_ptr stream) {
99+
int nr0 = ne10 / ne0;
100+
int nr1 = ne11/ne1;
101+
int nr2 = ne12/ne2;
102+
int nr3 = ne13/ne3;
103+
104+
int nr[4] = { nr0, nr1, nr2, nr3 };
105+
106+
// collapse dimensions until first broadcast dimension
107+
int64_t cne[] = {ne0, ne1, ne2, ne3};
108+
int64_t cne0[] = {ne00, ne01, ne02, ne03};
109+
int64_t cne1[] = {ne10, ne11, ne12, ne13};
110+
size_t cnb[] = {nb0, nb1, nb2, nb3};
111+
size_t cnb0[] = {nb00, nb01, nb02, nb03};
112+
size_t cnb1[] = {nb10, nb11, nb12, nb13};
113+
auto collapse = [](int64_t cne[]) {
114+
cne[0] *= cne[1];
115+
cne[1] = cne[2];
116+
cne[2] = cne[3];
117+
cne[3] = 1;
118+
};
119+
120+
auto collapse_nb = [](size_t cnb[], int64_t cne[]) {
121+
cnb[1] *= cne[1];
122+
cnb[2] *= cne[2];
123+
cnb[3] *= cne[3];
124+
};
125+
126+
if (src0_is_contiguous && src1_is_contiguous && dst_is_contiguous) {
127+
for (int i = 0; i < 4; i++) {
128+
if (nr[i] != 1) {
129+
break;
130+
}
131+
if (i > 0) {
132+
collapse_nb(cnb, cne);
133+
collapse_nb(cnb0, cne0);
134+
collapse_nb(cnb1, cne1);
135+
collapse(cne);
136+
collapse(cne0);
137+
collapse(cne1);
138+
}
139+
}
140+
}
141+
{
142+
int64_t ne0 = cne[0];
143+
int64_t ne1 = cne[1];
144+
int64_t ne2 = cne[2];
145+
int64_t ne3 = cne[3];
146+
147+
int64_t ne10 = cne1[0];
148+
int64_t ne11 = cne1[1];
149+
int64_t ne12 = cne1[2];
150+
int64_t ne13 = cne1[3];
151+
152+
size_t nb0 = cnb[0];
153+
size_t nb1 = cnb[1];
154+
size_t nb2 = cnb[2];
155+
size_t nb3 = cnb[3];
156+
157+
size_t nb00 = cnb0[0];
158+
size_t nb01 = cnb0[1];
159+
size_t nb02 = cnb0[2];
160+
size_t nb03 = cnb0[3];
161+
162+
size_t nb10 = cnb1[0];
163+
size_t nb11 = cnb1[1];
164+
size_t nb12 = cnb1[2];
165+
size_t nb13 = cnb1[3];
166+
167+
size_t s0 = nb0 / sizeof(dst_t);
168+
size_t s1 = nb1 / sizeof(dst_t);
169+
size_t s2 = nb2 / sizeof(dst_t);
170+
size_t s3 = nb3 / sizeof(dst_t);
171+
172+
size_t s10 = nb10 / sizeof(src1_t);
173+
size_t s11 = nb11 / sizeof(src1_t);
174+
size_t s12 = nb12 / sizeof(src1_t);
175+
size_t s13 = nb13 / sizeof(src1_t);
176+
177+
size_t s00 = nb00 / sizeof(src0_t);
178+
size_t s01 = nb01 / sizeof(src0_t);
179+
size_t s02 = nb02 / sizeof(src0_t);
180+
size_t s03 = nb03 / sizeof(src0_t);
181+
182+
GGML_UNUSED(s00);
183+
184+
GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
185+
GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
186+
GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
187+
GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
188+
189+
GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
190+
GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
191+
GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
192+
GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
193+
194+
GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
195+
GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
196+
GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
197+
GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
198+
199+
GGML_ASSERT(s0 == 1);
200+
GGML_ASSERT(s10 == 1);
201+
202+
const int block_size = 128;
203+
204+
int64_t hne0 = std::max(ne0/2LL, 1LL);
205+
206+
sycl::range<3> block_dims(1, 1, 1);
207+
block_dims[2] = std::min<unsigned int>(hne0, block_size);
208+
block_dims[1] = std::min<unsigned int>(
209+
ne1, block_size / (unsigned int)block_dims[2]);
210+
block_dims[0] = std::min(
211+
std::min<unsigned int>(
212+
ne2 * ne3, block_size / (unsigned int)block_dims[2] /
213+
(unsigned int)block_dims[1]),
214+
64U);
215+
216+
sycl::range<3> block_nums(
217+
(ne2 * ne3 + block_dims[0] - 1) / block_dims[0],
218+
(ne1 + block_dims[1] - 1) / block_dims[1],
219+
(hne0 + block_dims[2] - 1) / block_dims[2]);
220+
221+
if (block_nums[0] > 65535) {
222+
// this is the maximum number of blocks in z direction, fallback to 1D grid kernel
223+
int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
224+
{
225+
dpct::has_capability_or_fail(stream->get_device(),
226+
{sycl::aspect::fp16});
227+
228+
stream->parallel_for(
229+
sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) *
230+
sycl::range<3>(1, 1, block_size),
231+
sycl::range<3>(1, 1, block_size)),
232+
[=](sycl::nd_item<3> item_ct1) {
233+
k_bin_bcast_unravel<bin_op>(
234+
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3,
235+
ne10, ne11, ne12, ne13, s1, s2, s3, s01, s02,
236+
s03, s11, s12, s13, item_ct1);
237+
});
238+
}
239+
} else {
240+
/*
241+
DPCT1049:16: The work-group size passed to the SYCL kernel may
242+
exceed the limit. To get the device limit, query
243+
info::device::max_work_group_size. Adjust the work-group size if
244+
needed.
245+
*/
246+
dpct::has_capability_or_fail(stream->get_device(),
247+
{sycl::aspect::fp16});
248+
249+
stream->parallel_for(
250+
sycl::nd_range<3>(block_nums * block_dims, block_dims),
251+
[=](sycl::nd_item<3> item_ct1) {
252+
k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1,
253+
ne2, ne3, ne10, ne11, ne12, ne13,
254+
s1, s2, s3, s01, s02, s03, s11, s12, s13,
255+
item_ct1);
256+
});
257+
}
258+
}
259+
}
260+
};
261+
262+
template <class op>
263+
inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1,
264+
ggml_tensor * dst) {
265+
dpct::queue_ptr main_stream = ctx.stream();
266+
GGML_TENSOR_BINARY_OP_LOCALS
267+
268+
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
269+
op()((const float *) src0->data, (const float *) src1->data, (float *) dst->data, ne00, ne01, ne02, ne03, ne10,
270+
ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, nb3,
271+
ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), main_stream);
272+
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
273+
op()((const sycl::half *) src0->data, (const sycl::half *) src1->data, (sycl::half *) dst->data, ne00, ne01,
274+
ne02, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13,
275+
nb0, nb1, nb2, nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst),
276+
main_stream);
277+
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
278+
op()((const sycl::half *) src0->data, (const float *) src1->data, (sycl::half *) dst->data, ne00, ne01, ne02,
279+
ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1,
280+
nb2, nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), main_stream);
281+
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) {
282+
op()((const int32_t *) src0->data, (const int32_t *) src1->data, (int32_t *) dst->data, ne00, ne01, ne02, ne03,
283+
ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2,
284+
nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), main_stream);
285+
} else if (src0->type == GGML_TYPE_I16 && src1->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) {
286+
op()((const int16_t *) src0->data, (const int16_t *) src1->data, (int16_t *) dst->data, ne00, ne01, ne02, ne03,
287+
ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2,
288+
nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), main_stream);
289+
} else {
290+
fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__, ggml_type_name(dst->type),
291+
ggml_type_name(src0->type), ggml_type_name(src1->type));
292+
GGML_ABORT("fatal error");
293+
}
294+
}
295+
296+
inline void ggml_sycl_op_add(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
297+
298+
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_add>>(ctx, dst->src[0], dst->src[1], dst);
299+
}
300+
301+
inline void ggml_sycl_op_sub(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
302+
303+
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_sub>>(ctx, dst->src[0], dst->src[1], dst);
304+
}
305+
306+
inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
307+
308+
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_mul>>(ctx, dst->src[0], dst->src[1], dst);
309+
}
310+
311+
inline void ggml_sycl_op_div(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
312+
313+
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_div>>(ctx, dst->src[0], dst->src[1], dst);
314+
}
315+
316+
inline void ggml_sycl_op_repeat(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
317+
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_repeat>>(ctx, dst, dst->src[0], dst);
318+
}
319+
320+
321+
void ggml_sycl_add(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
322+
GGML_SYCL_DEBUG("call %s\n", __func__);
323+
ggml_sycl_op_add(ctx, dst);
324+
GGML_SYCL_DEBUG("call %s done\n", __func__);
325+
}
326+
327+
void ggml_sycl_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
328+
GGML_SYCL_DEBUG("call %s\n", __func__);
329+
ggml_sycl_op_sub(ctx, dst);
330+
GGML_SYCL_DEBUG("call %s done\n", __func__);
331+
}
332+
333+
void ggml_sycl_mul(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
334+
GGML_SYCL_DEBUG("call %s\n", __func__);
335+
ggml_sycl_op_mul(ctx, dst);
336+
GGML_SYCL_DEBUG("call %s done\n", __func__);
337+
}
338+
339+
void ggml_sycl_div(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
340+
GGML_SYCL_DEBUG("call %s\n", __func__);
341+
ggml_sycl_op_div(ctx, dst);
342+
GGML_SYCL_DEBUG("call %s done\n", __func__);
343+
}
344+
345+
void ggml_sycl_repeat(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
346+
GGML_SYCL_DEBUG("call %s\n", __func__);
347+
ggml_sycl_op_repeat(ctx, dst);
348+
GGML_SYCL_DEBUG("call %s done\n", __func__);
349+
}
350+

ggml/src/ggml-sycl/binbcast.hpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#ifndef GGML_SYCL_BINBCAST_HPP
2+
#define GGML_SYCL_BINBCAST_HPP
3+
#include "common.hpp"
4+
5+
6+
static __dpct_inline__ float op_repeat(const float a, const float b) {
7+
return b;
8+
GGML_UNUSED(a);
9+
}
10+
11+
static __dpct_inline__ float op_add(const float a, const float b) {
12+
return a + b;
13+
}
14+
15+
static __dpct_inline__ float op_sub(const float a, const float b) {
16+
return a - b;
17+
}
18+
19+
static __dpct_inline__ float op_mul(const float a, const float b) {
20+
return a * b;
21+
}
22+
23+
static __dpct_inline__ float op_div(const float a, const float b) {
24+
return a / b;
25+
}
26+
27+
void ggml_sycl_add(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
28+
29+
void ggml_sycl_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
30+
31+
void ggml_sycl_mul(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
32+
33+
void ggml_sycl_div(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
34+
35+
void ggml_sycl_repeat(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
36+
37+
38+
#endif //GGML_SYCL_BINBCAST_HPP
39+

0 commit comments

Comments
 (0)