Skip to content

Commit b90d6ed

Browse files
ikawrakowIwan Kawrakow
andauthored
Fix SER (CUDA) (#416)
* Fixing SER bugs * Cleanup * This seems to fix it. * This seems to work --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 1374062 commit b90d6ed

File tree

2 files changed

+38
-19
lines changed

2 files changed

+38
-19
lines changed

ggml/src/ggml-cuda.cu

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2203,7 +2203,7 @@ static __global__ void k_copy_dst_from_contiguous(char * __restrict__ dst_origin
22032203
}
22042204
}
22052205

2206-
static inline void prepare_row_mappigs(ggml_backend_cuda_context& ctx, int64_t n_as, int64_t n_ids,
2206+
static inline bool prepare_row_mappigs(ggml_backend_cuda_context& ctx, int64_t n_as, int64_t n_ids,
22072207
const ggml_tensor * ids, std::vector<int>& moe_counts, std::vector<int>& cum_moe_counts,
22082208
ggml_cuda_pool_alloc<mmid_row_mapping>& dev_row_mapping) {
22092209

@@ -2220,10 +2220,12 @@ static inline void prepare_row_mappigs(ggml_backend_cuda_context& ctx, int64_t n
22202220
moe_counts.resize(n_as, 0);
22212221
cum_moe_counts.resize(n_as + 1);
22222222

2223+
bool is_ser = false;
22232224
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
22242225
for (int64_t id = 0; id < n_ids; id++) {
22252226
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
22262227
if (row_id_i >= 0 && row_id_i < n_as) ++moe_counts[row_id_i];
2228+
else is_ser = true;
22272229
}
22282230
}
22292231
cum_moe_counts[0] = 0;
@@ -2244,16 +2246,20 @@ static inline void prepare_row_mappigs(ggml_backend_cuda_context& ctx, int64_t n
22442246

22452247
for (int i = 0; i < (int)n_as; ++i) cum_moe_counts[i] -= moe_counts[i];
22462248

2247-
CUDA_CHECK(cudaMemcpyAsync(dev_row_mapping.get(), rmapping.data(), cum_moe_counts[n_as]*sizeof(mmid_row_mapping), cudaMemcpyHostToDevice, stream));
2249+
CUDA_CHECK(cudaMemcpyAsync(dev_row_mapping.get(), rmapping.data(),
2250+
cum_moe_counts[n_as]*sizeof(mmid_row_mapping), cudaMemcpyHostToDevice, stream));
22482251
CUDA_CHECK(cudaStreamSynchronize(stream));
22492252

2253+
return is_ser;
22502254
}
22512255

22522256
static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
22532257
const ggml_tensor * src0 = dst->src[0];
22542258
const ggml_tensor * src1 = dst->src[1];
22552259
const ggml_tensor * ids = dst->src[2];
22562260

2261+
CUDA_CHECK(cudaMemsetAsync((char *)dst->data, 0, ggml_nbytes(dst), ctx.stream()));
2262+
22572263
if (src1->ne[1] == 1 && src1->ne[2] == 1 && src1->ne[3] == 1 &&
22582264
ggml_is_quantized(src0->type) &&
22592265
ggml_backend_buffer_is_cuda(src0->buffer) &&
@@ -2361,7 +2367,10 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
23612367

23622368
ggml_cuda_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool());
23632369
std::vector<int> moe_counts, cum_moe_counts;
2364-
prepare_row_mappigs(ctx, n_as, n_ids, ids, moe_counts, cum_moe_counts, dev_row_mapping);
2370+
bool is_ser = prepare_row_mappigs(ctx, n_as, n_ids, ids, moe_counts, cum_moe_counts, dev_row_mapping);
2371+
if (is_ser) {
2372+
CUDA_CHECK(cudaMemsetAsync(dst->data, 0, ggml_nbytes(dst), stream));
2373+
}
23652374

23662375
ggml_cuda_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
23672376
ggml_cuda_pool_alloc<char> dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
@@ -2519,13 +2528,16 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
25192528
auto local_src0 = *next->src[0];
25202529
local_src0.ne[2] = local_src0.ne[3] = 1;
25212530

2531+
CUDA_CHECK(cudaMemsetAsync(next->data, 0, ggml_nbytes(next), stream));
2532+
25222533
ggml_cuda_op_mul_mat_vec_q_id(ctx, &local_src0, &local_src1, ids, &local_next,
25232534
(const char *)next->src[0]->data, nullptr, dst_quantized.get(), (float *)next->data,
25242535
0, next->src[0]->ne[1], 1, dst_padded_col_size, stream);
25252536
CUDA_CHECK(cudaGetLastError());
25262537

25272538
return true;
25282539
} else {
2540+
CUDA_CHECK(cudaMemsetAsync(dst->data, 0, ggml_nbytes(dst), stream));
25292541
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(dst),
25302542
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst->data);
25312543
CUDA_CHECK(cudaGetLastError());
@@ -2534,7 +2546,6 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
25342546
}
25352547
}
25362548

2537-
25382549
GGML_TENSOR_BINARY_OP_LOCALS
25392550

25402551
GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0_1->buffer) && "mul_mat_id does not support split buffers");
@@ -2662,7 +2673,14 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
26622673
ggml_cuda_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool());
26632674
std::vector<int> moe_counts, cum_moe_counts;
26642675

2665-
prepare_row_mappigs(ctx, n_as, n_ids, ids, moe_counts, cum_moe_counts, dev_row_mapping);
2676+
bool is_ser = prepare_row_mappigs(ctx, n_as, n_ids, ids, moe_counts, cum_moe_counts, dev_row_mapping);
2677+
if (is_ser) {
2678+
if (fuse_down) {
2679+
CUDA_CHECK(cudaMemsetAsync(next->data, 0, ggml_nbytes(next), stream));
2680+
} else {
2681+
CUDA_CHECK(cudaMemsetAsync(dst->data, 0, ggml_nbytes(dst), stream));
2682+
}
2683+
}
26662684

26672685
for (int64_t i02 = 0; i02 < n_as; i02++) {
26682686
int64_t num_src1_rows = moe_counts[i02];

ggml/src/ggml-cuda/mmvq.cu

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -150,20 +150,21 @@ static __global__ void mul_mat_vec_q(
150150
char * cdst = (char *)dst + i2*nb2;
151151
int i02 = ids_data ? *(const int *)(ids_data + i2*ids_nb0) : i2;
152152
if (i02 < 0) {
153-
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
154-
constexpr int rows_per_cuda_block = 1;
155-
#else
156-
constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
157-
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
158-
const int row0 = rows_per_cuda_block*blockIdx.x;
159-
if (threadIdx.y == 0) {
160-
dst = (float *)cdst;
161-
for (int j = 0; j < ncols_y; ++j) {
162-
if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) {
163-
dst[j*nrows_dst + row0 + threadIdx.x] = 0;
164-
}
165-
}
166-
}
153+
// We clear the buffer via cudaMemset instead
154+
//#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
155+
// constexpr int rows_per_cuda_block = 1;
156+
//#else
157+
// constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
158+
//#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
159+
// const int row0 = rows_per_cuda_block*blockIdx.x;
160+
// if (threadIdx.y == 0) {
161+
// dst = (float *)cdst;
162+
// for (int j = 0; j < ncols_y; ++j) {
163+
// if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) {
164+
// dst[j*nrows_dst + row0 + threadIdx.x] = 0;
165+
// }
166+
// }
167+
// }
167168
return;
168169
}
169170
const char * cx = (const char *)vx + i02*nb02;

0 commit comments

Comments
 (0)