Skip to content

CUDA: fix MMQ for non-contiguous src0, add tests #10021

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions ggml/src/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1151,8 +1151,8 @@ static cudaError_t ggml_cuda_cpy_tensor_2d(
void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) {

GGML_ASSERT(ggml_backend_buffer_is_cuda(src->buffer));
char * src_ptr = (char *) src->data;
char * dst_ptr = (char *) dst;
const char * src_ptr = (const char *) src->data;
char * dst_ptr = (char *) dst;

const int64_t ne0 = src->ne[0];
const int64_t nb0 = src->nb[0];
Expand All @@ -1162,7 +1162,7 @@ static cudaError_t ggml_cuda_cpy_tensor_2d(
const enum ggml_type type = src->type;
const int64_t ts = ggml_type_size(type);
const int64_t bs = ggml_blck_size(type);
int64_t i1_diff = i1_high - i1_low;
const int64_t i1_diff = i1_high - i1_low;

const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3;
if (nb0 == ts && nb1 == ts*ne0/bs) {
Expand Down Expand Up @@ -1479,13 +1479,17 @@ static void ggml_cuda_op_mul_mat(
if (src0_is_contiguous) {
dev[id].src0_dd = split ? (char *) src0_extra->data_device[id] : (char *) src0->data;
} else {
dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(ctx.pool(id), ggml_nbytes(src0));
// If src0 is not contiguous it will be copied to a temporary buffer, it may then be necessary to clear padding.
const size_t nbytes_data = ggml_nbytes(src0);
const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);
dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(ctx.pool(id), nbytes_data + nbytes_padding);
CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd + nbytes_data , 0, nbytes_padding, stream));
}

// If src0 is on a temporary compute buffers (partial offloading) there may be some padding that needs to be cleared:
// If src0 is on a temporary compute buffer (partial offloading) there may be some padding that needs to be cleared:
if (ne00 % MATRIX_ROW_PADDING != 0 && ggml_is_quantized(src0->type) && ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE && src0->view_src == nullptr) {
const int64_t nbytes_data = ggml_row_size(src0->type, (dev[id].row_high - dev[id].row_low)*ne00);
const int64_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);
const size_t nbytes_data = ggml_row_size(src0->type, (dev[id].row_high - dev[id].row_low)*ne00);
const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);
CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd + nbytes_data , 0, nbytes_padding, stream));
}

