Skip to content

Commit ea2b795

Browse files
committed
ggml : group all experts in a single ggml_mul_mat_id
cuda : improve mmid row copy
1 parent a307375 commit ea2b795

File tree

7 files changed

+395
-152
lines changed

7 files changed

+395
-152
lines changed

ggml-cuda.cu

Lines changed: 175 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1231,7 +1231,7 @@ static void ggml_cuda_op_mul_mat_cublas(
12311231

12321232
if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
12331233
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
1234-
ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool());
1234+
ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool(id));
12351235
if (src0->type != GGML_TYPE_F16) {
12361236
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type);
12371237
GGML_ASSERT(to_fp16_cuda != nullptr);
@@ -1241,7 +1241,7 @@ static void ggml_cuda_op_mul_mat_cublas(
12411241
}
12421242
const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16.get();
12431243

1244-
ggml_cuda_pool_alloc<half> src1_as_f16(ctx.pool());
1244+
ggml_cuda_pool_alloc<half> src1_as_f16(ctx.pool(id));
12451245
if (src1->type != GGML_TYPE_F16) {
12461246
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
12471247
GGML_ASSERT(to_fp16_cuda != nullptr);
@@ -1250,7 +1250,7 @@ static void ggml_cuda_op_mul_mat_cublas(
12501250
to_fp16_cuda(src1_ddf_i, src1_as_f16.get(), ne, stream);
12511251
}
12521252
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16.get();
1253-
ggml_cuda_pool_alloc<half> dst_f16(ctx.pool(), row_diff*src1_ncols);
1253+
ggml_cuda_pool_alloc<half> dst_f16(ctx.pool(id), row_diff*src1_ncols);
12541254

12551255
const half alpha_f16 = 1.0f;
12561256
const half beta_f16 = 0.0f;
@@ -1960,20 +1960,84 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
19601960
}
19611961
}
19621962

1963+
struct mmid_row_mapping {
1964+
int64_t i1;
1965+
int64_t i2;
1966+
};
1967+
1968+
static __global__ void k_copy_src1_to_contiguous(const char * src1_original, char * src1_contiguous,
1969+
int * cur_src1_row, mmid_row_mapping * row_mapping,
1970+
const char * ids_dev, int64_t i02, int64_t ids_nb1, int64_t ids_nb0,
1971+
int64_t ids_ne1, int64_t n_ids,
1972+
int64_t ne11,
1973+
size_t nb11, size_t nb12) {
1974+
int64_t iid1 = blockIdx.x;
1975+
int64_t id = blockIdx.y;
1976+
1977+
if (iid1 >= ids_ne1 || id >= n_ids) {
1978+
return;
1979+
}
1980+
1981+
const int32_t row_id_i = *(const int32_t *) (ids_dev + iid1*ids_nb1 + id*ids_nb0);
1982+
1983+
if (row_id_i != i02) {
1984+
return;
1985+
}
1986+
1987+
const int64_t i11 = id % ne11;
1988+
const int64_t i12 = iid1;
1989+
1990+
__shared__ int src1_row;
1991+
if (threadIdx.x == 0) {
1992+
src1_row = atomicAdd(cur_src1_row, 1);
1993+
row_mapping[src1_row] = {id, iid1};
1994+
}
1995+
__syncthreads();
1996+
1997+
const char * src1_row_original = src1_original + i11*nb11 + i12*nb12;
1998+
char * src1_row_contiguous = src1_contiguous + src1_row*nb11;
1999+
2000+
for (int i = threadIdx.x; i < nb11; i += blockDim.x) {
2001+
src1_row_contiguous[i] = src1_row_original[i];
2002+
}
2003+
}
2004+
2005+
static __global__ void k_copy_dst_from_contiguous(char * dst_original, const char * dst_contiguous,
2006+
const mmid_row_mapping * row_mapping,
2007+
int64_t n_rows,
2008+
int64_t nb1, int64_t nb2) {
2009+
int64_t i = blockIdx.x;
2010+
2011+
if (i >= n_rows) {
2012+
return;
2013+
}
2014+
2015+
const int64_t i1 = row_mapping[i].i1;
2016+
const int64_t i2 = row_mapping[i].i2;
2017+
2018+
const char * dst_row_contiguous = dst_contiguous + i*nb1;
2019+
char * dst_row_original = dst_original + i1*nb1 + i2*nb2;
2020+
2021+
for (int j = threadIdx.x; j < nb1; j += blockDim.x) {
2022+
dst_row_original[j] = dst_row_contiguous[j];
2023+
}
2024+
}
2025+
2026+
//#define MMID_MEMCPY
2027+
19632028
static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
19642029
const ggml_tensor * src0 = dst->src[0];
19652030
const ggml_tensor * src1 = dst->src[1];
19662031
const ggml_tensor * ids = dst->src[2];
19672032

