@@ -1645,18 +1645,26 @@ struct test_leaky_relu : public test_case {
1645
1645
// GGML_OP_SSM_CONV
1646
1646
struct test_ssm_conv : public test_case {
1647
1647
const ggml_type type;
1648
+ const int64_t d_conv;
1649
+ const int64_t d_inner;
1650
+ const int64_t n_seq_tokens;
1651
+ const int64_t n_seqs;
1648
1652
1649
1653
std::string vars () override {
1650
- return VARS_TO_STR4 (type, 3 , 1536 , 4 );
1654
+ return VARS_TO_STR5 (type, d_conv, d_inner, n_seq_tokens, n_seqs );
1651
1655
}
1652
1656
1653
- test_ssm_conv (ggml_type type = GGML_TYPE_F32)
1654
- : type(type) {}
1657
+ test_ssm_conv (ggml_type type = GGML_TYPE_F32,
1658
+ int64_t d_conv = 4 ,
1659
+ int64_t d_inner = 1536 ,
1660
+ int64_t n_seq_tokens = 7 ,
1661
+ int64_t n_seqs = 2 )
1662
+ : type(type), d_conv(d_conv), d_inner(d_inner), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
1655
1663
1656
1664
ggml_tensor * build_graph (ggml_context * ctx) override {
1657
- ggml_tensor * s = ggml_new_tensor_3d (ctx, type, 3 , 1536 , 1 );
1658
- ggml_tensor * x = ggml_new_tensor_2d (ctx, type, 1536 , 1 );
1659
- ggml_tensor * c = ggml_new_tensor_2d (ctx, type, 4 , 1536 );
1665
+ ggml_tensor * s = ggml_new_tensor_3d (ctx, type, d_conv - 1 , d_inner, n_seqs );
1666
+ ggml_tensor * x = ggml_new_tensor_3d (ctx, type, d_inner, n_seq_tokens, n_seqs );
1667
+ ggml_tensor * c = ggml_new_tensor_2d (ctx, type, d_conv, d_inner );
1660
1668
ggml_tensor * out = ggml_ssm_conv (ctx, s, x, c);
1661
1669
return out;
1662
1670
}
@@ -1665,21 +1673,29 @@ struct test_ssm_conv : public test_case {
1665
1673
// GGML_OP_SSM_SCAN
1666
1674
struct test_ssm_scan : public test_case {
1667
1675
const ggml_type type;
1676
+ const int64_t d_state;
1677
+ const int64_t d_inner;
1678
+ const int64_t n_seq_tokens;
1679
+ const int64_t n_seqs;
1668
1680
1669
1681
std::string vars () override {
1670
- return VARS_TO_STR4 (type, 16 , 1536 , 2 );
1682
+ return VARS_TO_STR5 (type, d_state, d_inner, n_seq_tokens, n_seqs );
1671
1683
}
1672
1684
1673
- test_ssm_scan (ggml_type type = GGML_TYPE_F32)
1674
- : type(type) {}
1685
+ test_ssm_scan (ggml_type type = GGML_TYPE_F32,
1686
+ int64_t d_state = 16 ,
1687
+ int64_t d_inner = 1536 ,
1688
+ int64_t n_seq_tokens = 7 ,
1689
+ int64_t n_seqs = 2 )
1690
+ : type(type), d_state(d_state), d_inner(d_inner), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
1675
1691
1676
1692
ggml_tensor * build_graph (ggml_context * ctx) override {
1677
- ggml_tensor * s = ggml_new_tensor_3d (ctx, type, 16 , 1536 , 1 );
1678
- ggml_tensor * x = ggml_new_tensor_2d (ctx, type, 1536 , 2 );
1679
- ggml_tensor * dt = ggml_new_tensor_2d (ctx, type, 1536 , 2 );
1680
- ggml_tensor * A = ggml_new_tensor_2d (ctx, type, 16 , 1536 );
1681
- ggml_tensor * B = ggml_new_tensor_2d (ctx, type, 16 , 2 );
1682
- ggml_tensor * C = ggml_new_tensor_2d (ctx, type, 16 , 2 );
1693
+ ggml_tensor * s = ggml_new_tensor_3d (ctx, type, d_state, d_inner, n_seqs );
1694
+ ggml_tensor * x = ggml_new_tensor_3d (ctx, type, d_inner, n_seq_tokens, n_seqs );
1695
+ ggml_tensor * dt = ggml_new_tensor_3d (ctx, type, d_inner, n_seq_tokens, n_seqs );
1696
+ ggml_tensor * A = ggml_new_tensor_2d (ctx, type, d_state, d_inner );
1697
+ ggml_tensor * B = ggml_new_tensor_3d (ctx, type, d_state, n_seq_tokens, n_seqs );
1698
+ ggml_tensor * C = ggml_new_tensor_3d (ctx, type, d_state, n_seq_tokens, n_seqs );
1683
1699
ggml_tensor * out = ggml_ssm_scan (ctx, s, x, dt, A, B, C);
1684
1700
return out;
1685
1701
}
0 commit comments