Expand Down
4 changes: 1 addition & 3 deletions ggml/src/ggml-cuda/mmq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,14 @@ void ggml_cuda_op_mul_mat_q(

const int64_t ne00 = src0->ne[0];

const int64_t nb01 = src0->nb[1];

const int64_t ne10 = src1->ne[0];
const int64_t ne11 = src1->ne[1];
GGML_ASSERT(ne10 % QK8_1 == 0);

const int64_t ne0 = dst->ne[0];

const int64_t row_diff = row_high - row_low;
const int64_t stride00 = nb01 / ggml_type_size(src0->type);
const int64_t stride00 = ne00 / ggml_blck_size(src0->type);

int id = ggml_cuda_get_device();
const int compute_capability = ggml_cuda_info().devices[id].cc;
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -3464,7 +3464,7 @@ int64_t ggml_nrows(const struct ggml_tensor * tensor) {

size_t ggml_nbytes(const struct ggml_tensor * tensor) {
size_t nbytes;
size_t blck_size = ggml_blck_size(tensor->type);
const size_t blck_size = ggml_blck_size(tensor->type);
if (blck_size == 1) {
nbytes = ggml_type_size(tensor->type);
for (int i = 0; i < GGML_MAX_DIMS; ++i) {
Expand Down
78 changes: 60 additions & 18 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1650,11 +1650,12 @@ struct test_mul_mat : public test_case {
const int64_t m;
const int64_t n;
const int64_t k;
const std::array<int64_t, 2> bs; // dims 3 and 4
const std::array<int64_t, 2> nr; // repeat in dims 3 and 4
const std::array<int64_t, 2> bs; // dims 3 and 4
const std::array<int64_t, 2> nr; // repeat in dims 3 and 4
const std::array<int64_t, 4> per; // permutation of dimensions

std::string vars() override {
return VARS_TO_STR7(type_a, type_b, m, n, k, bs, nr);
return VARS_TO_STR8(type_a, type_b, m, n, k, bs, nr, per);
}

double max_nmse_err() override {
Expand All @@ -1669,17 +1670,44 @@ struct test_mul_mat : public test_case {
test_mul_mat(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
int64_t m = 32, int64_t n = 32, int64_t k = 32,
std::array<int64_t, 2> bs = {10, 10},
std::array<int64_t, 2> nr = {2, 2})
: type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr) {}
std::array<int64_t, 2> nr = {2, 2},
std::array<int64_t, 4> per = {0, 1, 2, 3})
: type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr), per(per) {}

ggml_tensor * build_graph(ggml_context * ctx) override {
// C^T = A * B^T: (k, m) * (k, n) => (m, n)
ggml_tensor * a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0] , bs[1]);
ggml_tensor * b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]);
ggml_set_param(ctx, a);
ggml_set_param(ctx, b);
ggml_set_name(a, "a");
ggml_set_name(b, "b");
ggml_tensor * a;
ggml_tensor * b;

const int npermuted = (per[0] != 0) + (per[1] != 1) + (per[2] != 2) + (per[3] != 3);
if (npermuted > 0) {
GGML_ASSERT(npermuted == 2);
GGML_ASSERT(!ggml_is_quantized(type_a) || per[0] == 0);
GGML_ASSERT(!ggml_is_quantized(type_b) || per[0] == 0);

// Create tensors with the permuted dimensions, then permute them back to the dimensions given by m,n,k.
const int64_t ne_a[4] = {k, m, bs[0], bs[1]};
const int64_t ne_b[4] = {k, n, bs[0]*nr[0], bs[1]*nr[1]};

a = ggml_new_tensor_4d(ctx, type_a, ne_a[per[0]], ne_a[per[1]], ne_a[per[2]], ne_a[per[3]]);
b = ggml_new_tensor_4d(ctx, type_b, ne_b[per[0]], ne_b[per[1]], ne_b[per[2]], ne_b[per[3]]);
ggml_set_param(ctx, a);
ggml_set_param(ctx, b);
ggml_set_name(a, "a");
ggml_set_name(b, "b");

a = ggml_permute(ctx, a, per[0], per[1], per[2], per[3]);
b = ggml_permute(ctx, b, per[0], per[1], per[2], per[3]);
ggml_set_name(a, "a_permuted");
ggml_set_name(b, "b_permuted");
} else {
a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0], bs[1]);
b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]);
ggml_set_param(ctx, a);
ggml_set_param(ctx, b);
ggml_set_name(a, "a");
ggml_set_name(b, "b");
}

ggml_tensor * out = ggml_mul_mat(ctx, a, b);
ggml_set_name(out, "out");
Expand Down Expand Up @@ -3442,13 +3470,14 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
#if 1
for (ggml_type type_a : base_types) {
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 2}));
// test cases without permutation
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 2}));

test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, { 1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 1}, {1, 1}));
Expand All @@ -3457,6 +3486,19 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {1, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 2}));

// test cases with permutation
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {2, 3}, {1, 1}, {0, 2, 1, 3}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {2, 3}, {1, 1}, {0, 1, 3, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {2, 3}, {1, 1}, {0, 3, 2, 1}));

test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, 256, {2, 3}, {1, 1}, {0, 2, 1, 3}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, 256, {2, 3}, {1, 1}, {0, 1, 3, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, 256, {2, 3}, {1, 1}, {0, 3, 2, 1}));

test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 2, 1, 3}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 1, 3, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 3, 2, 1}));
}
}
for (ggml_type type_a : other_types) {
Expand Down
Loading