Skip to content

feat(rpc): compile-time op metadata & RPC graph validation #13167

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 123 additions & 91 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -425,102 +425,119 @@ extern "C" {
GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors
};

// Helper macro for the X-Macro list of operations
// Format: XX(op_name, n_src_value)
#define GGML_OP_LIST(XX) \
XX(GGML_OP_NONE, 0) \
XX(GGML_OP_DUP, 1) \
XX(GGML_OP_ADD, 2) \
XX(GGML_OP_ADD1, 2) \
XX(GGML_OP_ACC, 2) \
XX(GGML_OP_SUB, 2) \
XX(GGML_OP_MUL, 2) \
XX(GGML_OP_DIV, 2) \
XX(GGML_OP_SQR, 1) \
XX(GGML_OP_SQRT, 1) \
XX(GGML_OP_LOG, 1) \
XX(GGML_OP_SIN, 1) \
XX(GGML_OP_COS, 1) \
XX(GGML_OP_SUM, 1) \
XX(GGML_OP_SUM_ROWS, 1) \
XX(GGML_OP_MEAN, 1) \
XX(GGML_OP_ARGMAX, 1) \
XX(GGML_OP_COUNT_EQUAL, 2) \
XX(GGML_OP_REPEAT, 2) \
XX(GGML_OP_REPEAT_BACK, 2) \
XX(GGML_OP_CONCAT, 2) \
XX(GGML_OP_SILU_BACK, 2) \
XX(GGML_OP_NORM, 1) \
XX(GGML_OP_RMS_NORM, 1) \
XX(GGML_OP_RMS_NORM_BACK, 2) \
XX(GGML_OP_GROUP_NORM, 1) \
XX(GGML_OP_L2_NORM, 1) \
XX(GGML_OP_MUL_MAT, 2) \
XX(GGML_OP_MUL_MAT_ID, 3) \
XX(GGML_OP_OUT_PROD, 2) \
XX(GGML_OP_SCALE, 1) \
XX(GGML_OP_SET, 2) \
XX(GGML_OP_CPY, 2) \
XX(GGML_OP_CONT, 1) \
XX(GGML_OP_RESHAPE, 1) \
XX(GGML_OP_VIEW, 1) \
XX(GGML_OP_PERMUTE, 1) \
XX(GGML_OP_TRANSPOSE, 1) \
XX(GGML_OP_GET_ROWS, 2) \
XX(GGML_OP_GET_ROWS_BACK, 3) \
XX(GGML_OP_DIAG, 1) \
XX(GGML_OP_DIAG_MASK_INF, 1) \
XX(GGML_OP_DIAG_MASK_ZERO, 1) \
XX(GGML_OP_SOFT_MAX, 2) \
XX(GGML_OP_SOFT_MAX_BACK, 2) \
XX(GGML_OP_ROPE, 3) \
XX(GGML_OP_ROPE_BACK, 3) \
XX(GGML_OP_CLAMP, 1) \
XX(GGML_OP_CONV_TRANSPOSE_1D, 2) \
XX(GGML_OP_IM2COL, 2) \
XX(GGML_OP_IM2COL_BACK, 2) \
XX(GGML_OP_CONV_2D_DW, 2) \
XX(GGML_OP_CONV_TRANSPOSE_2D, 2) \
XX(GGML_OP_POOL_1D, 1) \
XX(GGML_OP_POOL_2D, 1) \
XX(GGML_OP_POOL_2D_BACK, 2) \
XX(GGML_OP_UPSCALE, 1) \
XX(GGML_OP_PAD, 1) \
XX(GGML_OP_PAD_REFLECT_1D, 1) \
XX(GGML_OP_ARANGE, 0) \
XX(GGML_OP_TIMESTEP_EMBEDDING, 1) \
XX(GGML_OP_ARGSORT, 1) \
XX(GGML_OP_LEAKY_RELU, 1) \
XX(GGML_OP_FLASH_ATTN_EXT, 4) \
XX(GGML_OP_FLASH_ATTN_BACK, 4) \
XX(GGML_OP_SSM_CONV, 2) \
XX(GGML_OP_SSM_SCAN, 6) \
XX(GGML_OP_WIN_PART, 1) \
XX(GGML_OP_WIN_UNPART, 1) \
XX(GGML_OP_GET_REL_POS, 1) \
XX(GGML_OP_ADD_REL_POS, 3) \
XX(GGML_OP_RWKV_WKV6, 6) \
XX(GGML_OP_GATED_LINEAR_ATTN, 5) \
XX(GGML_OP_RWKV_WKV7, 7) \
XX(GGML_OP_UNARY, 1) \
XX(GGML_OP_MAP_CUSTOM1, 1) \
XX(GGML_OP_MAP_CUSTOM2, 2) \
XX(GGML_OP_MAP_CUSTOM3, 3) \
XX(GGML_OP_CUSTOM, -1) \
XX(GGML_OP_CROSS_ENTROPY_LOSS, 2) \
XX(GGML_OP_CROSS_ENTROPY_LOSS_BACK, 3) \
XX(GGML_OP_OPT_STEP_ADAMW, 5)

