Skip to content

Commit 0832de7

Browse files
authored
[SYCL] add conv support (#8688)
1 parent 6eeaeba commit 0832de7

File tree

5 files changed

+134
-0
lines changed

5 files changed

+134
-0
lines changed

ggml/src/ggml-sycl.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3981,6 +3981,9 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
39813981
ggml_sycl_func_t func;
39823982

39833983
switch (tensor->op) {
3984+
case GGML_OP_CONV_TRANSPOSE_1D:
3985+
func = ggml_sycl_op_conv_transpose_1d;
3986+
break;
39843987
case GGML_OP_REPEAT:
39853988
func = ggml_sycl_repeat;
39863989
break;
@@ -5090,6 +5093,15 @@ GGML_CALL static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t back
50905093

50915094
GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
50925095
switch (op->op) {
5096+
case GGML_OP_CONV_TRANSPOSE_1D:
5097+
{
5098+
ggml_type src0_type = op->src[0]->type;
5099+
ggml_type src1_type = op->src[1]->type;
5100+
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
5101+
return true;
5102+
}
5103+
return false;
5104+
} break;
50935105
case GGML_OP_UNARY:
50945106
switch (ggml_get_unary_op(op)) {
50955107
case GGML_UNARY_OP_GELU:

ggml/src/ggml-sycl/backend.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
#include "concat.hpp"
1717
#include "common.hpp"
18+
#include "conv.hpp"
1819
#include "convert.hpp"
1920
#include "dequantize.hpp"
2021
#include "dmmv.hpp"

ggml/src/ggml-sycl/conv.cpp

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
//
2+
// MIT license
3+
// Copyright (C) 2024 Intel Corporation
4+
// SPDX-License-Identifier: MIT
5+
//
6+
7+
//
8+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
9+
// See https://llvm.org/LICENSE.txt for license information.
10+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
11+
//
12+
13+
#include "conv.hpp"
14+
15+
static void conv_transpose_1d_kernel(
16+
const int s0, const int output_size,
17+
const int src0_ne0, const int src0_ne1, const int src0_ne2,
18+
const int src1_ne0, const int dst_ne0,
19+
const float * src0, const float * src1, float * dst,
20+
const sycl::nd_item<3> &item_ct1) {
21+
int global_index = item_ct1.get_local_id(2) +
22+
item_ct1.get_group(2) * item_ct1.get_local_range(2);
23+
if (global_index >= output_size) {
24+
return;
25+
}
26+
27+
int out_index = global_index / dst_ne0;
28+
29+
float accumulator = 0;
30+
31+
for (int c = 0; c < src0_ne2; c++) {
32+
int idx = global_index % dst_ne0;
33+
34+
int kernel_offset = (src0_ne0 * src0_ne1 * c) + (out_index * src0_ne0);
35+
int input_offset = src1_ne0 * c;
36+
37+
for (int i = 0; i < src1_ne0; i++) {
38+
if (!(idx >= i*s0 && idx < i*s0 + src0_ne0)) {
39+
continue;
40+
}
41+
int weight_idx = idx - i*s0;
42+
43+
float kernel_weight = src0[kernel_offset + weight_idx];
44+
float input_value = src1[input_offset+i];
45+
46+
accumulator += kernel_weight * input_value;
47+
}
48+
}
49+
dst[global_index] = accumulator;
50+
}
51+
52+
static void conv_transpose_1d_f32_f32_sycl(
53+
const int s0, const int output_size,
54+
const int src0_ne0, const int src0_ne1, const int src0_ne2,
55+
const int src1_ne0, const int dst_ne0,
56+
const float *src0, const float *src1, float *dst,
57+
const queue_ptr& stream) {
58+
59+
const int num_blocks = (output_size + SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE - 1) / SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE;
60+
const sycl::range<3> block_dims(1, 1, SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE);
61+
const sycl::range<3> block_nums(1, 1, num_blocks);
62+
stream->parallel_for(
63+
sycl::nd_range<3>(
64+
block_nums * block_dims, block_dims),
65+
[=](sycl::nd_item<3> item_ct1) {
66+
conv_transpose_1d_kernel(
67+
s0, output_size,
68+
src0_ne0, src0_ne1, src0_ne2,
69+
src1_ne0, dst_ne0,
70+
src0, src1, dst, item_ct1);
71+
});
72+
}
73+
74+
void ggml_sycl_op_conv_transpose_1d(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
75+
const ggml_tensor *src1, ggml_tensor *dst) {
76+
const float * src0_d = (const float *)src0->data;
77+
const float * src1_d = (const float *)src1->data;
78+
79+
float * dst_d = (float *)dst->data;
80+
dpct::queue_ptr stream = ctx.stream();
81+
82+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
83+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
84+
85+
GGML_ASSERT(ggml_is_contiguous(src0));
86+
GGML_ASSERT(ggml_is_contiguous(src1));
87+
88+
const int32_t * opts = (const int32_t *)dst->op_params;
89+
90+
const int s0 = opts[0];
91+
92+
const int64_t output_size = ggml_nelements(dst);
93+
94+
conv_transpose_1d_f32_f32_sycl(s0, output_size,
95+
src0->ne[0], src0->ne[1], src0->ne[2],
96+
src1->ne[0], dst->ne[0],
97+
src0_d, src1_d, dst_d, stream);
98+
}
99+

ggml/src/ggml-sycl/conv.hpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
//
2+
// MIT license
3+
// Copyright (C) 2024 Intel Corporation
4+
// SPDX-License-Identifier: MIT
5+
//
6+
7+
//
8+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
9+
// See https://llvm.org/LICENSE.txt for license information.
10+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
11+
//
12+
13+
#ifndef GGML_SYCL_CONV_HPP
14+
#define GGML_SYCL_CONV_HPP
15+
16+
#include "common.hpp"
17+
18+
void ggml_sycl_op_conv_transpose_1d(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
19+
const ggml_tensor *src1, ggml_tensor *dst);
20+
21+
#endif // GGML_SYCL_CONV_HPP

ggml/src/ggml-sycl/presets.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
#define SYCL_ACC_BLOCK_SIZE 256
4242
#define SYCL_IM2COL_BLOCK_SIZE 256
4343
#define SYCL_POOL2D_BLOCK_SIZE 256
44+
#define SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE 256
4445

4546
// dmmv = dequantize_mul_mat_vec
4647
#ifndef GGML_SYCL_DMMV_X

0 commit comments

Comments
 (0)