@@ -1650,11 +1650,12 @@ struct test_mul_mat : public test_case {
1650
1650
const int64_t m;
1651
1651
const int64_t n;
1652
1652
const int64_t k;
1653
- const std::array<int64_t , 2 > bs; // dims 3 and 4
1654
- const std::array<int64_t , 2 > nr; // repeat in dims 3 and 4
1653
+ const std::array<int64_t , 2 > bs; // dims 3 and 4
1654
+ const std::array<int64_t , 2 > nr; // repeat in dims 3 and 4
1655
+ const std::array<int64_t , 4 > per; // permutation of dimensions
1655
1656
1656
1657
std::string vars () override {
1657
- return VARS_TO_STR7 (type_a, type_b, m, n, k, bs, nr);
1658
+ return VARS_TO_STR8 (type_a, type_b, m, n, k, bs, nr, per );
1658
1659
}
1659
1660
1660
1661
double max_nmse_err () override {
@@ -1669,17 +1670,44 @@ struct test_mul_mat : public test_case {
1669
1670
test_mul_mat (ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
1670
1671
int64_t m = 32 , int64_t n = 32 , int64_t k = 32 ,
1671
1672
std::array<int64_t , 2 > bs = {10 , 10 },
1672
- std::array<int64_t , 2 > nr = {2 , 2 })
1673
- : type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr) {}
1673
+ std::array<int64_t , 2 > nr = {2 , 2 },
1674
+ std::array<int64_t , 4 > per = {0 , 1 , 2 , 3 })
1675
+ : type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr), per(per) {}
1674
1676
1675
1677
ggml_tensor * build_graph (ggml_context * ctx) override {
1676
1678
// C^T = A * B^T: (k, m) * (k, n) => (m, n)
1677
- ggml_tensor * a = ggml_new_tensor_4d (ctx, type_a, k, m, bs[0 ] , bs[1 ]);
1678
- ggml_tensor * b = ggml_new_tensor_4d (ctx, type_b, k, n, bs[0 ]*nr[0 ], bs[1 ]*nr[1 ]);
1679
- ggml_set_param (ctx, a);
1680
- ggml_set_param (ctx, b);
1681
- ggml_set_name (a, " a" );
1682
- ggml_set_name (b, " b" );
1679
+ ggml_tensor * a;
1680
+ ggml_tensor * b;
1681
+
1682
+ const int npermuted = (per[0 ] != 0 ) + (per[1 ] != 1 ) + (per[2 ] != 2 ) + (per[3 ] != 3 );
1683
+ if (npermuted > 0 ) {
1684
+ GGML_ASSERT (npermuted == 2 );
1685
+ GGML_ASSERT (!ggml_is_quantized (type_a) || per[0 ] == 0 );
1686
+ GGML_ASSERT (!ggml_is_quantized (type_b) || per[0 ] == 0 );
1687
+
1688
+ // Create tensors with the permuted dimensions, then permute them back to the dimensions given by m,n,k.
1689
+ const int64_t ne_a[4 ] = {k, m, bs[0 ], bs[1 ]};
1690
+ const int64_t ne_b[4 ] = {k, n, bs[0 ]*nr[0 ], bs[1 ]*nr[1 ]};
1691
+
1692
+ a = ggml_new_tensor_4d (ctx, type_a, ne_a[per[0 ]], ne_a[per[1 ]], ne_a[per[2 ]], ne_a[per[3 ]]);
1693
+ b = ggml_new_tensor_4d (ctx, type_b, ne_b[per[0 ]], ne_b[per[1 ]], ne_b[per[2 ]], ne_b[per[3 ]]);
1694
+ ggml_set_param (ctx, a);
1695
+ ggml_set_param (ctx, b);
1696
+ ggml_set_name (a, " a" );
1697
+ ggml_set_name (b, " b" );
1698
+
1699
+ a = ggml_permute (ctx, a, per[0 ], per[1 ], per[2 ], per[3 ]);
1700
+ b = ggml_permute (ctx, b, per[0 ], per[1 ], per[2 ], per[3 ]);
1701
+ ggml_set_name (a, " a_permuted" );
1702
+ ggml_set_name (b, " b_permuted" );
1703
+ } else {
1704
+ a = ggml_new_tensor_4d (ctx, type_a, k, m, bs[0 ], bs[1 ]);
1705
+ b = ggml_new_tensor_4d (ctx, type_b, k, n, bs[0 ]*nr[0 ], bs[1 ]*nr[1 ]);
1706
+ ggml_set_param (ctx, a);
1707
+ ggml_set_param (ctx, b);
1708
+ ggml_set_name (a, " a" );
1709
+ ggml_set_name (b, " b" );
1710
+ }
1683
1711
1684
1712
ggml_tensor * out = ggml_mul_mat (ctx, a, b);
1685
1713
ggml_set_name (out, " out" );
@@ -3478,13 +3506,14 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
3478
3506
#if 1
3479
3507
for (ggml_type type_a : base_types) {
3480
3508
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
3481
- test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , { 1 , 1 }, {1 , 1 }));
3482
- test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {10 , 1 }, {1 , 1 }));
3483
- test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {10 , 1 }, {2 , 1 }));
3484
- test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {10 , 10 }, {1 , 1 }));
3485
- test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {10 , 10 }, {2 , 1 }));
3486
- test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {10 , 10 }, {1 , 2 }));
3487
- test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {10 , 10 }, {2 , 2 }));
3509
+ // test cases without permutation
3510
+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , { 1 , 1 }, {1 , 1 }));
3511
+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {10 , 1 }, {1 , 1 }));
3512
+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {10 , 1 }, {2 , 1 }));
3513
+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {10 , 10 }, {1 , 1 }));
3514
+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {10 , 10 }, {2 , 1 }));
3515
+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {10 , 10 }, {1 , 2 }));
3516
+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {10 , 10 }, {2 , 2 }));
3488
3517
3489
3518
test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 16 , 256 , { 1 , 1 }, {1 , 1 }));
3490
3519
test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 16 , 256 , {10 , 1 }, {1 , 1 }));
@@ -3493,6 +3522,19 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
3493
3522
test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 16 , 256 , {10 , 10 }, {2 , 1 }));
3494
3523
test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 16 , 256 , {10 , 10 }, {1 , 2 }));
3495
3524
test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 16 , 256 , {10 , 10 }, {2 , 2 }));
3525
+
3526
+ // test cases with permutation
3527
+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {2 , 3 }, {1 , 1 }, {0 , 2 , 1 , 3 }));
3528
+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {2 , 3 }, {1 , 1 }, {0 , 1 , 3 , 2 }));
3529
+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {2 , 3 }, {1 , 1 }, {0 , 3 , 2 , 1 }));
3530
+
3531
+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 8 , 256 , {2 , 3 }, {1 , 1 }, {0 , 2 , 1 , 3 }));
3532
+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 8 , 256 , {2 , 3 }, {1 , 1 }, {0 , 1 , 3 , 2 }));
3533
+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 8 , 256 , {2 , 3 }, {1 , 1 }, {0 , 3 , 2 , 1 }));
3534
+
3535
+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 16 , 256 , {2 , 3 }, {1 , 1 }, {0 , 2 , 1 , 3 }));
3536
+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 16 , 256 , {2 , 3 }, {1 , 1 }, {0 , 1 , 3 , 2 }));
3537
+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 16 , 256 , {2 , 3 }, {1 , 1 }, {0 , 3 , 2 , 1 }));
3496
3538
}
3497
3539
}
3498
3540
for (ggml_type type_a : other_types) {
0 commit comments