// available tensor operations:
enum ggml_op {
GGML_OP_NONE = 0,

GGML_OP_DUP,
GGML_OP_ADD,
GGML_OP_ADD1,
GGML_OP_ACC,
GGML_OP_SUB,
GGML_OP_MUL,
GGML_OP_DIV,
GGML_OP_SQR,
GGML_OP_SQRT,
GGML_OP_LOG,
GGML_OP_SIN,
GGML_OP_COS,
GGML_OP_SUM,
GGML_OP_SUM_ROWS,
GGML_OP_MEAN,
GGML_OP_ARGMAX,
GGML_OP_COUNT_EQUAL,
GGML_OP_REPEAT,
GGML_OP_REPEAT_BACK,
GGML_OP_CONCAT,
GGML_OP_SILU_BACK,
GGML_OP_NORM, // normalize
GGML_OP_RMS_NORM,
GGML_OP_RMS_NORM_BACK,
GGML_OP_GROUP_NORM,
GGML_OP_L2_NORM,

GGML_OP_MUL_MAT,
GGML_OP_MUL_MAT_ID,
GGML_OP_OUT_PROD,

GGML_OP_SCALE,
GGML_OP_SET,
GGML_OP_CPY,
GGML_OP_CONT,
GGML_OP_RESHAPE,
GGML_OP_VIEW,
GGML_OP_PERMUTE,
GGML_OP_TRANSPOSE,
GGML_OP_GET_ROWS,
GGML_OP_GET_ROWS_BACK,
GGML_OP_DIAG,
GGML_OP_DIAG_MASK_INF,
GGML_OP_DIAG_MASK_ZERO,
GGML_OP_SOFT_MAX,
GGML_OP_SOFT_MAX_BACK,
GGML_OP_ROPE,
GGML_OP_ROPE_BACK,
GGML_OP_CLAMP,
GGML_OP_CONV_TRANSPOSE_1D,
GGML_OP_IM2COL,
GGML_OP_IM2COL_BACK,
GGML_OP_CONV_2D_DW,
GGML_OP_CONV_TRANSPOSE_2D,
GGML_OP_POOL_1D,
GGML_OP_POOL_2D,
GGML_OP_POOL_2D_BACK,
GGML_OP_UPSCALE, // nearest interpolate
GGML_OP_PAD,
GGML_OP_PAD_REFLECT_1D,
GGML_OP_ARANGE,
GGML_OP_TIMESTEP_EMBEDDING,
GGML_OP_ARGSORT,
GGML_OP_LEAKY_RELU,

GGML_OP_FLASH_ATTN_EXT,
GGML_OP_FLASH_ATTN_BACK,
GGML_OP_SSM_CONV,
GGML_OP_SSM_SCAN,
GGML_OP_WIN_PART,
GGML_OP_WIN_UNPART,
GGML_OP_GET_REL_POS,
GGML_OP_ADD_REL_POS,
GGML_OP_RWKV_WKV6,
GGML_OP_GATED_LINEAR_ATTN,
GGML_OP_RWKV_WKV7,

GGML_OP_UNARY,

GGML_OP_MAP_CUSTOM1,
GGML_OP_MAP_CUSTOM2,
GGML_OP_MAP_CUSTOM3,

GGML_OP_CUSTOM,

GGML_OP_CROSS_ENTROPY_LOSS,
GGML_OP_CROSS_ENTROPY_LOSS_BACK,
GGML_OP_OPT_STEP_ADAMW,

#define GGML_OP_ENUM_MEMBER(op_name, ...) op_name,
GGML_OP_LIST(GGML_OP_ENUM_MEMBER)
#undef GGML_OP_ENUM_MEMBER
GGML_OP_COUNT,
};

// metadata for ggml_op
typedef struct {
int n_src; // number of arguments
} ggml_op_metadata_t;

static const ggml_op_metadata_t GGML_OP_METADATA[GGML_OP_COUNT] = {
#define GGML_OP_METADATA_ENTRY(op_name, n_src_val) [op_name] = {.n_src = n_src_val},
GGML_OP_LIST(GGML_OP_METADATA_ENTRY)
#undef GGML_OP_METADATA_ENTRY
};

// Inline function to get the number of source operands for an operation
static inline int ggml_op_get_n_src(enum ggml_op op) {
if (op >= 0 && op < GGML_OP_COUNT) {
return GGML_OP_METADATA[op].n_src;
}
return -2; // invalid op
}

