Skip to content

Commit 108dfde

Browse files
CUDA: fix logic for clearing padding with -ngl 0
1 parent 5215b91 commit 108dfde

File tree

6 files changed

+33
-6
lines changed

6 files changed

+33
-6
lines changed

ggml/include/ggml-backend.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ extern "C" {
3838
GGML_API ggml_backend_buffer_t ggml_backend_buft_alloc_buffer (ggml_backend_buffer_type_t buft, size_t size);
3939
GGML_API size_t ggml_backend_buft_get_alignment (ggml_backend_buffer_type_t buft);
4040
GGML_API size_t ggml_backend_buft_get_max_size (ggml_backend_buffer_type_t buft);
41-
GGML_API size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor);
41+
GGML_API size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor);
4242
GGML_API bool ggml_backend_buft_is_host (ggml_backend_buffer_type_t buft);
4343
GGML_API ggml_backend_dev_t ggml_backend_buft_get_device (ggml_backend_buffer_type_t buft);
4444

@@ -59,7 +59,7 @@ extern "C" {
5959
GGML_API enum ggml_status ggml_backend_buffer_init_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
6060
GGML_API size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer);
6161
GGML_API size_t ggml_backend_buffer_get_max_size (ggml_backend_buffer_t buffer);
62-
GGML_API size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
62+
GGML_API size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor);
6363
GGML_API void ggml_backend_buffer_clear (ggml_backend_buffer_t buffer, uint8_t value);
6464
GGML_API bool ggml_backend_buffer_is_host (ggml_backend_buffer_t buffer);
6565
GGML_API void ggml_backend_buffer_set_usage (ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage);

ggml/src/ggml-backend.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ size_t ggml_backend_buft_get_max_size(ggml_backend_buffer_type_t buft) {
5656
return SIZE_MAX;
5757
}
5858

