Skip to content

Commit f5ef34e

Browse files
justcho5ggerganov
authored andcommitted
feat: implemented sigmoid function (ggml/806)
* added sigmoid function * implemented metal kernel for sigmoid * implemented cuda kernel for sigmoid * added sigmoid unary op and incremented count
1 parent ef0d5e3 commit f5ef34e

File tree

7 files changed

+136
-1
lines changed

7 files changed

+136
-1
lines changed

ggml-cuda.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2204,6 +2204,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
22042204
case GGML_UNARY_OP_RELU:
22052205
ggml_cuda_op_relu(ctx, dst);
22062206
break;
2207+
case GGML_UNARY_OP_SIGMOID:
2208+
ggml_cuda_op_sigmoid(ctx, dst);
2209+
break;
22072210
case GGML_UNARY_OP_HARDSIGMOID:
22082211
ggml_cuda_op_hardsigmoid(ctx, dst);
22092212
break;
@@ -2716,6 +2719,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
27162719
case GGML_UNARY_OP_GELU:
27172720
case GGML_UNARY_OP_SILU:
27182721
case GGML_UNARY_OP_RELU:
2722+
case GGML_UNARY_OP_SIGMOID:
27192723
case GGML_UNARY_OP_HARDSIGMOID:
27202724
case GGML_UNARY_OP_HARDSWISH:
27212725
case GGML_UNARY_OP_GELU_QUICK:

ggml-cuda/unary.cu

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,15 @@ static __global__ void relu_f32(const float * x, float * dst, const int k) {
4848
dst[i] = fmaxf(x[i], 0);
4949
}
5050

51+
static __global__ void sigmoid_f32(const float * x, float * dst, const int k) {
52+
const int i = blockDim.x*blockIdx.x + threadIdx.x;
53+
54+
if (i >= k) {
55+
return;
56+
}
57+
dst[i] = 1.0f / (1.0f + expf(-x[i]));
58+
}
59+
5160
static __global__ void hardsigmoid_f32(const float * x, float * dst, const int k) {
5261
const int i = blockDim.x*blockIdx.x + threadIdx.x;
5362

@@ -108,6 +117,11 @@ static void relu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
108117
relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
109118
}
110119

120+
static void sigmoid_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
121+
const int num_blocks = (k + CUDA_SIGMOID_BLOCK_SIZE - 1) / CUDA_SIGMOID_BLOCK_SIZE;
122+
sigmoid_f32<<<num_blocks, CUDA_SIGMOID_BLOCK_SIZE, 0, stream>>>(x, dst, k);
123+
}
124+
111125
static void hardsigmoid_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
112126
const int num_blocks = (k + CUDA_HARDSIGMOID_BLOCK_SIZE - 1) / CUDA_HARDSIGMOID_BLOCK_SIZE;
113127
hardsigmoid_f32<<<num_blocks, CUDA_HARDSIGMOID_BLOCK_SIZE, 0, stream>>>(x, dst, k);
@@ -188,6 +202,18 @@ void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
188202
relu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
189203
}
190204

205+
void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
206+
const ggml_tensor * src0 = dst->src[0];
207+
const float * src0_d = (const float *)src0->data;
208+
float * dst_d = (float *)dst->data;
209+
cudaStream_t stream = ctx.stream();
210+
211+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
212+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
213+
214+
sigmoid_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
215+
}
216+
191217
void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
192218
const ggml_tensor * src0 = dst->src[0];
193219
const float * src0_d = (const float *)src0->data;

ggml-cuda/unary.cuh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#define CUDA_SILU_BLOCK_SIZE 256
55
#define CUDA_TANH_BLOCK_SIZE 256
66
#define CUDA_RELU_BLOCK_SIZE 256
7+
#define CUDA_SIGMOID_BLOCK_SIZE 256
78
#define CUDA_HARDSIGMOID_BLOCK_SIZE 256
89
#define CUDA_HARDSWISH_BLOCK_SIZE 256
910
#define CUDA_SQR_BLOCK_SIZE 256
@@ -18,6 +19,8 @@ void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
1819

1920
void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
2021

22+
void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
23+
2124
void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
2225

2326
void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

