Skip to content

vulkan: Add fusion support for RMS_NORM+MUL #14366

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ggml/include/ggml-backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ extern "C" {
typedef bool (*ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data);

// Compare the output of two backends
GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data);
GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor *test_node);

// Tensor initialization
GGML_API enum ggml_status ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr);
Expand Down
6 changes: 5 additions & 1 deletion ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,11 @@ extern "C" {

void * extra; // extra things e.g. for ggml-cuda.cu

char padding[8];
// number of operations that use this tensor as a src
int32_t use_count;

// add padding if needed to make a multiple of GGML_MEM_ALIGN
char padding[4];
};

static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
Expand Down
53 changes: 35 additions & 18 deletions ggml/src/ggml-backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -817,8 +817,8 @@ static void ggml_backend_sched_print_assignments(ggml_backend_sched_t sched, str
}
if (sched->debug > 1) {
ggml_backend_t tensor_backend = ggml_backend_sched_get_tensor_backend(sched, node);
GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s]:", i, ggml_op_name(node->op), node->name,
fmt_size(ggml_nbytes(node)), tensor_backend ? ggml_backend_name(tensor_backend) : "NULL", GET_CAUSE(node));
GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s] use=%d:", i, ggml_op_name(node->op), node->name,
fmt_size(ggml_nbytes(node)), tensor_backend ? ggml_backend_name(tensor_backend) : "NULL", GET_CAUSE(node), node->use_count);
for (int j = 0; j < GGML_MAX_SRC; j++) {
struct ggml_tensor * src = node->src[j];
if (src == NULL) {
Expand Down Expand Up @@ -1826,7 +1826,7 @@ void ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy) {
ggml_free(copy.ctx_unallocated);
}

bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data) {
bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor *test_node) {
struct ggml_backend_graph_copy copy = ggml_backend_graph_copy(backend2, graph);
if (copy.buffer == NULL) {
return false;
Expand All @@ -1837,28 +1837,45 @@ bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t

assert(g1->n_nodes == g2->n_nodes);

for (int i = 0; i < g1->n_nodes; i++) {
struct ggml_tensor * t1 = g1->nodes[i];
struct ggml_tensor * t2 = g2->nodes[i];
if (test_node != nullptr) {
// Compute the whole graph and only test the output for a specific tensor
ggml_backend_graph_compute(backend1, g1);
ggml_backend_graph_compute(backend2, g2);

assert(t1->op == t2->op && ggml_are_same_layout(t1, t2));
int test_node_idx = -1;
for (int i = 0; i < g1->n_nodes; i++) {
struct ggml_tensor * t1 = g1->nodes[i];
if (t1 == test_node) {
test_node_idx = i;
break;
}
}
GGML_ASSERT(test_node_idx != -1);

struct ggml_cgraph g1v = ggml_graph_view(g1, i, i + 1);
struct ggml_cgraph g2v = ggml_graph_view(g2, i, i + 1);
callback(test_node_idx, g1->nodes[test_node_idx], g2->nodes[test_node_idx], user_data);
} else {
for (int i = 0; i < g1->n_nodes; i++) {
struct ggml_tensor * t1 = g1->nodes[i];
struct ggml_tensor * t2 = g2->nodes[i];

ggml_backend_graph_compute(backend1, &g1v);
ggml_backend_graph_compute(backend2, &g2v);
assert(t1->op == t2->op && ggml_are_same_layout(t1, t2));

if (ggml_is_view_op(t1->op)) {
continue;
}
struct ggml_cgraph g1v = ggml_graph_view(g1, i, i + 1);
struct ggml_cgraph g2v = ggml_graph_view(g2, i, i + 1);

// compare results, calculate rms etc
if (!callback(i, t1, t2, user_data)) {
break;
ggml_backend_graph_compute(backend1, &g1v);
ggml_backend_graph_compute(backend2, &g2v);

if (ggml_is_view_op(t1->op)) {
continue;
}

// compare results, calculate rms etc
if (!callback(i, t1, t2, user_data)) {
break;
}
}
}

ggml_backend_graph_copy_free(copy);

return true;
Expand Down
63 changes: 63 additions & 0 deletions ggml/src/ggml-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -467,10 +467,73 @@ static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) {
#define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x)
#define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x)

// return true if the node's results are only used by N other nodes
// and can be fused into their calculations.
static inline bool ggml_node_has_N_uses(const struct ggml_tensor * node, int32_t N) {
// check the use count against how many we're replacing
if (node->use_count != N) {
return false;
}

// if node is a view, some other node might be using the intermediate result
// via the view source.
if (node->view_src) {
return false;
}

// If the user requested output for the node, can't fuse
if (node->flags & GGML_TENSOR_FLAG_OUTPUT) {
return false;
}

return true;
}

// Returns true if nodes [i, i+ops.size()) are the sequence of ggml_ops in ops[]
// and are fusable. Nodes are considered fusable according to this function if:
// - all nodes except the last have only one use and are not views/outputs (see ggml_node_has_N_uses).
// - all nodes except the last are src[0] of the following node.
// - all nodes are the same shape.
// TODO: Consider allowing GGML_OP_NONE nodes in between
static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, const enum ggml_op *ops, int num_ops) {
if (node_idx + num_ops > cgraph->n_nodes) {
return false;
}

for (int i = 0; i < num_ops; ++i) {
struct ggml_tensor *node = cgraph->nodes[node_idx + i];
if (node->op != ops[i]) {
return false;
}
if (i < num_ops && !ggml_node_has_N_uses(node, 1)) {
return false;
}
if (i > 0) {
struct ggml_tensor *prev = cgraph->nodes[node_idx + i - 1];
if (node->src[0] != prev) {
return false;
}
if (!ggml_are_same_shape(node, prev)) {
return false;
}
}
}
return true;
}

#ifdef __cplusplus
}
#endif

#ifdef __cplusplus
#include <initializer_list>

// nicer C++ syntax for ggml_can_fuse
inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
return ggml_can_fuse(cgraph, node_idx, ops.begin(), (int)ops.size());
}
#endif

#ifdef __cplusplus
#include <vector>

Expand Down
56 changes: 41 additions & 15 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,7 @@ struct vk_device_struct {
vk_pipeline pipeline_norm_f32;
vk_pipeline pipeline_group_norm_f32;
vk_pipeline pipeline_rms_norm_f32;
vk_pipeline pipeline_rms_norm_mul_f32;
vk_pipeline pipeline_rms_norm_back_f32;
vk_pipeline pipeline_l2_norm_f32;

Expand Down Expand Up @@ -978,6 +979,10 @@ struct ggml_backend_vk_context {

vk_command_pool compute_cmd_pool;
vk_command_pool transfer_cmd_pool;

// number of additional consecutive nodes that are being fused with the
// node currently being processed
bool num_additional_fused_ops {};
};

static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT
Expand Down Expand Up @@ -2655,7 +2660,8 @@ static void ggml_vk_load_shaders(vk_device& device) {

ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);

Expand Down Expand Up @@ -6418,7 +6424,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return nullptr;
case GGML_OP_RMS_NORM:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_rms_norm_f32;
return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_f32 : ctx->device->pipeline_rms_norm_f32;
}
return nullptr;
case GGML_OP_RMS_NORM_BACK:
Expand Down Expand Up @@ -7518,18 +7524,19 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun);
}