59-
size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor) {
59+
size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
6060
// get_alloc_size is optional, defaults to ggml_nbytes
6161
if (buft->iface.get_alloc_size) {
6262
size_t size = buft->iface.get_alloc_size(buft, tensor);
@@ -152,7 +152,7 @@ size_t ggml_backend_buffer_get_max_size(ggml_backend_buffer_t buffer) {
152152
return ggml_backend_buft_get_max_size(ggml_backend_buffer_get_type(buffer));
153153
}
154154

155-
size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
155+
size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor) {
156156
return ggml_backend_buft_get_alloc_size(ggml_backend_buffer_get_type(buffer), tensor);
157157
}
158158

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -555,8 +555,8 @@ static enum ggml_status ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer
555555

556556
if (ggml_is_quantized(tensor->type) && tensor->view_src == nullptr && ggml_backend_buffer_get_usage(buffer) != GGML_BACKEND_BUFFER_USAGE_COMPUTE) {
557557
// initialize padding to 0 to avoid possible NaN values
558-
size_t original_size = ggml_nbytes(tensor);
559-
size_t padded_size = ggml_backend_buft_get_alloc_size(buffer->buft, tensor);
558+
const size_t original_size = ggml_nbytes(tensor);
559+
const size_t padded_size = ggml_backend_buft_get_alloc_size(buffer->buft, tensor);
560560

561561
if (padded_size > original_size) {
562562
ggml_cuda_set_device(ctx->device);
@@ -679,6 +679,7 @@ static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_t
679679

680680
if (ggml_is_quantized(tensor->type)) {
681681
if (ne0 % MATRIX_ROW_PADDING != 0) {
682+
GGML_ASSERT(tensor->nb[0] == ggml_element_size(tensor));
682683
size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
683684
}
684685
}
@@ -800,6 +801,7 @@ static void * ggml_backend_cuda_split_buffer_get_base(ggml_backend_buffer_t buff
800801

801802
static enum ggml_status ggml_backend_cuda_split_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
802803
GGML_ASSERT(tensor->view_src == nullptr); // views of split tensors are not supported
804+
GGML_ASSERT(ggml_is_contiguous(tensor) && "split buffers only supported for contiguous tensors");
803805

804806
ggml_backend_cuda_split_buffer_context * ctx = (ggml_backend_cuda_split_buffer_context *)buffer->context;
805807
ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *)buffer->buft->context;
@@ -851,6 +853,7 @@ static void ggml_backend_cuda_split_buffer_set_tensor(ggml_backend_buffer_t buff
851853
// split tensors must always be set in their entirety at once
852854
GGML_ASSERT(offset == 0);
853855
GGML_ASSERT(size == ggml_nbytes(tensor));
856+
GGML_ASSERT(ggml_is_contiguous(tensor) && "split buffers only supported for contiguous tensors");
854857

855858
ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *)buffer->buft->context;
856859

@@ -889,6 +892,7 @@ static void ggml_backend_cuda_split_buffer_get_tensor(ggml_backend_buffer_t buff
889892
// split tensors must always be set in their entirety at once
890893
GGML_ASSERT(offset == 0);
891894
GGML_ASSERT(size == ggml_nbytes(tensor));
895+
GGML_ASSERT(ggml_is_contiguous(tensor) && "split buffers only supported for contiguous tensors");
892896

893897
ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *)buffer->buft->context;
894898

@@ -970,6 +974,7 @@ static size_t ggml_backend_cuda_split_buffer_type_get_alignment(ggml_backend_buf
970974

971975
static size_t ggml_backend_cuda_split_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
972976
ggml_backend_cuda_split_buffer_type_context * ctx = (ggml_backend_cuda_split_buffer_type_context *)buft->context;
977+
GGML_ASSERT(ggml_is_contiguous(tensor) && "split buffers only supported for contiguous tensors");
973978

974979
size_t total_size = 0;
975980

@@ -2065,6 +2070,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
20652070
src0_slice.ne[2] = 1;
20662071
src0_slice.nb[3] = src0_slice.nb[2];
20672072
src0_slice.data = (char *) src0->data + i02*nb02;
2073+
GGML_ASSERT(!ggml_cuda_should_use_mmq(src0->type, cc, ne11) || ne00 % MATRIX_ROW_PADDING == 0);
20682074

20692075
ggml_tensor src1_slice;
20702076
memset(&src1_slice, 0, sizeof(src1_slice));

ggml/src/ggml-cuda/mmq.cu

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,16 @@ void ggml_cuda_mul_mat_q(
8989
const float * src1_d = (const float *) src1->data;
9090
float * dst_d = (float *) dst->data;
9191

92+
// If src0 is a temporary compute buffer, clear any potential padding.
93+
if (ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE) {
94+
GGML_ASSERT(ggml_is_contiguous(src0));
95+
const size_t size_data = ggml_nbytes(src0);
96+
const size_t size_alloc = ggml_backend_buffer_get_alloc_size(src0->buffer, src0);
97+
if (size_alloc > size_data) {
98+
CUDA_CHECK(cudaMemsetAsync((char *) src0->data + size_data, 0, size_alloc - size_data), stream);
99+
}
100+
}
101+
92102
const int64_t ne10_padded = GGML_PAD(ne10, MATRIX_ROW_PADDING);
93103

94104
const int64_t s01 = src0->nb[1] / ts_src0;

ggml/src/ggml-cuda/mmvq.cu

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -513,6 +513,16 @@ void ggml_cuda_mul_mat_vec_q(
513513
const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr;
514514
float * dst_d = (float *) dst->data;
515515

516+
// If src0 is a temporary compute buffer, clear any potential padding.
517+
if (ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE) {
518+
GGML_ASSERT(ggml_is_contiguous(src0));
519+
const size_t size_data = ggml_nbytes(src0);
520+
const size_t size_alloc = ggml_backend_buffer_get_alloc_size(src0->buffer, src0);
521+
if (size_alloc > size_data) {
522+
CUDA_CHECK(cudaMemsetAsync((char *) src0->data + size_data, 0, size_alloc - size_data, stream));
523+
}
524+
}
525+
516526
const int64_t ne10_padded = GGML_PAD(ne10, MATRIX_ROW_PADDING);
517527
ggml_cuda_pool_alloc<char> src1_q8_1(ctx.pool(), ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1);
518528
{

ggml/src/ggml-cuda/quantize.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ void quantize_mmq_q8_1_cuda(
163163
const float * x, const int32_t * ids, void * vy, const ggml_type type_src0,
164164
const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
165165
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
166+
GGML_ASSERT(ne00 % 4 == 0);
166167
GGML_ASSERT(ne0 % (4*QK8_1) == 0);
167168

168169
const int64_t block_num_x = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);

0 commit comments

Comments
 (0)