Skip to content

ggml : sync (abort callback, mul / add broadcast, fix alibi) #2183

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 88 additions & 27 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -239,13 +239,13 @@ struct ggml_tensor_extra_gpu {
cudaEvent_t events[GGML_CUDA_MAX_DEVICES]; // events for synchronizing multiple GPUs
};

static __global__ void add_f32(const float * x, const float * y, float * dst, const int k) {
static __global__ void add_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;

if (i >= k) {
if (i >= kx) {
return;
}
dst[i] = x[i] + y[i];
dst[i] = x[i] + y[i%ky];
}

static __global__ void add_f16_f32_f16(const half * x, const float * y, half * dst, const int k) {
Expand Down Expand Up @@ -275,16 +275,46 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) {
dst[i] = x[i] / (1.0f + expf(-x[i]));
}

static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;

const float eps = 1e-5f;

float mean = 0.0f;
float var = 0.0f;

for (int col = tid; col < ncols; col += WARP_SIZE) {
const float xi = x[row*ncols + col];
mean += xi;
var += xi * xi;
}

// sum up partial sums
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
mean += __shfl_xor_sync(0xffffffff, mean, mask, 32);
var += __shfl_xor_sync(0xffffffff, var, mask, 32);
}

mean /= ncols;
var = var / ncols - mean * mean;
const float inv_var = rsqrtf(var + eps);

for (int col = tid; col < ncols; col += WARP_SIZE) {
dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_var;
}
}

static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;

const float eps = 1e-6;
const float eps = 1e-6f;

float tmp = 0.0f; // partial sum for thread in warp

for (int i = 0; i < ncols; i += WARP_SIZE) {
const int col = i + tid;
for (int col = tid; col < ncols; col += WARP_SIZE) {
const float xi = x[row*ncols + col];
tmp += xi * xi;
}
Expand All @@ -296,10 +326,9 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
}

const float mean = tmp / ncols;
const float scale = 1.0f / sqrtf(mean + eps);
const float scale = rsqrtf(mean + eps);

for (int i = 0; i < ncols; i += WARP_SIZE) {
const int col = i + tid;
for (int col = tid; col < ncols; col += WARP_SIZE) {
dst[row*ncols + col] = scale * x[row*ncols + col];
}
}
Expand Down Expand Up @@ -1689,9 +1718,9 @@ static __global__ void scale_f32(const float * x, float * dst, const float scale
dst[i] = scale * x[i];
}

static void add_f32_cuda(const float * x, const float * y, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
static void add_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
const int num_blocks = (kx + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
}

static void add_f16_f32_f16_cuda(const half * x, const float * y, half * dst, const int k, cudaStream_t stream) {
Expand All @@ -1709,6 +1738,12 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}

static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
const dim3 block_dims(WARP_SIZE, 1, 1);
norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
}

static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
const dim3 block_dims(WARP_SIZE, 1, 1);
Expand Down Expand Up @@ -2239,14 +2274,16 @@ inline void ggml_cuda_op_add(
GGML_ASSERT(src1_ddf_i != nullptr);
GGML_ASSERT(dst_ddf_i != nullptr);

const int64_t ne0 = src0->ne[0];
const int64_t ne00 = src0->ne[0];
const int64_t i01_diff = i01_high - i01_low;

const int64_t ne10 = src1->ne[0];

// compute
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne0*i01_diff, cudaStream_main);
add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne00*i01_diff, ne10, cudaStream_main);
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
add_f16_f32_f16_cuda((half *) src0_ddq_i, src1_ddf_i, (half *) dst_ddf_i, ne0*i01_diff, cudaStream_main);
add_f16_f32_f16_cuda((half *) src0_ddq_i, src1_ddf_i, (half *) dst_ddf_i, ne00*i01_diff, cudaStream_main);
} else {
GGML_ASSERT(false);
}
Expand All @@ -2268,20 +2305,11 @@ inline void ggml_cuda_op_mul(
GGML_ASSERT(dst_ddf_i != nullptr);

const int64_t ne00 = src0->ne[0];
const int64_t i01_diff = i01_high - i01_low;

const int64_t ne10 = src1->ne[0];
const int64_t ne11 = src1->ne[1];

for (int64_t i01 = i01_low; i01 < i01_high; i01++) {
const int64_t i11 = i1*ne11 + i01%ne11; // broadcast src1 across src0

float * src0_ddf_i01 = src0_ddf_i + i01*ne00;
float * src1_ddf_i01 = src1_ddf_i + i11*ne10;
float * dst_ddf_i01 = dst_ddf_i + i01*ne00;

// compute
mul_f32_cuda(src0_ddf_i01, src1_ddf_i01, dst_ddf_i01, ne00, ne10, cudaStream_main);
}
mul_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne00*i01_diff, ne10, cudaStream_main);

(void) dst;
(void) src0_ddq_i;
Expand Down Expand Up @@ -2310,6 +2338,28 @@ inline void ggml_cuda_op_silu(
(void) i1;
}

inline void ggml_cuda_op_norm(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
cudaStream_t & cudaStream_main){

GGML_ASSERT(src0_ddf_i != nullptr);
GGML_ASSERT(dst_ddf_i != nullptr);

const int64_t ne00 = src0->ne[0];
const int64_t i01_diff = i01_high - i01_low;

// compute
norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);

(void) src1;
(void) dst;
(void) src0_ddq_i;
(void) src1_ddf_i;
(void) i02;
(void) i1;
}

inline void ggml_cuda_op_rms_norm(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
Expand Down Expand Up @@ -2930,6 +2980,11 @@ void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_ten
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_silu, true, true);
}

void ggml_cuda_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_norm, true, true);
}

void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rms_norm, true, true);
Expand Down Expand Up @@ -3160,7 +3215,7 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
}


cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice);
CUDA_CHECK(cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice));

extra->data_device[id] = buf;

Expand Down Expand Up @@ -3322,6 +3377,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
}
func = ggml_cuda_silu;
break;
case GGML_OP_NORM:
if (!any_on_device) {
return false;
}
func = ggml_cuda_norm;
break;
case GGML_OP_RMS_NORM:
if (!any_on_device) {
return false;
Expand Down
Loading