2033+
GGML_TENSOR_BINARY_OP_LOCALS
2034+
19682035
GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0->buffer) && "mul_mat_id does not support split buffers");
19692036

19702037
cudaStream_t stream = ctx.stream();
19712038

1972-
const size_t nb11 = src1->nb[1];
1973-
const size_t nb1 = dst->nb[1];
1974-
1975-
const int32_t id = ((int32_t *) dst->op_params)[0];
1976-
const int32_t n_as = src0->ne[2];
2039+
const int64_t n_as = ne02;
2040+
const int64_t n_ids = ids->ne[0];
19772041

19782042
std::vector<char> ids_host(ggml_nbytes(ids));
19792043
const char * ids_dev = (const char *) ids->data;
@@ -1982,27 +2046,47 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
19822046

19832047
ggml_tensor src0_row = *src0;
19842048
ggml_tensor src1_row = *src1;
1985-
ggml_tensor dst_row = *dst;
2049+
ggml_tensor dst_row = *dst;
19862050

19872051
char * src0_original = (char *) src0->data;
19882052
char * src1_original = (char *) src1->data;
19892053
char * dst_original = (char *) dst->data;
19902054

19912055
src0_row.ne[2] = 1;
19922056
src0_row.ne[3] = 1;
1993-
src0_row.nb[3] = src0->nb[2];
2057+
src0_row.nb[3] = nb02;
19942058

