Skip to content

Commit 70f23e5

Browse files
JohannesGaesslerarthw
authored andcommitted
CUDA: fix MMQ for non-contiguous src0, add tests (ggml-org#10021)
* CUDA: fix MMQ for non-contiguous src0, add tests * revise test code
1 parent 9267d7f commit 70f23e5

File tree

4 files changed

+73
-29
lines changed

4 files changed

+73
-29
lines changed

ggml/src/ggml-cuda.cu

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1151,8 +1151,8 @@ static cudaError_t ggml_cuda_cpy_tensor_2d(
11511151
void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) {
11521152

11531153
GGML_ASSERT(ggml_backend_buffer_is_cuda(src->buffer));
1154-
char * src_ptr = (char *) src->data;
1155-
char * dst_ptr = (char *) dst;
1154+
const char * src_ptr = (const char *) src->data;
1155+
char * dst_ptr = (char *) dst;
11561156

11571157
const int64_t ne0 = src->ne[0];
11581158
const int64_t nb0 = src->nb[0];
@@ -1162,7 +1162,7 @@ static cudaError_t ggml_cuda_cpy_tensor_2d(
11621162
const enum ggml_type type = src->type;
11631163
const int64_t ts = ggml_type_size(type);
11641164
const int64_t bs = ggml_blck_size(type);
1165-
int64_t i1_diff = i1_high - i1_low;
1165+
const int64_t i1_diff = i1_high - i1_low;
11661166

11671167
const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3;
11681168
if (nb0 == ts && nb1 == ts*ne0/bs) {
@@ -1479,13 +1479,17 @@ static void ggml_cuda_op_mul_mat(
14791479
if (src0_is_contiguous) {
14801480
dev[id].src0_dd = split ? (char *) src0_extra->data_device[id] : (char *) src0->data;
14811481
} else {
1482-
dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(ctx.pool(id), ggml_nbytes(src0));
1482+
// If src0 is not contiguous it will be copied to a temporary buffer, it may then be necessary to clear padding.
1483+
const size_t nbytes_data = ggml_nbytes(src0);
1484+
const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);
1485+
dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(ctx.pool(id), nbytes_data + nbytes_padding);
1486+
CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd + nbytes_data , 0, nbytes_padding, stream));
14831487
}
14841488

1485-
// If src0 is on a temporary compute buffers (partial offloading) there may be some padding that needs to be cleared:
1489+
// If src0 is on a temporary compute buffer (partial offloading) there may be some padding that needs to be cleared:
14861490
if (ne00 % MATRIX_ROW_PADDING != 0 && ggml_is_quantized(src0->type) && ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE && src0->view_src == nullptr) {
1487-
const int64_t nbytes_data = ggml_row_size(src0->type, (dev[id].row_high - dev[id].row_low)*ne00);
1488-
const int64_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);
1491+
const size_t nbytes_data = ggml_row_size(src0->type, (dev[id].row_high - dev[id].row_low)*ne00);
1492+
const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);
14891493
CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd + nbytes_data , 0, nbytes_padding, stream));
14901494
}
14911495