static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
float * op_params = (float *)dst->op_params;
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);

ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, {
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM, {
(uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
0,
op_params[0], 0.0f,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
op_params[0], 0.0f, 0,
}, dryrun);
}

Expand Down Expand Up @@ -8724,7 +8731,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* t

// Returns true if node has enqueued work into the queue, false otherwise
// If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution.
static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool almost_ready, bool submit){
static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, ggml_tensor * node, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool almost_ready, bool submit){
if (ggml_is_empty(node) || !node->buffer) {
return false;
}
Expand Down Expand Up @@ -8962,8 +8969,13 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod

break;
case GGML_OP_RMS_NORM:
ggml_vk_rms_norm(ctx, compute_ctx, src0, node, dryrun);

if (ctx->num_additional_fused_ops > 0) {
// fused rms_norm + mul
ggml_tensor *mul = cgraph->nodes[node_idx + 1];
ggml_vk_rms_norm(ctx, compute_ctx, src0, mul->src[1], mul, dryrun);
} else {
ggml_vk_rms_norm(ctx, compute_ctx, src0, src0, node, dryrun);
}
break;
case GGML_OP_RMS_NORM_BACK:
ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun);
Expand Down Expand Up @@ -9698,10 +9710,15 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg

uint64_t total_mat_mul_bytes = 0;
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false, false);
if (ggml_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
ctx->num_additional_fused_ops = 1;
}
ggml_vk_build_graph(ctx, cgraph, cgraph->nodes[i], i, nullptr, 0, true, false, false, false);
if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
}
i += ctx->num_additional_fused_ops;
ctx->num_additional_fused_ops = 0;
}
if (ctx->device->need_compiles) {
ggml_vk_load_shaders(ctx->device);
Expand Down Expand Up @@ -9763,14 +9780,18 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
}

