Skip to content

Commit 0d56246

Browse files
slarenggerganov
andauthored
ggml : group all experts in a single ggml_mul_mat_id (#6505)
* ggml : group all experts in a single ggml_mul_mat_id cuda : improve mmid row copy * cuda : fix bin bcast with non-cont src0 * test-backend-ops : only run all mul mat tests for base types * llama : disable moe offloading with SYCL --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent 03c0946 commit 0d56246

File tree

12 files changed

+971
-821
lines changed

12 files changed

+971
-821
lines changed

examples/imatrix/imatrix.cpp

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class IMatrixCollector {
4444
std::mutex m_mutex;
4545
int m_last_call = 0;
4646
std::vector<float> m_src1_data;
47-
std::vector<int> m_ids; // the expert ids from ggml_mul_mat_id
47+
std::vector<char> m_ids; // the expert ids from ggml_mul_mat_id
4848
//
4949
void save_imatrix(const char * file_name) const;
5050
void keep_imatrix(int ncall) const;
@@ -81,6 +81,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
8181
if (ask) {
8282
if (t->op == GGML_OP_MUL_MAT_ID) return true; // collect all indirect matrix multiplications
8383
if (t->op != GGML_OP_MUL_MAT) return false;
84+
// why are small batches ignored (<16 tokens)?
8485
if (src1->ne[1] < 16 || src1->type != GGML_TYPE_F32) return false;
8586
if (!(wname.substr(0, 4) == "blk." || (m_params.collect_output_weight && wname == "output.weight"))) return false;
8687
return true;
@@ -101,14 +102,19 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
101102
// this has been adapted to the new format of storing merged experts in a single 3d tensor
102103
// ref: https://github.com/ggerganov/llama.cpp/pull/6387
103104
if (t->op == GGML_OP_MUL_MAT_ID) {
104-
const int idx = ((int32_t *) t->op_params)[0];
105+
// ids -> [n_experts_used, n_tokens]
106+
// src1 -> [cols, n_expert_used, n_tokens]
105107
const ggml_tensor * ids = t->src[2];
106108
const int n_as = src0->ne[2];
109+
const int n_ids = ids->ne[0];
107110

108111
// the top-k selected expert ids are stored in the ids tensor
109112
// for simplicity, always copy ids to host, because it is small
110-
GGML_ASSERT(ids->ne[1] == src1->ne[1]);
111-
m_ids.resize(ggml_nbytes(ids)/sizeof(int));
113+
// take into account that ids is not contiguous!
114+
115+
GGML_ASSERT(ids->ne[1] == src1->ne[2]);
116+
117+
m_ids.resize(ggml_nbytes(ids));
112118
ggml_backend_tensor_get(ids, m_ids.data(), 0, ggml_nbytes(ids));
113119

114120
auto & e = m_stats[wname];
@@ -118,26 +124,35 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
118124
// using the following line, we can correct for that if needed by replacing the line above with:
119125
//if (idx == t->src[0]->ne[0] - 1) ++e.ncall;
120126

127+
if (e.values.empty()) {
128+
e.values.resize(src1->ne[0]*n_as, 0);
129+
}
130+
else if (e.values.size() != (size_t)src1->ne[0]*n_as) {
131+
fprintf(stderr, "Oops: inconsistent size for %s (%d vs %d)\n", wname.c_str(), (int)e.values.size(), (int)src1->ne[0]*n_as);
132+
exit(1); //GGML_ASSERT(false);
133+
}
134+
if (m_params.verbosity > 1) {
135+
printf("%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_call, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[2], (int)src1->type);
136+
}
121137
// loop over all possible experts, regardless if they are used or not in the batch
122138
for (int ex = 0; ex < n_as; ++ex) {
123139
size_t e_start = ex*src1->ne[0];
124-
if (e.values.empty()) {
125-
e.values.resize(src1->ne[0]*n_as, 0);
126-
}
127-
else if (e.values.size() != (size_t)src1->ne[0]*n_as) {
128-
fprintf(stderr, "Oops: inconsistent size for %s (%d vs %d)\n", wname.c_str(), (int)e.values.size(), (int)src1->ne[0]*n_as);
129-
exit(1); //GGML_ASSERT(false);
130-
}
131-
if (m_params.verbosity > 1) {
132-
printf("%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_call, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[1], (int)src1->type);
133-
}
134-
for (int row = 0; row < (int)src1->ne[1]; ++row) {
135-
const int excur = m_ids[row*n_as + idx];
136-
GGML_ASSERT(excur >= 0 && excur < n_as); // sanity check
137-
if (excur != ex) continue;
138-
const float * x = data + row * src1->ne[0];
139-
for (int j = 0; j < (int)src1->ne[0]; ++j) {
140-
e.values[e_start + j] += x[j]*x[j];
140+
141+
for (int idx = 0; idx < n_ids; ++idx) {
142+
for (int row = 0; row < (int)src1->ne[2]; ++row) {
143+
const int excur = *(const int32_t *) (m_ids.data() + row*ids->nb[1] + idx*ids->nb[0]);
144+
145+
GGML_ASSERT(excur >= 0 && excur < n_as); // sanity check
146+
147+
if (excur != ex) continue;
148+
149+
const int64_t i11 = idx % src1->ne[1];
150+
const int64_t i12 = row;
151+
const float * x = (const float *)((const char *)data + i11*src1->nb[1] + i12*src1->nb[2]);
152+
153+
for (int j = 0; j < (int)src1->ne[0]; ++j) {
154+
e.values[e_start + j] += x[j]*x[j];
155+
}
141156
}
142157
}
143158
if (e.ncall > m_last_call) {

ggml-cuda.cu

Lines changed: 134 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1231,7 +1231,7 @@ static void ggml_cuda_op_mul_mat_cublas(
12311231

12321232
if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
12331233
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
1234-
ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool());
1234+
ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool(id));
12351235
if (src0->type != GGML_TYPE_F16) {
12361236
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type);
12371237
GGML_ASSERT(to_fp16_cuda != nullptr);
@@ -1241,7 +1241,7 @@ static void ggml_cuda_op_mul_mat_cublas(
12411241
}
12421242
const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16.get();
12431243

1244-
ggml_cuda_pool_alloc<half> src1_as_f16(ctx.pool());
1244+
ggml_cuda_pool_alloc<half> src1_as_f16(ctx.pool(id));
12451245
if (src1->type != GGML_TYPE_F16) {
12461246
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
12471247
GGML_ASSERT(to_fp16_cuda != nullptr);
@@ -1250,7 +1250,7 @@ static void ggml_cuda_op_mul_mat_cublas(
12501250
to_fp16_cuda(src1_ddf_i, src1_as_f16.get(), ne, stream);
12511251
}
12521252
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16.get();
1253-
ggml_cuda_pool_alloc<half> dst_f16(ctx.pool(), row_diff*src1_ncols);
1253+
ggml_cuda_pool_alloc<half> dst_f16(ctx.pool(id), row_diff*src1_ncols);
12541254

12551255
const half alpha_f16 = 1.0f;
12561256
const half beta_f16 = 0.0f;
@@ -1960,20 +1960,73 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
19601960
}
19611961
}
19621962

1963+
struct mmid_row_mapping {
1964+
int32_t i1;
1965+
int32_t i2;
1966+
};
1967+
1968+
static __global__ void k_copy_src1_to_contiguous(const char * __restrict__ src1_original, char * __restrict__ src1_contiguous,
1969+
int * __restrict__ cur_src1_row, mmid_row_mapping * __restrict__ row_mapping,
1970+
const char * __restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0,
1971+
int64_t ne11, int64_t ne10,
1972+
size_t nb11, size_t nb12) {
1973+
int32_t iid1 = blockIdx.x;
1974+
int32_t id = blockIdx.y;
1975+
1976+
const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0);
1977+
1978+
if (row_id_i != i02) {
1979+
return;
1980+
}
1981+
1982+
const int64_t i11 = id % ne11;
1983+
const int64_t i12 = iid1;
1984+
1985+
__shared__ int src1_row;
1986+
if (threadIdx.x == 0) {
1987+
src1_row = atomicAdd(cur_src1_row, 1);
1988+
row_mapping[src1_row] = {id, iid1};
1989+
}
1990+
__syncthreads();
1991+
1992+
const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12);
1993+
float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11);
1994+
1995+
for (int i = threadIdx.x; i < ne10; i += blockDim.x) {
1996+
src1_row_contiguous[i] = src1_row_original[i];
1997+
}
1998+
}
1999+
2000+
static __global__ void k_copy_dst_from_contiguous(char * __restrict__ dst_original, const char * __restrict__ dst_contiguous,
2001+
const mmid_row_mapping * __restrict__ row_mapping,
2002+
int64_t ne0,
2003+
size_t nb1, size_t nb2) {
2004+
int32_t i = blockIdx.x;
2005+
2006+
const int32_t i1 = row_mapping[i].i1;
2007+
const int32_t i2 = row_mapping[i].i2;
2008+
2009+
const float * dst_row_contiguous = (const float *)(dst_contiguous + i*nb1);
2010+
float * dst_row_original = (float *)(dst_original + i1*nb1 + i2*nb2);
2011+
2012+
for (int j = threadIdx.x; j < ne0; j += blockDim.x) {
2013+
dst_row_original[j] = dst_row_contiguous[j];
2014+
}
2015+
}
2016+
19632017
static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
19642018
const ggml_tensor * src0 = dst->src[0];
19652019
const ggml_tensor * src1 = dst->src[1];
19662020
const ggml_tensor * ids = dst->src[2];
19672021

2022+
GGML_TENSOR_BINARY_OP_LOCALS
2023+
19682024
GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0->buffer) && "mul_mat_id does not support split buffers");
19692025

19702026
cudaStream_t stream = ctx.stream();
19712027

1972-
const size_t nb11 = src1->nb[1];
1973-
const size_t nb1 = dst->nb[1];
1974-
1975-
const int32_t id = ((int32_t *) dst->op_params)[0];
1976-
const int32_t n_as = src0->ne[2];
2028+
const int64_t n_as = ne02;
2029+
const int64_t n_ids = ids->ne[0];
19772030

19782031
std::vector<char> ids_host(ggml_nbytes(ids));
19792032
const char * ids_dev = (const char *) ids->data;
@@ -1982,27 +2035,47 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
19822035

19832036
ggml_tensor src0_row = *src0;
19842037
ggml_tensor src1_row = *src1;
1985-
ggml_tensor dst_row = *dst;
2038+
ggml_tensor dst_row = *dst;
19862039

19872040
char * src0_original = (char *) src0->data;
19882041
char * src1_original = (char *) src1->data;
19892042
char * dst_original = (char *) dst->data;
19902043

19912044
src0_row.ne[2] = 1;
19922045
src0_row.ne[3] = 1;
1993-
src0_row.nb[3] = src0->nb[2];
2046+
src0_row.nb[3] = nb02;
19942047

1995-
if (src1->ne[1] == 1) {
1996-
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
1997-
const int32_t row_id = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
2048+
src1_row.ne[1] = 1;
2049+
src1_row.ne[2] = 1;
2050+
src1_row.ne[3] = 1;
2051+
src1_row.nb[2] = nb11;
2052+
src1_row.nb[3] = nb11;
19982053

1999-
GGML_ASSERT(row_id >= 0 && row_id < n_as);
2054+
dst_row.ne[1] = 1;
2055+
dst_row.ne[2] = 1;
2056+
dst_row.ne[3] = 1;
2057+
dst_row.nb[2] = nb1;
2058+
dst_row.nb[3] = nb1;
20002059

2001-
src0_row.data = src0_original + row_id*src0->nb[2];
2002-
src1_row.data = src1_original + i01*src1->nb[1];
2003-
dst_row.data = dst_original + i01*dst->nb[1];
2060+
if (ne12 == 1) {
2061+
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
2062+
for (int64_t id = 0; id < n_ids; id++) {
2063+
const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
20042064

2005-
ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
2065+
GGML_ASSERT(i02 >= 0 && i02 < n_as);
2066+
2067+
const int64_t i11 = id % ne11;
2068+
const int64_t i12 = iid1;
2069+
2070+
const int64_t i1 = id;
2071+
const int64_t i2 = i12;
2072+
2073+
src0_row.data = src0_original + i02*nb02;
2074+
src1_row.data = src1_original + i11*nb11 + i12*nb12;
2075+
dst_row.data = dst_original + i1*nb1 + i2*nb2;
2076+
2077+
ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
2078+
}
20062079
}
20072080
} else {
20082081
ggml_cuda_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
@@ -2011,54 +2084,69 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
20112084
src1_row.data = src1_contiguous.get();
20122085
dst_row.data = dst_contiguous.get();
20132086

2014-
for (int32_t row_id = 0; row_id < n_as; ++row_id) {
2087+
for (int64_t i02 = 0; i02 < n_as; i02++) {
20152088
int64_t num_src1_rows = 0;
2016-
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
2017-
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
20182089

2019-
if (row_id_i != row_id) {
2020-
continue;
2021-
}
2090+
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
2091+
for (int64_t id = 0; id < n_ids; id++) {
2092+
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
20222093

2023-
GGML_ASSERT(row_id >= 0 && row_id < n_as);
2094+
GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
20242095

2025-
CUDA_CHECK(cudaMemcpyAsync(src1_contiguous.get() + num_src1_rows*nb11, src1_original + i01*nb11,
2026-
nb11, cudaMemcpyDeviceToDevice, stream));
2027-
num_src1_rows++;
2096+
if (row_id_i != i02) {
2097+
continue;
2098+
}
2099+
2100+
num_src1_rows++;
2101+
}
20282102
}
20292103

20302104
if (num_src1_rows == 0) {
20312105
continue;
20322106
}
20332107

2034-
src0_row.data = src0_original + row_id*src0->nb[2];
2108+
ggml_cuda_pool_alloc<int> dev_cur_src1_row(ctx.pool(), 1);
2109+
ggml_cuda_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool(), num_src1_rows);
2110+
CUDA_CHECK(cudaMemsetAsync(dev_cur_src1_row.get(), 0, sizeof(int), stream));
20352111

2036-
src1_row.ne[1] = num_src1_rows;
2037-
dst_row.ne[1] = num_src1_rows;
2112+
{
2113+
dim3 block_dims(std::min((unsigned int)ne10, 768u));
2114+
dim3 grid_dims(ids->ne[1], n_ids);
2115+
k_copy_src1_to_contiguous<<<grid_dims, block_dims, 0, stream>>>(
2116+
src1_original, src1_contiguous.get(),
2117+
dev_cur_src1_row.get(), dev_row_mapping.get(),
2118+
ids_dev, i02, ids->nb[1], ids->nb[0],
2119+
ne11, ne10,
2120+
nb11, nb12);
2121+
CUDA_CHECK(cudaGetLastError());
2122+
}
2123+
2124+
src0_row.data = src0_original + i02*nb02;
20382125

2126+
GGML_ASSERT(nb11 == sizeof(float)*ne10);
2127+
GGML_ASSERT(nb1 == sizeof(float)*ne0);
2128+
2129+
src1_row.ne[1] = num_src1_rows;
20392130
src1_row.nb[1] = nb11;
20402131
src1_row.nb[2] = num_src1_rows*nb11;
20412132
src1_row.nb[3] = num_src1_rows*nb11;
20422133

2134+
dst_row.ne[1] = num_src1_rows;
20432135
dst_row.nb[1] = nb1;
20442136
dst_row.nb[2] = num_src1_rows*nb1;
20452137
dst_row.nb[3] = num_src1_rows*nb1;
20462138

20472139
ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
20482140

2049-
num_src1_rows = 0;
2050-
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
2051-
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
2052-
2053-
if (row_id_i != row_id) {
2054-
continue;
2055-
}
2056-
2057-
GGML_ASSERT(row_id >= 0 && row_id < n_as);
2058-
2059-
CUDA_CHECK(cudaMemcpyAsync(dst_original + i01*nb1, dst_contiguous.get() + num_src1_rows*nb1,
2060-
nb1, cudaMemcpyDeviceToDevice, stream));
2061-
num_src1_rows++;
2141+
{
2142+
dim3 block_dims(std::min((unsigned int)ne0, 768u));
2143+
dim3 grid_dims(num_src1_rows);
2144+
k_copy_dst_from_contiguous<<<grid_dims, block_dims, 0, stream>>>(
2145+
dst_original, dst_contiguous.get(),
2146+
dev_row_mapping.get(),
2147+
ne0,
2148+
nb1, nb2);
2149+
CUDA_CHECK(cudaGetLastError());
20622150
}
20632151
}
20642152
}
@@ -2487,7 +2575,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
24872575
GGML_CALL static bool ggml_backend_cuda_offload_op(ggml_backend_t backend, const ggml_tensor * op) {
24882576
const int min_batch_size = 32;
24892577

2490-
return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS;
2578+
return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) ||
2579+
(op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID);
24912580

24922581
GGML_UNUSED(backend);
24932582
}

0 commit comments

Comments
 (0)