Skip to content

Commit 2356fb1

Browse files
CUDA: fix bad asserts for partial offload (#13337)
1 parent 764b856 commit 2356fb1

File tree

6 files changed

+21
-6
lines changed

6 files changed

+21
-6
lines changed

ggml/include/ggml.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -673,11 +673,15 @@ extern "C" {
673673
GGML_API bool ggml_is_3d (const struct ggml_tensor * tensor);
674674
GGML_API int ggml_n_dims (const struct ggml_tensor * tensor); // returns 1 for scalars
675675

676+
// returns whether the tensor elements can be iterated over with a flattened index (no gaps, no permutation)
676677
GGML_API bool ggml_is_contiguous (const struct ggml_tensor * tensor);
677678
GGML_API bool ggml_is_contiguous_0(const struct ggml_tensor * tensor); // same as ggml_is_contiguous()
678679
GGML_API bool ggml_is_contiguous_1(const struct ggml_tensor * tensor); // contiguous for dims >= 1
679680
GGML_API bool ggml_is_contiguous_2(const struct ggml_tensor * tensor); // contiguous for dims >= 2
680681

682+
// returns whether the tensor elements are allocated as one contiguous block of memory (no gaps, but permutation ok)
683+
GGML_API bool ggml_is_contiguously_allocated(const struct ggml_tensor * tensor);
684+
681685
// true for tensor that is stored in memory as CxWxHxN and has been permuted to WxHxCxN
682686
GGML_API bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor);
683687

ggml/src/ggml-cuda/fattn-common.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -719,6 +719,7 @@ void launch_fattn(
719719
size_t nb23 = V->nb[3];
720720

721721
if (need_f16_K && K->type != GGML_TYPE_F16) {
722+
GGML_ASSERT(ggml_is_contiguously_allocated(K));
722723
K_f16.alloc(ggml_nelements(K));
723724
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type);
724725
to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);
@@ -733,6 +734,7 @@ void launch_fattn(
733734
}
734735

735736
if (need_f16_V && V->type != GGML_TYPE_F16) {
737+
GGML_ASSERT(ggml_is_contiguously_allocated(V));
736738
V_f16.alloc(ggml_nelements(V));
737739
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
738740
to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1536,6 +1536,8 @@ static void ggml_cuda_op_mul_mat(
15361536

15371537
// If src0 is on a temporary compute buffer (partial offloading) there may be some padding that needs to be cleared:
15381538
if (ne00 % MATRIX_ROW_PADDING != 0 && ggml_is_quantized(src0->type) && ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE && src0->view_src == nullptr) {
1539+
GGML_ASSERT(ggml_is_contiguously_allocated(src0));
1540+
GGML_ASSERT(!src0->view_src);
15391541
const size_t nbytes_data = ggml_row_size(src0->type, (dev[id].row_high - dev[id].row_low)*ne00);
15401542
const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);
15411543
CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd + nbytes_data, 0, nbytes_padding, stream));
@@ -2067,10 +2069,11 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
20672069
}
20682070

20692071
ggml_tensor src0_slice = *src0;
2070-
src0_slice.ne[2] = 1;
2071-
src0_slice.nb[3] = src0_slice.nb[2];
2072-
src0_slice.data = (char *) src0->data + i02*nb02;
2073-
GGML_ASSERT(!ggml_cuda_should_use_mmq(src0->type, cc, ne11) || ne00 % MATRIX_ROW_PADDING == 0);
2072+
src0_slice.ne[2] = 1;
2073+
src0_slice.nb[3] = src0_slice.nb[2];
2074+
src0_slice.op = GGML_OP_VIEW;
2075+
src0_slice.view_src = dst->src[0]; // non-const pointer to src0
2076+
src0_slice.data = (char *) src0->data + i02*nb02;
20742077

20752078
ggml_tensor src1_slice;
20762079
memset(&src1_slice, 0, sizeof(src1_slice));

ggml/src/ggml-cuda/mmq.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,8 @@ void ggml_cuda_mul_mat_q(
9191

9292
// If src0 is a temporary compute buffer, clear any potential padding.
9393
if (ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE) {
94-
GGML_ASSERT(ggml_is_contiguous(src0));
94+
GGML_ASSERT(ggml_is_contiguously_allocated(src0));
95+
GGML_ASSERT(!src0->view_src);
9596
const size_t size_data = ggml_nbytes(src0);
9697
const size_t size_alloc = ggml_backend_buffer_get_alloc_size(src0->buffer, src0);
9798
if (size_alloc > size_data) {

ggml/src/ggml-cuda/mmvq.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -515,7 +515,8 @@ void ggml_cuda_mul_mat_vec_q(
515515

516516
// If src0 is a temporary compute buffer, clear any potential padding.
517517
if (ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE) {
518-
GGML_ASSERT(ggml_is_contiguous(src0));
518+
GGML_ASSERT(ggml_is_contiguously_allocated(src0));
519+
GGML_ASSERT(!src0->view_src);
519520
const size_t size_data = ggml_nbytes(src0);
520521
const size_t size_alloc = ggml_backend_buffer_get_alloc_size(src0->buffer, src0);
521522
if (size_alloc > size_data) {

ggml/src/ggml.c

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1299,6 +1299,10 @@ bool ggml_is_contiguous_2(const struct ggml_tensor * tensor) {
12991299
return ggml_is_contiguous_n(tensor, 2);
13001300
}
13011301

1302+
bool ggml_is_contiguously_allocated(const struct ggml_tensor * tensor) {
1303+
return ggml_nbytes(tensor) == ggml_nelements(tensor) * ggml_type_size(tensor->type)/ggml_blck_size(tensor->type);
1304+
}
1305+
13021306
bool ggml_is_permuted(const struct ggml_tensor * tensor) {
13031307
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
13041308

0 commit comments

Comments
 (0)