Skip to content

Commit 59fc1ec

Browse files
shani-fCISC
andauthored
sycl: add REPEAT_BACK operation support (#16734)
* SYCL repeat_back v1 — add core op + switch case * Implement repeat_back SYCL operation and minor fixes * Update ggml/src/ggml-sycl/repeat_back.cpp Co-authored-by: Sigbjørn Skjæret <[email protected]> * Update ggml/src/ggml-sycl/repeat_back.hpp Co-authored-by: Sigbjørn Skjæret <[email protected]> * Update ggml/src/ggml-sycl/ggml-sycl.cpp Co-authored-by: Sigbjørn Skjæret <[email protected]> --------- Co-authored-by: Sigbjørn Skjæret <[email protected]>
1 parent 75d33b9 commit 59fc1ec

File tree

3 files changed

+77
-0
lines changed

3 files changed

+77
-0
lines changed

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
#include "ggml-sycl/set.hpp"
4949
#include "ggml-sycl/sycl_hw.hpp"
5050
#include "ggml-sycl/getrows.hpp"
51+
#include "ggml-sycl/repeat_back.hpp"
5152
#include "ggml-sycl/quantize.hpp"
5253
#include "ggml.h"
5354

@@ -2615,6 +2616,10 @@ catch (sycl::exception const &exc) {
26152616
std::exit(1);
26162617
}
26172618

2619+
static void ggml_sycl_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2620+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
2621+
ggml_sycl_op_repeat_back(ctx, dst);
2622+
}
26182623

26192624
static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
26202625
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
@@ -3679,6 +3684,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
36793684
case GGML_OP_REPEAT:
36803685
ggml_sycl_repeat(ctx, dst);
36813686
break;
3687+
case GGML_OP_REPEAT_BACK:
3688+
ggml_sycl_repeat_back(ctx, dst);
3689+
break;
36823690
case GGML_OP_GET_ROWS:
36833691
ggml_sycl_get_rows(ctx, dst);
36843692
break;
@@ -4516,6 +4524,11 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
45164524
ggml_type src0_type = op->src[0]->type;
45174525
return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
45184526
}
4527+
case GGML_OP_REPEAT_BACK:
4528+
{
4529+
ggml_type src0_type = op->src[0]->type;
4530+
return src0_type == GGML_TYPE_F32;
4531+
}
45194532
case GGML_OP_DUP:
45204533
case GGML_OP_ARGMAX:
45214534
case GGML_OP_NONE:

ggml/src/ggml-sycl/repeat_back.cpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
#include "repeat_back.hpp"
2+
3+
#include "common.hpp"
4+
5+
void ggml_sycl_op_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
6+
7+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
8+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
9+
10+
const float * src0_dd = (const float *) dst->src[0]->data;
11+
float * dst_dd = (float *) dst->data;
12+
13+
const int64_t ne0 = dst->ne[0], ne1 = dst->ne[1], ne2 = dst->ne[2], ne3 = dst->ne[3];
14+
const int64_t ne00 = dst->src[0]->ne[0], ne01 = dst->src[0]->ne[1], ne02 = dst->src[0]->ne[2],
15+
ne03 = dst->src[0]->ne[3];
16+
17+
const int nr0 = (int) (ne00 / ne0);
18+
const int nr1 = (int) (ne01 / ne1);
19+
const int nr2 = (int) (ne02 / ne2);
20+
const int nr3 = (int) (ne03 / ne3);
21+
22+
const size_t total = ne0 * ne1 * ne2 * ne3;
23+
const int BLOCK_SIZE = 256;
24+
const int num_blocks = (total + BLOCK_SIZE - 1) / BLOCK_SIZE;
25+
26+
queue_ptr stream = ctx.stream();
27+
28+
stream->parallel_for(
29+
sycl::nd_range<1>(sycl::range<1>(num_blocks * BLOCK_SIZE), sycl::range<1>(BLOCK_SIZE)),
30+
[=](sycl::nd_item<1> item_ct1) {
31+
const size_t i = item_ct1.get_global_linear_id();
32+
if (i >= total) {
33+
return;
34+
}
35+
36+
const int i0 = i % ne0;
37+
const int i1 = (i / ne0) % ne1;
38+
const int i2 = (i / (ne0 * ne1)) % ne2;
39+
const int i3 = i / (ne0 * ne1 * ne2);
40+
41+
float acc = 0.0f;
42+
43+
for (int j3 = 0; j3 < nr3; ++j3) {
44+
for (int j2 = 0; j2 < nr2; ++j2) {
45+
for (int j1 = 0; j1 < nr1; ++j1) {
46+
for (int j0 = 0; j0 < nr0; ++j0) {
47+
acc += src0_dd[(i0 + j0 * ne0) + (i1 + j1 * ne1) * ne00 + (i2 + j2 * ne2) * ne00 * ne01 +
48+
(i3 + j3 * ne3) * ne00 * ne01 * ne02];
49+
}
50+
}
51+
}
52+
}
53+
54+
dst_dd[i] = acc;
55+
});
56+
}

ggml/src/ggml-sycl/repeat_back.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#ifndef GGML_SYCL_REPEAT_BACK_HPP
2+
#define GGML_SYCL_REPEAT_BACK_HPP
3+
4+
#include "common.hpp"
5+
6+
void ggml_sycl_op_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
7+
8+
#endif // GGML_SYCL_REPEAT_BACK_HPP

0 commit comments

Comments
 (0)