ggml-metal.m

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
GGML_METAL_KERNEL_TYPE_CLAMP,
4141
GGML_METAL_KERNEL_TYPE_TANH,
4242
GGML_METAL_KERNEL_TYPE_RELU,
43+
GGML_METAL_KERNEL_TYPE_SIGMOID,
4344
GGML_METAL_KERNEL_TYPE_GELU,
4445
GGML_METAL_KERNEL_TYPE_GELU_4,
4546
GGML_METAL_KERNEL_TYPE_GELU_QUICK,
@@ -493,6 +494,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
493494
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
494495
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
495496
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
497+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
496498
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
497499
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
498500
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
@@ -730,6 +732,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
730732
switch (ggml_get_unary_op(op)) {
731733
case GGML_UNARY_OP_TANH:
732734
case GGML_UNARY_OP_RELU:
735+
case GGML_UNARY_OP_SIGMOID:
733736
case GGML_UNARY_OP_GELU:
734737
case GGML_UNARY_OP_GELU_QUICK:
735738
case GGML_UNARY_OP_SILU:
@@ -1237,6 +1240,18 @@ static enum ggml_status ggml_metal_graph_compute(
12371240

12381241
const int64_t n = ggml_nelements(dst);
12391242

1243+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1244+
} break;
1245+
case GGML_UNARY_OP_SIGMOID:
1246+
{
1247+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIGMOID].pipeline;
1248+
1249+
[encoder setComputePipelineState:pipeline];
1250+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1251+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1252+
1253+
const int64_t n = ggml_nelements(dst);
1254+
12401255
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
12411256
} break;
12421257
case GGML_UNARY_OP_GELU:

ggml-metal.metal

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,13 @@ kernel void kernel_relu(
229229
dst[tpig] = max(0.0f, src0[tpig]);
230230
}
231231

232+
kernel void kernel_sigmoid(
233+
device const float * src0,
234+
device float * dst,
235+
uint tpig[[thread_position_in_grid]]) {
236+
dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
237+
}
238+
232239
kernel void kernel_tanh(
233240
device const float * src0,
234241
device float * dst,

ggml.c

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1949,6 +1949,7 @@ inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) {
19491949
inline static void ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expf(x[i])-1; }
19501950
inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
19511951
inline static void ggml_vec_leaky_relu_f32 (const int n, float * y, const float * x, const float ns) { for (int i = 0; i < n; ++i) y[i] = ((x[i] > 0.f) ? x[i] : 0.f) + ns * ((x[i] < 0.0f) ? x[i] : 0.f); }
1952+
inline static void ggml_vec_sigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = 1.f / (1.f + expf(-x[i])); }
19521953
// TODO: optimize performance
19531954
inline static void ggml_vec_hardswish_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
19541955
inline static void ggml_vec_hardsigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
@@ -2329,14 +2330,15 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
23292330
"TANH",
23302331
"ELU",
23312332
"RELU",
2333+
"SIGMOID",
23322334
"GELU",
23332335
"GELU_QUICK",
23342336
"SILU",
23352337
"HARDSWISH",
23362338
"HARDSIGMOID",
23372339
};
23382340

2339-
static_assert(GGML_UNARY_OP_COUNT == 12, "GGML_UNARY_OP_COUNT != 12");
2341+
static_assert(GGML_UNARY_OP_COUNT == 13, "GGML_UNARY_OP_COUNT != 13");
23402342

23412343

23422344
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
@@ -4561,6 +4563,20 @@ struct ggml_tensor * ggml_leaky_relu(
45614563
return result;
45624564
}
45634565

4566+
// ggml_sigmoid
4567+
4568+
struct ggml_tensor * ggml_sigmoid(
4569+
struct ggml_context * ctx,
4570+
struct ggml_tensor * a) {
4571+
return ggml_unary(ctx, a, GGML_UNARY_OP_SIGMOID);
4572+
}
4573+
4574+
struct ggml_tensor * ggml_sigmoid_inplace(
4575+
struct ggml_context * ctx,
4576+
struct ggml_tensor * a) {
4577+
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SIGMOID);
4578+
}
4579+
45644580
// ggml_gelu
45654581

45664582
struct ggml_tensor * ggml_gelu(
@@ -10852,6 +10868,52 @@ static void ggml_compute_forward_relu(
1085210868
}
1085310869
}
1085410870

