diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index e91dedf14a1cb..80145b064d4bb 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -425,102 +425,119 @@ extern "C" { GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors }; + // Helper macro for the X-Macro list of operations + // Format: XX(op_name, n_src_value) + #define GGML_OP_LIST(XX) \ + XX(GGML_OP_NONE, 0) \ + XX(GGML_OP_DUP, 1) \ + XX(GGML_OP_ADD, 2) \ + XX(GGML_OP_ADD1, 2) \ + XX(GGML_OP_ACC, 2) \ + XX(GGML_OP_SUB, 2) \ + XX(GGML_OP_MUL, 2) \ + XX(GGML_OP_DIV, 2) \ + XX(GGML_OP_SQR, 1) \ + XX(GGML_OP_SQRT, 1) \ + XX(GGML_OP_LOG, 1) \ + XX(GGML_OP_SIN, 1) \ + XX(GGML_OP_COS, 1) \ + XX(GGML_OP_SUM, 1) \ + XX(GGML_OP_SUM_ROWS, 1) \ + XX(GGML_OP_MEAN, 1) \ + XX(GGML_OP_ARGMAX, 1) \ + XX(GGML_OP_COUNT_EQUAL, 2) \ + XX(GGML_OP_REPEAT, 2) \ + XX(GGML_OP_REPEAT_BACK, 2) \ + XX(GGML_OP_CONCAT, 2) \ + XX(GGML_OP_SILU_BACK, 2) \ + XX(GGML_OP_NORM, 1) \ + XX(GGML_OP_RMS_NORM, 1) \ + XX(GGML_OP_RMS_NORM_BACK, 2) \ + XX(GGML_OP_GROUP_NORM, 1) \ + XX(GGML_OP_L2_NORM, 1) \ + XX(GGML_OP_MUL_MAT, 2) \ + XX(GGML_OP_MUL_MAT_ID, 3) \ + XX(GGML_OP_OUT_PROD, 2) \ + XX(GGML_OP_SCALE, 1) \ + XX(GGML_OP_SET, 2) \ + XX(GGML_OP_CPY, 2) \ + XX(GGML_OP_CONT, 1) \ + XX(GGML_OP_RESHAPE, 1) \ + XX(GGML_OP_VIEW, 1) \ + XX(GGML_OP_PERMUTE, 1) \ + XX(GGML_OP_TRANSPOSE, 1) \ + XX(GGML_OP_GET_ROWS, 2) \ + XX(GGML_OP_GET_ROWS_BACK, 3) \ + XX(GGML_OP_DIAG, 1) \ + XX(GGML_OP_DIAG_MASK_INF, 1) \ + XX(GGML_OP_DIAG_MASK_ZERO, 1) \ + XX(GGML_OP_SOFT_MAX, 2) \ + XX(GGML_OP_SOFT_MAX_BACK, 2) \ + XX(GGML_OP_ROPE, 3) \ + XX(GGML_OP_ROPE_BACK, 3) \ + XX(GGML_OP_CLAMP, 1) \ + XX(GGML_OP_CONV_TRANSPOSE_1D, 2) \ + XX(GGML_OP_IM2COL, 2) \ + XX(GGML_OP_IM2COL_BACK, 2) \ + XX(GGML_OP_CONV_2D_DW, 2) \ + XX(GGML_OP_CONV_TRANSPOSE_2D, 2) \ + XX(GGML_OP_POOL_1D, 1) \ + XX(GGML_OP_POOL_2D, 1) \ + XX(GGML_OP_POOL_2D_BACK, 2) \ + XX(GGML_OP_UPSCALE, 1) \ + XX(GGML_OP_PAD, 1) \ + XX(GGML_OP_PAD_REFLECT_1D, 1) \ + XX(GGML_OP_ARANGE, 0) \ + XX(GGML_OP_TIMESTEP_EMBEDDING, 1) \ + XX(GGML_OP_ARGSORT, 1) \ + XX(GGML_OP_LEAKY_RELU, 1) \ + XX(GGML_OP_FLASH_ATTN_EXT, 4) \ + XX(GGML_OP_FLASH_ATTN_BACK, 4) \ + XX(GGML_OP_SSM_CONV, 2) \ + XX(GGML_OP_SSM_SCAN, 6) \ + XX(GGML_OP_WIN_PART, 1) \ + XX(GGML_OP_WIN_UNPART, 1) \ + XX(GGML_OP_GET_REL_POS, 1) \ + XX(GGML_OP_ADD_REL_POS, 3) \ + XX(GGML_OP_RWKV_WKV6, 6) \ + XX(GGML_OP_GATED_LINEAR_ATTN, 5) \ + XX(GGML_OP_RWKV_WKV7, 7) \ + XX(GGML_OP_UNARY, 1) \ + XX(GGML_OP_MAP_CUSTOM1, 1) \ + XX(GGML_OP_MAP_CUSTOM2, 2) \ + XX(GGML_OP_MAP_CUSTOM3, 3) \ + XX(GGML_OP_CUSTOM, -1) \ + XX(GGML_OP_CROSS_ENTROPY_LOSS, 2) \ + XX(GGML_OP_CROSS_ENTROPY_LOSS_BACK, 3) \ + XX(GGML_OP_OPT_STEP_ADAMW, 5) + // available tensor operations: enum ggml_op { - GGML_OP_NONE = 0, - - GGML_OP_DUP, - GGML_OP_ADD, - GGML_OP_ADD1, - GGML_OP_ACC, - GGML_OP_SUB, - GGML_OP_MUL, - GGML_OP_DIV, - GGML_OP_SQR, - GGML_OP_SQRT, - GGML_OP_LOG, - GGML_OP_SIN, - GGML_OP_COS, - GGML_OP_SUM, - GGML_OP_SUM_ROWS, - GGML_OP_MEAN, - GGML_OP_ARGMAX, - GGML_OP_COUNT_EQUAL, - GGML_OP_REPEAT, - GGML_OP_REPEAT_BACK, - GGML_OP_CONCAT, - GGML_OP_SILU_BACK, - GGML_OP_NORM, // normalize - GGML_OP_RMS_NORM, - GGML_OP_RMS_NORM_BACK, - GGML_OP_GROUP_NORM, - GGML_OP_L2_NORM, - - GGML_OP_MUL_MAT, - GGML_OP_MUL_MAT_ID, - GGML_OP_OUT_PROD, - - GGML_OP_SCALE, - GGML_OP_SET, - GGML_OP_CPY, - GGML_OP_CONT, - GGML_OP_RESHAPE, - GGML_OP_VIEW, - GGML_OP_PERMUTE, - GGML_OP_TRANSPOSE, - GGML_OP_GET_ROWS, - GGML_OP_GET_ROWS_BACK, - GGML_OP_DIAG, - GGML_OP_DIAG_MASK_INF, - GGML_OP_DIAG_MASK_ZERO, - GGML_OP_SOFT_MAX, - GGML_OP_SOFT_MAX_BACK, - GGML_OP_ROPE, - GGML_OP_ROPE_BACK, - GGML_OP_CLAMP, - GGML_OP_CONV_TRANSPOSE_1D, - GGML_OP_IM2COL, - GGML_OP_IM2COL_BACK, - GGML_OP_CONV_2D_DW, - GGML_OP_CONV_TRANSPOSE_2D, - GGML_OP_POOL_1D, - GGML_OP_POOL_2D, - GGML_OP_POOL_2D_BACK, - GGML_OP_UPSCALE, // nearest interpolate - GGML_OP_PAD, - GGML_OP_PAD_REFLECT_1D, - GGML_OP_ARANGE, - GGML_OP_TIMESTEP_EMBEDDING, - GGML_OP_ARGSORT, - GGML_OP_LEAKY_RELU, - - GGML_OP_FLASH_ATTN_EXT, - GGML_OP_FLASH_ATTN_BACK, - GGML_OP_SSM_CONV, - GGML_OP_SSM_SCAN, - GGML_OP_WIN_PART, - GGML_OP_WIN_UNPART, - GGML_OP_GET_REL_POS, - GGML_OP_ADD_REL_POS, - GGML_OP_RWKV_WKV6, - GGML_OP_GATED_LINEAR_ATTN, - GGML_OP_RWKV_WKV7, - - GGML_OP_UNARY, - - GGML_OP_MAP_CUSTOM1, - GGML_OP_MAP_CUSTOM2, - GGML_OP_MAP_CUSTOM3, - - GGML_OP_CUSTOM, - - GGML_OP_CROSS_ENTROPY_LOSS, - GGML_OP_CROSS_ENTROPY_LOSS_BACK, - GGML_OP_OPT_STEP_ADAMW, - + #define GGML_OP_ENUM_MEMBER(op_name, ...) op_name, + GGML_OP_LIST(GGML_OP_ENUM_MEMBER) + #undef GGML_OP_ENUM_MEMBER GGML_OP_COUNT, }; + // metadata for ggml_op + typedef struct { + int n_src; // number of arguments + } ggml_op_metadata_t; + + static const ggml_op_metadata_t GGML_OP_METADATA[GGML_OP_COUNT] = { + #define GGML_OP_METADATA_ENTRY(op_name, n_src_val) [op_name] = {.n_src = n_src_val}, + GGML_OP_LIST(GGML_OP_METADATA_ENTRY) + #undef GGML_OP_METADATA_ENTRY + }; + + // Inline function to get the number of source operands for an operation + static inline int ggml_op_get_n_src(enum ggml_op op) { + if (op >= 0 && op < GGML_OP_COUNT) { + return GGML_OP_METADATA[op].n_src; + } + return -2; // invalid op + } + enum ggml_unary_op { GGML_UNARY_OP_ABS, GGML_UNARY_OP_SGN, @@ -2186,6 +2203,21 @@ extern "C" { GGML_API void ggml_threadpool_params_init (struct ggml_threadpool_params * p, int n_threads); GGML_API bool ggml_threadpool_params_match (const struct ggml_threadpool_params * p0, const struct ggml_threadpool_params * p1); + // Compile-time check helper function + // Asserts that GGML_OP_METADATA is updated when ggml_op changes. + // Ensure all ggml_op values are handled. Missing case = compile error. + // Relies on compiler warnings treated as errors (-Werror=switch-enum or similar). + static inline void ggml_op_metadata_check(void) { + enum ggml_op op = GGML_OP_NONE; // Dummy value + switch (op) { +#define GGML_OP_SWITCH_CASE(op_name, ...) case op_name: (void)GGML_OP_METADATA[op_name].n_src; break; + GGML_OP_LIST(GGML_OP_SWITCH_CASE) +#undef GGML_OP_SWITCH_CASE + case GGML_OP_COUNT: break; + // NOTE: No default case. Compiler warning/error for unhandled enum value is the goal. + } + } + #ifdef __cplusplus } #endif diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index 4f0abb5a60f48..1131b2539dc20 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -754,6 +754,43 @@ static void serialize_graph(const ggml_cgraph * cgraph, std::vector & o memcpy(out_tensors, tensors.data(), n_tensors * sizeof(rpc_tensor)); } +// Helper function to validate graph operands before computation +static bool validate_graph_operands(const ggml_cgraph *graph) { + GGML_PRINT_DEBUG("[%s] Validating graph with %d nodes\n", __func__, graph->n_nodes); + for (uint32_t i = 0; i < (uint32_t)graph->n_nodes; ++i) { + const ggml_tensor* node = graph->nodes[i]; + // Initial null check added for safety. + if (node == nullptr) { + GGML_LOG_ERROR("[%s] Graph node %d is null.\n", __func__, i); + return false; + } + + const int n_src = ggml_op_get_n_src(node->op); + + if (n_src == -1) { + // Ops like GGML_OP_CUSTOM have variable inputs, cannot validate here. + GGML_PRINT_DEBUG("[%s] Skipping operand validation for node %d (op %s, name '%s') with variable inputs.\n", __func__, i, ggml_op_name(node->op), node->name); + continue; + } else if (n_src == -2) { + GGML_LOG_ERROR("[%s] Graph node %d (name '%s') has invalid op type %d.\n", __func__, i, node->name, (int)node->op); + return false; + } else if (n_src > GGML_MAX_SRC) { + GGML_LOG_ERROR("[%s] Graph node %d (op %s, name '%s') requires %d sources, exceeding GGML_MAX_SRC (%d).\n", __func__, i, ggml_op_name(node->op), node->name, n_src, GGML_MAX_SRC); + return false; + } + + // Check required source operands + for (int s_idx = 0; s_idx < n_src; ++s_idx) { + if (node->src[s_idx] == nullptr) { + GGML_LOG_ERROR("[%s] Graph node %d (op %s, name '%s') missing required input src[%d].\n", __func__, i, ggml_op_name(node->op), node->name, s_idx); + return false; + } + } + } + GGML_PRINT_DEBUG("[%s] Graph validation successful\n", __func__); + return true; +} + static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; std::vector input; @@ -1357,6 +1394,11 @@ bool rpc_server::graph_compute(const std::vector & input, rpc_msg_graph return false; } } + + if (!validate_graph_operands(graph)) { + return false; + } + ggml_status status = ggml_backend_graph_compute(backend, graph); response.result = status; return true; diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 8a6546240f46f..70812a5e96603 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1387,6 +1387,9 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { // initialize time system (required on Windows) ggml_time_init(); + ggml_op_metadata_check(); + + for (int i = 0; i < (1 << 16); ++i) { union { uint16_t u16; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 543db93402190..ab905a0fd6915 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -4546,6 +4546,59 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_falcon(2)); #endif + // Verify that the ggml_op_metadata_t correctly validates n_src + { + struct test_op_metadata_counts : public test_case { + std::string op_desc(ggml_tensor * t) override { + GGML_UNUSED(t); + return "OP_METADATA_COUNTS"; + } + + ggml_tensor * build_graph(ggml_context * ctx) override { + bool all_passed = true; + + struct { + ggml_op op; + int expected_n_src; + const char* name; + } test_ops[] = { + {GGML_OP_NONE, 0, "NONE"}, + {GGML_OP_UNARY, 1, "UNARY"}, + {GGML_OP_ADD, 2, "ADD"}, + {GGML_OP_MUL, 2, "MUL"}, + {GGML_OP_ROPE, 3, "ROPE"}, + {GGML_OP_FLASH_ATTN_EXT, 4, "FLASH_ATTN_EXT"}, + {GGML_OP_CUSTOM, -1, "CUSTOM"} + }; + + // Test each operation's metadata + for (const auto& test : test_ops) { + int n_src = ggml_op_get_n_src(test.op); + if (n_src != test.expected_n_src) { + fprintf(stderr, "ERROR: Expected n_src=%d for GGML_OP_%s but got %d\n", + test.expected_n_src, test.name, n_src); + all_passed = false; + } + } + + if (!all_passed) { + GGML_ASSERT("One or more metadata checks failed"); + } + + // Create a dummy tensor that will be used for backend comparison + ggml_tensor * a = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 10); + ggml_set_name(a, "a"); + + ggml_tensor * result = ggml_scale(ctx, a, 1.0f); + ggml_set_name(result, "result"); + + return result; + } + }; + + test_cases.push_back(std::make_unique()); + } + return test_cases; }