Skip to content

Commit a4b9665

Browse files
committed
SYCL repeat_back v1 — add core op + switch case
1 parent 7a50cf3 commit a4b9665

File tree

3 files changed

+85
-0
lines changed

3 files changed

+85
-0
lines changed

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
#include "ggml-sycl/set_rows.hpp"
4545
#include "ggml-sycl/sycl_hw.hpp"
4646
#include "ggml-sycl/getrows.hpp"
47+
#include "ggml-sycl/repeat_back.hpp"
4748
#include "ggml-sycl/quantize.hpp"
4849
#include "ggml.h"
4950

@@ -2597,6 +2598,10 @@ catch (sycl::exception const &exc) {
25972598
std::exit(1);
25982599
}
25992600

2601+
static void ggml_sycl_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2602+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
2603+
ggml_sycl_op_repeat_back(ctx, dst);
2604+
}
26002605

26012606
static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
26022607
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
@@ -3616,6 +3621,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
36163621
case GGML_OP_REPEAT:
36173622
ggml_sycl_repeat(ctx, dst);
36183623
break;
3624+
case GGML_OP_REPEAT_BACK:
3625+
ggml_sycl_repeat_back(ctx, dst);
3626+
break;
36193627
case GGML_OP_GET_ROWS:
36203628
ggml_sycl_get_rows(ctx, dst);
36213629
break;
@@ -4405,11 +4413,17 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
44054413
}
44064414
return false;
44074415
}
4416+
44084417
case GGML_OP_CONCAT:
44094418
{
44104419
ggml_type src0_type = op->src[0]->type;
44114420
return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
44124421
}
4422+
case GGML_OP_REPEAT_BACK:
4423+
{
4424+
ggml_type src0_type = op->src[0]->type;
4425+
return src0_type == GGML_TYPE_F32;
4426+
}
44134427
case GGML_OP_DUP:
44144428
case GGML_OP_ARGMAX:
44154429
case GGML_OP_NONE:

ggml/src/ggml-sycl/repeat_back.cpp

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
//בס"ד
2+
#include "repeat_back.hpp"
3+
4+
#include "common.hpp"
5+
6+
void ggml_sycl_op_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
7+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
8+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
9+
10+
const float * src0_dd = (const float *) dst->src[0]->data;
11+
float * dst_dd = (float *) dst->data;
12+
13+
const int64_t ne0 = dst->ne[0];
14+
const int64_t ne1 = dst->ne[1];
15+
const int64_t ne2 = dst->ne[2];
16+
const int64_t ne3 = dst->ne[3];
17+
const int64_t ne00 = dst->src[0]->ne[0];
18+
const int64_t ne01 = dst->src[0]->ne[1];
19+
const int64_t ne02 = dst->src[0]->ne[2];
20+
const int64_t ne03 = dst->src[0]->ne[3];
21+
22+
const int nr0 = (int) (ne00 / ne0);
23+
const int nr1 = (int) (ne01 / ne1);
24+
const int nr2 = (int) (ne02 / ne2);
25+
const int nr3 = (int) (ne03 / ne3);
26+
27+
const size_t total = ne0 * ne1 * ne2 * ne3;
28+
const int BLOCK_SIZE = 256;
29+
const int num_blocks = (total + BLOCK_SIZE - 1) / BLOCK_SIZE;
30+
31+
queue_ptr stream = ctx.stream();
32+
stream->memset(dst_dd, 0, total * sizeof(float));
33+
34+
stream->parallel_for(
35+
sycl::nd_range<1>(sycl::range<1>(num_blocks * BLOCK_SIZE), sycl::range<1>(BLOCK_SIZE)),
36+
[=](sycl::nd_item<1> item_ct1) {
37+
const size_t i = item_ct1.get_global_linear_id();
38+
if (i >= total) {
39+
return;
40+
}
41+
42+
const int i0 = i % ne0;
43+
const int i1 = (i / ne0) % ne1;
44+
const int i2 = (i / (ne0 * ne1)) % ne2;
45+
const int i3 = i / (ne0 * ne1 * ne2);
46+
47+
float acc = 0.0f;
48+
49+
for (int j3 = 0; j3 < nr3; ++j3) {
50+
for (int j2 = 0; j2 < nr2; ++j2) {
51+
for (int j1 = 0; j1 < nr1; ++j1) {
52+
for (int j0 = 0; j0 < nr0; ++j0) {
53+
acc += src0_dd[(i0 + j0 * ne0) + (i1 + j1 * ne1) * ne00 + (i2 + j2 * ne2) * ne00 * ne01 +
54+
(i3 + j3 * ne3) * ne00 * ne01 * ne02];
55+
}
56+
}
57+
}
58+
}
59+
60+
dst_dd[i] = acc;
61+
});
62+
}

ggml/src/ggml-sycl/repeat_back.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
//בס"ד
2+
#ifndef GGML_SYCL_REPEAT_BACK_HPP
3+
#define GGML_SYCL_REPEAT_BACK_HPP
4+
5+
#include "common.hpp"
6+
7+
void ggml_sycl_op_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
8+
9+
#endif // GGML_SYCL_REPEAT_BACK_HPP

0 commit comments

Comments
 (0)