Skip to content

Commit 35f2f86

Browse files
committed
Add GGML_OP_SSM_CONF, GGML_OP_SSM_SCAN to supported ops for CUDA backend + test case for each op
1 parent 677ad0a commit 35f2f86

File tree

2 files changed

+74
-0
lines changed

2 files changed

+74
-0
lines changed

ggml-cuda.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2908,6 +2908,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
29082908
case GGML_OP_ARANGE:
29092909
case GGML_OP_TIMESTEP_EMBEDDING:
29102910
case GGML_OP_LEAKY_RELU:
2911+
case GGML_OP_SSM_CONV:
2912+
case GGML_OP_SSM_SCAN:
29112913
return true;
29122914
case GGML_OP_FLASH_ATTN_EXT:
29132915
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)

tests/test-backend-ops.cpp

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1559,6 +1559,76 @@ struct test_leaky_relu : public test_case {
15591559
}
15601560
};
15611561

1562+
// GGML_OP_SSM_CONV
1563+
struct test_ssm_conv : public test_case {
1564+
const ggml_type type;
1565+
1566+
std::string vars() override {
1567+
return VARS_TO_STR4(type, 3, 1536, 4);
1568+
}
1569+
1570+
test_ssm_conv(ggml_type type = GGML_TYPE_F32)
1571+
: type(type) {}
1572+
1573+
ggml_tensor * build_graph(ggml_context * ctx) override {
1574+
ggml_tensor * s = ggml_new_tensor_3d(ctx, type, 3, 1536, 1);
1575+
ggml_tensor * x = ggml_new_tensor_2d(ctx, type, 1536, 1);
1576+
ggml_tensor * c = ggml_new_tensor_2d(ctx, type, 4, 1536);
1577+
ggml_tensor * sq = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, 1, 1);
1578+
ggml_tensor * out = ggml_ssm_conv(ctx, s, x, c, sq);
1579+
return out;
1580+
}
1581+
1582+
void initialize_tensors(ggml_context * ctx) override {
1583+
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
1584+
if (t->type == GGML_TYPE_I32) {
1585+
std::vector<int> data(1);
1586+
data[0] = 0;
1587+
ggml_backend_tensor_set(t, data.data(), 0, 1 * sizeof(int));
1588+
} else {
1589+
init_tensor_uniform(t);
1590+
}
1591+
}
1592+
}
1593+
};
1594+
1595+
// GGML_OP_SSM_SCAN
1596+
struct test_ssm_scan : public test_case {
1597+
const ggml_type type;
1598+
1599+
std::string vars() override {
1600+
return VARS_TO_STR4(type, 16, 1536, 2);
1601+
}
1602+
1603+
test_ssm_scan(ggml_type type = GGML_TYPE_F32)
1604+
: type(type) {}
1605+
1606+
ggml_tensor * build_graph(ggml_context * ctx) override {
1607+
ggml_tensor * s = ggml_new_tensor_3d(ctx, type, 16, 1536, 1);
1608+
ggml_tensor * x = ggml_new_tensor_2d(ctx, type, 1536, 2);
1609+
ggml_tensor * dt = ggml_new_tensor_2d(ctx, type, 1536, 2);
1610+
ggml_tensor * A = ggml_new_tensor_2d(ctx, type, 16, 1536);
1611+
ggml_tensor * B = ggml_new_tensor_2d(ctx, type, 16, 2);
1612+
ggml_tensor * C = ggml_new_tensor_2d(ctx, type, 16, 2);
1613+
ggml_tensor * sq = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, 1, 2);
1614+
ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C, sq);
1615+
return out;
1616+
}
1617+
1618+
void initialize_tensors(ggml_context * ctx) override {
1619+
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
1620+
if (t->type == GGML_TYPE_I32) {
1621+
std::vector<int> data(2);
1622+
data[0] = 0;
1623+
data[1] = 0;
1624+
ggml_backend_tensor_set(t, data.data(), 0, 2 * sizeof(int));
1625+
} else {
1626+
init_tensor_uniform(t);
1627+
}
1628+
}
1629+
}
1630+
};
1631+
15621632
// GGML_OP_FLASH_ATTN_EXT
15631633
struct test_flash_attn_ext : public test_case {
15641634
const int64_t hs; // head size
@@ -2284,6 +2354,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
22842354
test_cases.emplace_back(new test_arange());
22852355
test_cases.emplace_back(new test_timestep_embedding());
22862356
test_cases.emplace_back(new test_leaky_relu());
2357+
test_cases.emplace_back(new test_ssm_conv());
2358+
test_cases.emplace_back(new test_ssm_scan());
22872359

22882360
for (int hs : { 64, 80, 128, 256, }) {
22892361
for (bool mask : { true, false } ) {

0 commit comments

Comments
 (0)