Skip to content

Commit 64fbd32

Browse files
committed
Add patch to test cases provided by @compilade; test for ssm_conv fails
1 parent 25f9e65 commit 64fbd32

File tree

1 file changed

+31
-15
lines changed

1 file changed

+31
-15
lines changed

tests/test-backend-ops.cpp

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1645,18 +1645,26 @@ struct test_leaky_relu : public test_case {
16451645
// GGML_OP_SSM_CONV
16461646
struct test_ssm_conv : public test_case {
16471647
const ggml_type type;
1648+
const int64_t d_conv;
1649+
const int64_t d_inner;
1650+
const int64_t n_seq_tokens;
1651+
const int64_t n_seqs;
16481652

16491653
std::string vars() override {
1650-
return VARS_TO_STR4(type, 3, 1536, 4);
1654+
return VARS_TO_STR5(type, d_conv, d_inner, n_seq_tokens, n_seqs);
16511655
}
16521656

1653-
test_ssm_conv(ggml_type type = GGML_TYPE_F32)
1654-
: type(type) {}
1657+
test_ssm_conv(ggml_type type = GGML_TYPE_F32,
1658+
int64_t d_conv = 4,
1659+
int64_t d_inner = 1536,
1660+
int64_t n_seq_tokens = 7,
1661+
int64_t n_seqs = 2)
1662+
: type(type), d_conv(d_conv), d_inner(d_inner), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
16551663

16561664
ggml_tensor * build_graph(ggml_context * ctx) override {
1657-
ggml_tensor * s = ggml_new_tensor_3d(ctx, type, 3, 1536, 1);
1658-
ggml_tensor * x = ggml_new_tensor_2d(ctx, type, 1536, 1);
1659-
ggml_tensor * c = ggml_new_tensor_2d(ctx, type, 4, 1536);
1665+
ggml_tensor * s = ggml_new_tensor_3d(ctx, type, d_conv - 1, d_inner, n_seqs);
1666+
ggml_tensor * x = ggml_new_tensor_3d(ctx, type, d_inner, n_seq_tokens, n_seqs);
1667+
ggml_tensor * c = ggml_new_tensor_2d(ctx, type, d_conv, d_inner);
16601668
ggml_tensor * out = ggml_ssm_conv(ctx, s, x, c);
16611669
return out;
16621670
}
@@ -1665,21 +1673,29 @@ struct test_ssm_conv : public test_case {
16651673
// GGML_OP_SSM_SCAN
16661674
struct test_ssm_scan : public test_case {
16671675
const ggml_type type;
1676+
const int64_t d_state;
1677+
const int64_t d_inner;
1678+
const int64_t n_seq_tokens;
1679+
const int64_t n_seqs;
16681680

16691681
std::string vars() override {
1670-
return VARS_TO_STR4(type, 16, 1536, 2);
1682+
return VARS_TO_STR5(type, d_state, d_inner, n_seq_tokens, n_seqs);
16711683
}
16721684

1673-
test_ssm_scan(ggml_type type = GGML_TYPE_F32)
1674-
: type(type) {}
1685+
test_ssm_scan(ggml_type type = GGML_TYPE_F32,
1686+
int64_t d_state = 16,
1687+
int64_t d_inner = 1536,
1688+
int64_t n_seq_tokens = 7,
1689+
int64_t n_seqs = 2)
1690+
: type(type), d_state(d_state), d_inner(d_inner), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
16751691

16761692
ggml_tensor * build_graph(ggml_context * ctx) override {
1677-
ggml_tensor * s = ggml_new_tensor_3d(ctx, type, 16, 1536, 1);
1678-
ggml_tensor * x = ggml_new_tensor_2d(ctx, type, 1536, 2);
1679-
ggml_tensor * dt = ggml_new_tensor_2d(ctx, type, 1536, 2);
1680-
ggml_tensor * A = ggml_new_tensor_2d(ctx, type, 16, 1536);
1681-
ggml_tensor * B = ggml_new_tensor_2d(ctx, type, 16, 2);
1682-
ggml_tensor * C = ggml_new_tensor_2d(ctx, type, 16, 2);
1693+
ggml_tensor * s = ggml_new_tensor_3d(ctx, type, d_state, d_inner, n_seqs);
1694+
ggml_tensor * x = ggml_new_tensor_3d(ctx, type, d_inner, n_seq_tokens, n_seqs);
1695+
ggml_tensor * dt = ggml_new_tensor_3d(ctx, type, d_inner, n_seq_tokens, n_seqs);
1696+
ggml_tensor * A = ggml_new_tensor_2d(ctx, type, d_state, d_inner);
1697+
ggml_tensor * B = ggml_new_tensor_3d(ctx, type, d_state, n_seq_tokens, n_seqs);
1698+
ggml_tensor * C = ggml_new_tensor_3d(ctx, type, d_state, n_seq_tokens, n_seqs);
16831699
ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C);
16841700
return out;
16851701
}

0 commit comments

Comments
 (0)