if (ggml_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
ctx->num_additional_fused_ops = 1;
}

// Signal the almost_ready fence when the graph is mostly complete (< 20% remaining)
bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;
bool submit = (submitted_nodes >= nodes_per_submit) ||
(mul_mat_bytes >= mul_mat_bytes_per_submit) ||
(i == last_node) ||
(i + ctx->num_additional_fused_ops == last_node) ||
(almost_ready && !ctx->almost_ready_fence_pending);

bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, almost_ready, submit);
bool enqueued = ggml_vk_build_graph(ctx, cgraph, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i + ctx->num_additional_fused_ops == last_node, almost_ready, submit);

if (vk_perf_logger_enabled) {
if (ctx->compute_ctx.expired()) {
Expand All @@ -9780,7 +9801,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
} else {
compute_ctx = ctx->compute_ctx.lock();
}
compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, i+1);
// If there are fused ops, just write out timestamps for all nodes to keep the accounting simple
for (int j = 0; j < ctx->num_additional_fused_ops + 1; ++j) {
compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, i+j+1);
}
}

if (enqueued) {
Expand All @@ -9802,6 +9826,8 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
}
submit_count++;
}
i += ctx->num_additional_fused_ops;
ctx->num_additional_fused_ops = 0;
}

if (vk_perf_logger_enabled) {
Expand Down
15 changes: 12 additions & 3 deletions ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
#version 450

#include "generic_unary_head.comp"
#include "generic_binary_head.comp"
#include "types.comp"

#extension GL_EXT_control_flow_attributes : enable
#define BLOCK_SIZE 512

layout (constant_id = 1) const bool do_multiply = false;

layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;

shared FLOAT_TYPE sum[BLOCK_SIZE];
Expand All @@ -25,6 +27,7 @@ void main() {
const uint stride_sample = p.nb03;

uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset();
uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset();
uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset();

sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
Expand All @@ -46,7 +49,13 @@ void main() {
const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(ncols);
const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));

[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
if (do_multiply) {
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col]));
}
} else {
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
}
}
}
2 changes: 1 addition & 1 deletion ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ void process_shaders() {
// Norms
string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));

Expand Down
2 changes: 2 additions & 0 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -1608,6 +1608,7 @@ static struct ggml_tensor * ggml_new_tensor_impl(
/*.data =*/ obj_alloc_size > 0 ? (void *)(result + 1) : data,
/*.name =*/ { 0 },
/*.extra =*/ NULL,
/*.use_count =*/ 0,
/*.padding =*/ { 0 },
};

Expand Down Expand Up @@ -5817,6 +5818,7 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor *
/* unknown order, just fall back to using i*/ i;
if (node->src[k]) {
ggml_visit_parents(cgraph, node->src[k]);
node->src[k]->use_count++;
}
}

Expand Down
Loading
Loading