10871+
// ggml_compute_forward_sigmoid
10872+
10873+
static void ggml_compute_forward_sigmoid_f32(
10874+
const struct ggml_compute_params * params,
10875+
struct ggml_tensor * dst) {
10876+
10877+
const struct ggml_tensor * src0 = dst->src[0];
10878+
10879+
assert(params->ith == 0);
10880+
assert(ggml_are_same_shape(src0, dst));
10881+
10882+
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
10883+
return;
10884+
}
10885+
10886+
const int n = ggml_nrows(src0);
10887+
const int nc = src0->ne[0];
10888+
10889+
assert(dst->nb[0] == sizeof(float));
10890+
assert(src0->nb[0] == sizeof(float));
10891+
10892+
for (int i = 0; i < n; i++) {
10893+
ggml_vec_sigmoid_f32(nc,
10894+
(float *) ((char *) dst->data + i*( dst->nb[1])),
10895+
(float *) ((char *) src0->data + i*(src0->nb[1])));
10896+
}
10897+
}
10898+
10899+
static void ggml_compute_forward_sigmoid(
10900+
const struct ggml_compute_params * params,
10901+
struct ggml_tensor * dst) {
10902+
10903+
const struct ggml_tensor * src0 = dst->src[0];
10904+
10905+
switch (src0->type) {
10906+
case GGML_TYPE_F32:
10907+
{
10908+
ggml_compute_forward_sigmoid_f32(params, dst);
10909+
} break;
10910+
default:
10911+
{
10912+
GGML_ASSERT(false);
10913+
} break;
10914+
}
10915+
}
10916+
1085510917
// ggml_compute_forward_gelu
1085610918

1085710919
static void ggml_compute_forward_gelu_f32(
@@ -16617,6 +16679,10 @@ static void ggml_compute_forward_unary(
1661716679
{
1661816680
ggml_compute_forward_relu(params, dst);
1661916681
} break;
16682+
case GGML_UNARY_OP_SIGMOID:
16683+
{
16684+
ggml_compute_forward_sigmoid(params, dst);
16685+
} break;
1662016686
case GGML_UNARY_OP_GELU:
1662116687
{
1662216688
ggml_compute_forward_gelu(params, dst);
@@ -18601,6 +18667,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
1860118667
zero_table);
1860218668
}
1860318669
} break;
18670+
case GGML_UNARY_OP_SIGMOID:
18671+
{
18672+
GGML_ASSERT(false); // TODO: not implemented
18673+
} break;
1860418674
case GGML_UNARY_OP_GELU:
1860518675
{
1860618676
GGML_ASSERT(false); // TODO: not implemented
@@ -19130,6 +19200,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
1913019200
case GGML_UNARY_OP_TANH:
1913119201
case GGML_UNARY_OP_ELU:
1913219202
case GGML_UNARY_OP_RELU:
19203+
case GGML_UNARY_OP_SIGMOID:
1913319204
case GGML_UNARY_OP_HARDSWISH: // to opt for multiple threads
1913419205
case GGML_UNARY_OP_HARDSIGMOID: // to opt for multiple threads
1913519206
{

ggml.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,7 @@ extern "C" {
519519
GGML_UNARY_OP_TANH,
520520
GGML_UNARY_OP_ELU,
521521
GGML_UNARY_OP_RELU,
522+
GGML_UNARY_OP_SIGMOID,
522523
GGML_UNARY_OP_GELU,
523524
GGML_UNARY_OP_GELU_QUICK,
524525
GGML_UNARY_OP_SILU,
@@ -1073,6 +1074,14 @@ extern "C" {
10731074
struct ggml_context * ctx,
10741075
struct ggml_tensor * a);
10751076

1077+
GGML_API struct ggml_tensor * ggml_sigmoid(
1078+
struct ggml_context * ctx,
1079+
struct ggml_tensor * a);
1080+
1081+
GGML_API struct ggml_tensor * ggml_sigmoid_inplace(
1082+
struct ggml_context * ctx,
1083+
struct ggml_tensor * a);
1084+
10761085
GGML_API struct ggml_tensor * ggml_gelu(
10771086
struct ggml_context * ctx,
10781087
struct ggml_tensor * a);

0 commit comments

Comments
 (0)