ggml/src/ggml-cuda/mmq.cu

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,14 @@ void ggml_cuda_op_mul_mat_q(
88

99
const int64_t ne00 = src0->ne[0];
1010

11-
const int64_t nb01 = src0->nb[1];
12-
1311
const int64_t ne10 = src1->ne[0];
1412
const int64_t ne11 = src1->ne[1];
1513
GGML_ASSERT(ne10 % QK8_1 == 0);
1614

1715
const int64_t ne0 = dst->ne[0];
1816

1917
const int64_t row_diff = row_high - row_low;
20-
const int64_t stride00 = nb01 / ggml_type_size(src0->type);
18+
const int64_t stride00 = ne00 / ggml_blck_size(src0->type);
2119

2220
int id = ggml_cuda_get_device();
2321
const int compute_capability = ggml_cuda_info().devices[id].cc;

ggml/src/ggml.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3464,7 +3464,7 @@ int64_t ggml_nrows(const struct ggml_tensor * tensor) {
34643464

34653465
size_t ggml_nbytes(const struct ggml_tensor * tensor) {
34663466
size_t nbytes;
3467-
size_t blck_size = ggml_blck_size(tensor->type);
3467+
const size_t blck_size = ggml_blck_size(tensor->type);
34683468
if (blck_size == 1) {
34693469
nbytes = ggml_type_size(tensor->type);
34703470
for (int i = 0; i < GGML_MAX_DIMS; ++i) {

tests/test-backend-ops.cpp

Lines changed: 60 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1650,11 +1650,12 @@ struct test_mul_mat : public test_case {
16501650
const int64_t m;
16511651
const int64_t n;
16521652
const int64_t k;
1653-
const std::array<int64_t, 2> bs; // dims 3 and 4
1654-
const std::array<int64_t, 2> nr; // repeat in dims 3 and 4
1653+
const std::array<int64_t, 2> bs; // dims 3 and 4
1654+
const std::array<int64_t, 2> nr; // repeat in dims 3 and 4
1655+
const std::array<int64_t, 4> per; // permutation of dimensions
16551656

16561657
std::string vars() override {
1657-
return VARS_TO_STR7(type_a, type_b, m, n, k, bs, nr);
1658+
return VARS_TO_STR8(type_a, type_b, m, n, k, bs, nr, per);
16581659
}
16591660

16601661
double max_nmse_err() override {
@@ -1669,17 +1670,44 @@ struct test_mul_mat : public test_case {
16691670
test_mul_mat(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
16701671
int64_t m = 32, int64_t n = 32, int64_t k = 32,
16711672
std::array<int64_t, 2> bs = {10, 10},
1672-
std::array<int64_t, 2> nr = {2, 2})
1673-
: type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr) {}
1673+
std::array<int64_t, 2> nr = {2, 2},
1674+
std::array<int64_t, 4> per = {0, 1, 2, 3})
1675+
: type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr), per(per) {}
16741676

16751677
ggml_tensor * build_graph(ggml_context * ctx) override {
16761678
// C^T = A * B^T: (k, m) * (k, n) => (m, n)
1677-
ggml_tensor * a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0] , bs[1]);
1678-
ggml_tensor * b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]);
1679-
ggml_set_param(ctx, a);
1680-
ggml_set_param(ctx, b);
1681-
ggml_set_name(a, "a");
1682-
ggml_set_name(b, "b");
1679+
ggml_tensor * a;
1680+
ggml_tensor * b;
1681+
1682+
const int npermuted = (per[0] != 0) + (per[1] != 1) + (per[2] != 2) + (per[3] != 3);
1683+
if (npermuted > 0) {
1684+
GGML_ASSERT(npermuted == 2);
1685+
GGML_ASSERT(!ggml_is_quantized(type_a) || per[0] == 0);
1686+
GGML_ASSERT(!ggml_is_quantized(type_b) || per[0] == 0);
1687+
1688+
// Create tensors with the permuted dimensions, then permute them back to the dimensions given by m,n,k.
1689+
const int64_t ne_a[4] = {k, m, bs[0], bs[1]};
1690+
const int64_t ne_b[4] = {k, n, bs[0]*nr[0], bs[1]*nr[1]};
1691+
1692+
a = ggml_new_tensor_4d(ctx, type_a, ne_a[per[0]], ne_a[per[1]], ne_a[per[2]], ne_a[per[3]]);
1693+
b = ggml_new_tensor_4d(ctx, type_b, ne_b[per[0]], ne_b[per[1]], ne_b[per[2]], ne_b[per[3]]);
1694+
ggml_set_param(ctx, a);
1695+
ggml_set_param(ctx, b);
1696+
ggml_set_name(a, "a");
1697+
ggml_set_name(b, "b");
1698+
1699+
a = ggml_permute(ctx, a, per[0], per[1], per[2], per[3]);
1700+
b = ggml_permute(ctx, b, per[0], per[1], per[2], per[3]);
1701+
ggml_set_name(a, "a_permuted");
1702+
ggml_set_name(b, "b_permuted");
1703+
} else {
1704+
a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0], bs[1]);
1705+
b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]);
1706+
ggml_set_param(ctx, a);
1707+
ggml_set_param(ctx, b);
1708+
ggml_set_name(a, "a");
1709+
ggml_set_name(b, "b");
1710+
}
16831711

16841712
ggml_tensor * out = ggml_mul_mat(ctx, a, b);
16851713
ggml_set_name(out, "out");
@@ -3478,13 +3506,14 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
34783506
#if 1
34793507
for (ggml_type type_a : base_types) {
34803508
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
3481-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1, 1}, {1, 1}));
3482-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {1, 1}));
3483-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {2, 1}));
3484-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 1}));
3485-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 1}));
3486-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 2}));
3487-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 2}));
3509+
// test cases without permutation
3510+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1, 1}, {1, 1}));
3511+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {1, 1}));
3512+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {2, 1}));
3513+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 1}));
3514+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 1}));
3515+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 2}));
3516+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 2}));
34883517

34893518
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, { 1, 1}, {1, 1}));
34903519
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 1}, {1, 1}));
@@ -3493,6 +3522,19 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
34933522
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 1}));
34943523
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {1, 2}));
34953524
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 2}));
3525+
3526+
// test cases with permutation
3527+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {2, 3}, {1, 1}, {0, 2, 1, 3}));
3528+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {2, 3}, {1, 1}, {0, 1, 3, 2}));
3529+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {2, 3}, {1, 1}, {0, 3, 2, 1}));
3530+
3531+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, 256, {2, 3}, {1, 1}, {0, 2, 1, 3}));
3532+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, 256, {2, 3}, {1, 1}, {0, 1, 3, 2}));
3533+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, 256, {2, 3}, {1, 1}, {0, 3, 2, 1}));
3534+
3535+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 2, 1, 3}));
3536+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 1, 3, 2}));
3537+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 3, 2, 1}));
34963538
}
34973539
}
34983540
for (ggml_type type_a : other_types) {

0 commit comments

Comments
 (0)