Skip to content

Commit 0b174ab

Browse files
committed
ggml: CUDA unary op EXP
Signed-off-by: Molly Sophia <[email protected]>
1 parent 7820364 commit 0b174ab

File tree

3 files changed

+35
-0
lines changed

3 files changed

+35
-0
lines changed

ggml/src/ggml-cuda.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2225,6 +2225,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
22252225
case GGML_UNARY_OP_HARDSWISH:
22262226
ggml_cuda_op_hardswish(ctx, dst);
22272227
break;
2228+
case GGML_UNARY_OP_EXP:
2229+
ggml_cuda_op_exp(ctx, dst);
2230+
break;
22282231
default:
22292232
return false;
22302233
}
@@ -2769,6 +2772,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
27692772
case GGML_UNARY_OP_HARDSWISH:
27702773
case GGML_UNARY_OP_GELU_QUICK:
27712774
case GGML_UNARY_OP_TANH:
2775+
case GGML_UNARY_OP_EXP:
27722776
return ggml_is_contiguous(op->src[0]);
27732777
default:
27742778
return false;

ggml/src/ggml-cuda/unary.cu

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,15 @@ static __global__ void hardswish_f32(const float * x, float * dst, const int k)
8585
dst[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f));
8686
}
8787

88+
static __global__ void exp_f32(const float * x, float * dst, const int k) {
89+
const int i = blockDim.x*blockIdx.x + threadIdx.x;
90+
91+
if (i >= k) {
92+
return;
93+
}
94+
dst[i] = expf(x[i]);
95+
}
96+
8897
static __global__ void leaky_relu_f32(const float * x, float * dst, const int k, const float negative_slope) {
8998
const int i = blockDim.x*blockIdx.x + threadIdx.x;
9099
if (i >= k) {
@@ -174,6 +183,11 @@ static void hardswish_f32_cuda(const float * x, float * dst, const int k, cudaSt
174183
hardswish_f32<<<num_blocks, CUDA_HARDSWISH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
175184
}
176185

186+
static void exp_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
187+
const int num_blocks = (k + CUDA_EXP_BLOCK_SIZE - 1) / CUDA_EXP_BLOCK_SIZE;
188+
exp_f32<<<num_blocks, CUDA_EXP_BLOCK_SIZE, 0, stream>>>(x, dst, k);
189+
}
190+
177191
static void leaky_relu_f32_cuda(const float * x, float * dst, const int k, const float negative_slope, cudaStream_t stream) {
178192
const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
179193
leaky_relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k, negative_slope);
@@ -325,6 +339,20 @@ void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
325339
hardswish_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
326340
}
327341

342+
void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
343+
const ggml_tensor * src0 = dst->src[0];
344+
const float * src0_d = (const float *)src0->data;
345+
float * dst_d = (float *)dst->data;
346+
cudaStream_t stream = ctx.stream();
347+
348+
GGML_ASSERT(ggml_is_contiguous(src0));
349+
350+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
351+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
352+
353+
exp_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
354+
}
355+
328356
void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
329357
const ggml_tensor * src0 = dst->src[0];
330358
const float * src0_d = (const float *)src0->data;

ggml/src/ggml-cuda/unary.cuh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#define CUDA_RELU_BLOCK_SIZE 256
88
#define CUDA_SIGMOID_BLOCK_SIZE 256
99
#define CUDA_HARDSIGMOID_BLOCK_SIZE 256
10+
#define CUDA_EXP_BLOCK_SIZE 256
1011
#define CUDA_HARDSWISH_BLOCK_SIZE 256
1112
#define CUDA_SQR_BLOCK_SIZE 256
1213
#define CUDA_SQRT_BLOCK_SIZE 256
@@ -29,6 +30,8 @@ void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
2930

3031
void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
3132

33+
void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
34+
3235
void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
3336

3437
void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

0 commit comments

Comments
 (0)