1995-
if (src1->ne[1] == 1) {
1996-
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
1997-
const int32_t row_id = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
2059+
src1_row.ne[1] = 1;
2060+
src1_row.ne[2] = 1;
2061+
src1_row.ne[3] = 1;
2062+
src1_row.nb[2] = nb11;
2063+
src1_row.nb[3] = nb11;
19982064

1999-
GGML_ASSERT(row_id >= 0 && row_id < n_as);
2065+
dst_row.ne[1] = 1;
2066+
dst_row.ne[2] = 1;
2067+
dst_row.ne[3] = 1;
2068+
dst_row.nb[2] = nb1;
2069+
dst_row.nb[3] = nb1;
20002070

2001-
src0_row.data = src0_original + row_id*src0->nb[2];
2002-
src1_row.data = src1_original + i01*src1->nb[1];
2003-
dst_row.data = dst_original + i01*dst->nb[1];
2071+
if (ne12 == 1) {
2072+
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
2073+
for (int64_t id = 0; id < n_ids; id++) {
2074+
const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
20042075

2005-
ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
2076+
GGML_ASSERT(i02 >= 0 && i02 < n_as);
2077+
2078+
const int64_t i11 = id % ne11;
2079+
const int64_t i12 = iid1;
2080+
2081+
const int64_t i1 = id;
2082+
const int64_t i2 = i12;
2083+
2084+
src0_row.data = src0_original + i02*nb02;
2085+
src1_row.data = src1_original + i11*nb11 + i12*nb12;
2086+
dst_row.data = dst_original + i1*nb1 + i2*nb2;
2087+
2088+
ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
2089+
}
20062090
}
20072091
} else {
20082092
ggml_cuda_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
@@ -2011,55 +2095,104 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
20112095
src1_row.data = src1_contiguous.get();
20122096
dst_row.data = dst_contiguous.get();
20132097

2014-
for (int32_t row_id = 0; row_id < n_as; ++row_id) {
2098+
for (int64_t i02 = 0; i02 < n_as; i02++) {
20152099
int64_t num_src1_rows = 0;
2016-
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
2017-
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
20182100

2019-
if (row_id_i != row_id) {
2020-
continue;
2021-
}
2101+
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
2102+
for (int64_t id = 0; id < n_ids; id++) {
2103+
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
2104+
2105+
if (row_id_i != i02) {
2106+
continue;
2107+
}
20222108

2023-
GGML_ASSERT(row_id >= 0 && row_id < n_as);
2109+
GGML_ASSERT(i02 >= 0 && i02 < n_as);
20242110

2025-
CUDA_CHECK(cudaMemcpyAsync(src1_contiguous.get() + num_src1_rows*nb11, src1_original + i01*nb11,
2026-
nb11, cudaMemcpyDeviceToDevice, stream));
2027-
num_src1_rows++;
2111+
#ifdef MMID_MEMCPY
2112+
const int64_t i11 = id % ne11;
2113+
const int64_t i12 = iid1;
2114+
CUDA_CHECK(cudaMemcpyAsync(src1_contiguous.get() + num_src1_rows*nb11,
2115+
src1_original + i11*nb11 + i12*nb12,
2116+
nb11, cudaMemcpyDeviceToDevice, stream));
2117+
#endif
2118+
num_src1_rows++;
2119+
}
20282120
}
20292121

20302122
if (num_src1_rows == 0) {
20312123
continue;
20322124
}
20332125

2034-
src0_row.data = src0_original + row_id*src0->nb[2];
2126+
#ifndef MMID_MEMCPY
2127+
ggml_cuda_pool_alloc<int> dev_cur_src1_row(ctx.pool(), 1);
2128+
ggml_cuda_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool(), num_src1_rows);
2129+
CUDA_CHECK(cudaMemsetAsync(dev_cur_src1_row.get(), 0, sizeof(int), stream));
20352130

2036-
src1_row.ne[1] = num_src1_rows;
2037-
dst_row.ne[1] = num_src1_rows;
2131+
{
2132+
dim3 block_dims(std::min((uint)nb11, 1024u));
2133+
dim3 grid_dims(ids->ne[1], n_ids);
2134+
k_copy_src1_to_contiguous<<<grid_dims, block_dims, 0, stream>>>(
2135+
src1_original, src1_contiguous.get(),
2136+
dev_cur_src1_row.get(), dev_row_mapping.get(),
2137+
ids_dev, i02, ids->nb[1], ids->nb[0],
2138+
ids->ne[1], n_ids,
2139+
ne11,
2140+
nb11, nb12);
2141+
CUDA_CHECK(cudaGetLastError());
2142+
}
2143+
#endif
2144+
2145+
src0_row.data = src0_original + i02*nb02;
20382146

2147+
GGML_ASSERT(nb11 == sizeof(float)*ne10);
2148+
GGML_ASSERT(nb1 == sizeof(float)*ne0);
2149+
2150+
src1_row.ne[1] = num_src1_rows;
20392151
src1_row.nb[1] = nb11;
20402152
src1_row.nb[2] = num_src1_rows*nb11;
20412153
src1_row.nb[3] = num_src1_rows*nb11;
20422154

2155+
dst_row.ne[1] = num_src1_rows;
20432156
dst_row.nb[1] = nb1;
20442157
dst_row.nb[2] = num_src1_rows*nb1;
20452158
dst_row.nb[3] = num_src1_rows*nb1;
20462159

20472160
ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
20482161

2162+
#ifndef MMID_MEMCPY
2163+
{
2164+
dim3 block_dims(std::min((uint)nb1, 1024u));
2165+
dim3 grid_dims(num_src1_rows);
2166+
k_copy_dst_from_contiguous<<<grid_dims, block_dims, 0, stream>>>(
2167+
dst_original, dst_contiguous.get(),
2168+
dev_row_mapping.get(),
2169+
num_src1_rows, nb1, nb2);
2170+
CUDA_CHECK(cudaGetLastError());
2171+
}
2172+
#endif
2173+
2174+
#ifdef MMID_MEMCPY
20492175
num_src1_rows = 0;
2050-
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
2051-
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
2176+
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
2177+
for (int64_t id = 0; id < n_ids; id++) {
2178+
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
20522179

2053-
if (row_id_i != row_id) {
2054-
continue;
2055-
}
2180+
if (row_id_i != i02) {
2181+
continue;
2182+
}
20562183

2057-
GGML_ASSERT(row_id >= 0 && row_id < n_as);
2184+
GGML_ASSERT(i02 >= 0 && i02 < n_as);
20582185

2059-
CUDA_CHECK(cudaMemcpyAsync(dst_original + i01*nb1, dst_contiguous.get() + num_src1_rows*nb1,
2060-
nb1, cudaMemcpyDeviceToDevice, stream));
2061-
num_src1_rows++;
2186+
const int64_t i1 = id;
2187+
const int64_t i2 = iid1;
2188+
2189+
CUDA_CHECK(cudaMemcpyAsync(dst_original + i1*nb1 + i2*nb2,
2190+
dst_contiguous.get() + num_src1_rows*nb1,
2191+
nb1, cudaMemcpyDeviceToDevice, stream));
2192+
num_src1_rows++;
2193+
}
20622194
}
2195+
#endif
20632196
}
20642197
}
20652198
}
@@ -2487,7 +2620,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
24872620
GGML_CALL static bool ggml_backend_cuda_offload_op(ggml_backend_t backend, const ggml_tensor * op) {
24882621
const int min_batch_size = 32;
24892622

2490-
return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS;
2623+
return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) ||
2624+
(op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID);
24912625

24922626
GGML_UNUSED(backend);
24932627
}

ggml-cuda/convert.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, h
4545
vals[ix] = x0[ix];
4646
}
4747

48+
__syncthreads();
49+
4850
#pragma unroll
4951
for (int iy = 0; iy < CUDA_Q8_0_NE_ALIGN; iy += 2*WARP_SIZE) {
5052
if (need_check && i0 + iy + 2*threadIdx.x >= k) {

0 commit comments

Comments
 (0)