@@ -1001,6 +1001,45 @@ struct test_ssm_scan : public test_case {
1001
1001
}
1002
1002
};
1003
1003
1004
+ // GGML_OP_RWKV_WKV
1005
+ struct test_rwkv_wkv : public test_case {
1006
+ const ggml_type type;
1007
+
1008
+ const int64_t head_count;
1009
+ const int64_t head_size;
1010
+ const int64_t n_seq_tokens;
1011
+ const int64_t n_seqs;
1012
+
1013
+ std::string vars () override {
1014
+ return VARS_TO_STR5 (type, head_count, head_size, n_seq_tokens, n_seqs);
1015
+ }
1016
+
1017
+ test_rwkv_wkv (ggml_type type = GGML_TYPE_F32,
1018
+ int64_t head_count = 32 , int64_t head_size = 64 , int64_t n_seq_tokens = 32 , int64_t n_seqs = 32 )
1019
+ : type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
1020
+
1021
+ ggml_tensor * build_graph (ggml_context * ctx) override {
1022
+ const int64_t n_tokens = n_seq_tokens * n_seqs;
1023
+ // ggml_tensor * s = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_state, d_inner, n_seqs, 1 }.data());
1024
+ // ggml_tensor * x = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_inner, n_seq_tokens, n_seqs, 1 }.data());
1025
+ // ggml_tensor * dt = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_inner, n_seq_tokens, n_seqs, 1 }.data());
1026
+ // ggml_tensor * A = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_state, d_inner, 1 , 1 }.data());
1027
+ // ggml_tensor * B = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_state, n_seq_tokens, n_seqs, 1 }.data());
1028
+ // ggml_tensor * C = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_state, n_seq_tokens, n_seqs, 1 }.data());
1029
+ // ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C);
1030
+
1031
+ ggml_tensor * r = ggml_new_tensor (ctx, type, 4 , std::vector<int64_t >{ 1 , head_size, head_count, n_tokens }.data ());
1032
+ ggml_tensor * k = ggml_new_tensor (ctx, type, 4 , std::vector<int64_t >{ head_size, 1 , head_count, n_tokens }.data ());
1033
+ ggml_tensor * v = ggml_new_tensor (ctx, type, 4 , std::vector<int64_t >{ 1 , head_size, head_count, n_tokens }.data ());
1034
+ ggml_tensor * tf = ggml_new_tensor (ctx, type, 2 , std::vector<int64_t >{ head_size, head_count }.data ());
1035
+ ggml_tensor * td = ggml_new_tensor (ctx, type, 4 , std::vector<int64_t >{ 1 , head_size, head_count, n_tokens }.data ());
1036
+ ggml_tensor * s = ggml_new_tensor (ctx, type, 2 , std::vector<int64_t >{ head_size * head_size * head_count, n_seqs }.data ());
1037
+ ggml_tensor * out = ggml_rwkv_wkv (ctx, k, v, r, tf, td, s);
1038
+
1039
+ return out;
1040
+ }
1041
+ };
1042
+
1004
1043
// GGML_OP_MUL_MAT
1005
1044
struct test_mul_mat : public test_case {
1006
1045
const ggml_type type_a;
@@ -2371,6 +2410,11 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
2371
2410
2372
2411
test_cases.emplace_back (new test_ssm_scan (GGML_TYPE_F32, 16 , 1024 , 32 , 4 ));
2373
2412
2413
+ test_cases.emplace_back (new test_rwkv_wkv (GGML_TYPE_F32, 32 , 64 , 1 , 1 ));
2414
+ test_cases.emplace_back (new test_rwkv_wkv (GGML_TYPE_F32, 32 , 64 , 32 , 1 ));
2415
+ test_cases.emplace_back (new test_rwkv_wkv (GGML_TYPE_F32, 32 , 64 , 32 , 4 ));
2416
+ test_cases.emplace_back (new test_rwkv_wkv (GGML_TYPE_F32, 32 , 64 , 128 , 4 ));
2417
+
2374
2418
#if 1
2375
2419
for (ggml_type type_a : base_types) {
2376
2420
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
0 commit comments