enum ggml_unary_op {
GGML_UNARY_OP_ABS,
GGML_UNARY_OP_SGN,
Expand Down Expand Up @@ -2186,6 +2203,21 @@ extern "C" {
GGML_API void ggml_threadpool_params_init (struct ggml_threadpool_params * p, int n_threads);
GGML_API bool ggml_threadpool_params_match (const struct ggml_threadpool_params * p0, const struct ggml_threadpool_params * p1);

// Compile-time check helper function
// Asserts that GGML_OP_METADATA is updated when ggml_op changes.
// Ensure all ggml_op values are handled. Missing case = compile error.
// Relies on compiler warnings treated as errors (-Werror=switch-enum or similar).
static inline void ggml_op_metadata_check(void) {
enum ggml_op op = GGML_OP_NONE; // Dummy value
switch (op) {
#define GGML_OP_SWITCH_CASE(op_name, ...) case op_name: (void)GGML_OP_METADATA[op_name].n_src; break;
GGML_OP_LIST(GGML_OP_SWITCH_CASE)
#undef GGML_OP_SWITCH_CASE
case GGML_OP_COUNT: break;
// NOTE: No default case. Compiler warning/error for unhandled enum value is the goal.
}
}

#ifdef __cplusplus
}
#endif
42 changes: 42 additions & 0 deletions ggml/src/ggml-rpc/ggml-rpc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,43 @@ static void serialize_graph(const ggml_cgraph * cgraph, std::vector<uint8_t> & o
memcpy(out_tensors, tensors.data(), n_tensors * sizeof(rpc_tensor));
}

// Helper function to validate graph operands before computation
static bool validate_graph_operands(const ggml_cgraph *graph) {
GGML_PRINT_DEBUG("[%s] Validating graph with %d nodes\n", __func__, graph->n_nodes);
for (uint32_t i = 0; i < (uint32_t)graph->n_nodes; ++i) {
const ggml_tensor* node = graph->nodes[i];
// Initial null check added for safety.
if (node == nullptr) {
GGML_LOG_ERROR("[%s] Graph node %d is null.\n", __func__, i);
return false;
}

const int n_src = ggml_op_get_n_src(node->op);

if (n_src == -1) {
// Ops like GGML_OP_CUSTOM have variable inputs, cannot validate here.
GGML_PRINT_DEBUG("[%s] Skipping operand validation for node %d (op %s, name '%s') with variable inputs.\n", __func__, i, ggml_op_name(node->op), node->name);
continue;
} else if (n_src == -2) {
GGML_LOG_ERROR("[%s] Graph node %d (name '%s') has invalid op type %d.\n", __func__, i, node->name, (int)node->op);
return false;
} else if (n_src > GGML_MAX_SRC) {
GGML_LOG_ERROR("[%s] Graph node %d (op %s, name '%s') requires %d sources, exceeding GGML_MAX_SRC (%d).\n", __func__, i, ggml_op_name(node->op), node->name, n_src, GGML_MAX_SRC);
return false;
}

// Check required source operands
for (int s_idx = 0; s_idx < n_src; ++s_idx) {
if (node->src[s_idx] == nullptr) {
GGML_LOG_ERROR("[%s] Graph node %d (op %s, name '%s') missing required input src[%d].\n", __func__, i, ggml_op_name(node->op), node->name, s_idx);
return false;
}
}
}
GGML_PRINT_DEBUG("[%s] Graph validation successful\n", __func__);
return true;
}

static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
std::vector<uint8_t> input;
Expand Down Expand Up @@ -1357,6 +1394,11 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
return false;
}
}

if (!validate_graph_operands(graph)) {
return false;
}

ggml_status status = ggml_backend_graph_compute(backend, graph);
response.result = status;
return true;
Expand Down
3 changes: 3 additions & 0 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -1387,6 +1387,9 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
// initialize time system (required on Windows)
ggml_time_init();

ggml_op_metadata_check();


for (int i = 0; i < (1 << 16); ++i) {
union {
uint16_t u16;
Expand Down
53 changes: 53 additions & 0 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4546,6 +4546,59 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_falcon(2));
#endif

// Verify that the ggml_op_metadata_t correctly validates n_src
{
struct test_op_metadata_counts : public test_case {
std::string op_desc(ggml_tensor * t) override {
GGML_UNUSED(t);
return "OP_METADATA_COUNTS";
}

ggml_tensor * build_graph(ggml_context * ctx) override {
bool all_passed = true;

struct {
ggml_op op;
int expected_n_src;
const char* name;
} test_ops[] = {
{GGML_OP_NONE, 0, "NONE"},
{GGML_OP_UNARY, 1, "UNARY"},
{GGML_OP_ADD, 2, "ADD"},
{GGML_OP_MUL, 2, "MUL"},
{GGML_OP_ROPE, 3, "ROPE"},
{GGML_OP_FLASH_ATTN_EXT, 4, "FLASH_ATTN_EXT"},
{GGML_OP_CUSTOM, -1, "CUSTOM"}
};

// Test each operation's metadata
for (const auto& test : test_ops) {
int n_src = ggml_op_get_n_src(test.op);
if (n_src != test.expected_n_src) {
fprintf(stderr, "ERROR: Expected n_src=%d for GGML_OP_%s but got %d\n",
test.expected_n_src, test.name, n_src);
all_passed = false;
}
}

if (!all_passed) {
GGML_ASSERT("One or more metadata checks failed");
}

// Create a dummy tensor that will be used for backend comparison
ggml_tensor * a = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 10);
ggml_set_name(a, "a");

ggml_tensor * result = ggml_scale(ctx, a, 1.0f);
ggml_set_name(result, "result");

return result;
}
};

test_cases.push_back(std::make_unique<test_op_metadata_counts>());
}

return test_cases;
}

Expand Down
Loading