Skip to content

Commit 9826272

Browse files
committed
mtl : adapt the MNIST example as starter
1 parent 98c267f commit 9826272

File tree

4 files changed

+459
-0
lines changed

4 files changed

+459
-0
lines changed

examples/mtl/CMakeLists.txt

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,28 @@ set(TARGET mtl-export)
22
add_executable(${TARGET} mtl-export.cpp)
33
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
44
target_compile_features(${TARGET} PRIVATE cxx_std_11)
5+
56
if(TARGET BUILD_INFO)
67
add_dependencies(${TARGET} BUILD_INFO)
78
endif()
9+
10+
if (APPLE)
11+
#
12+
# mtl
13+
14+
find_library(FOUNDATION_LIBRARY Foundation REQUIRED)
15+
find_library(METAL_FRAMEWORK Metal REQUIRED)
16+
find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
17+
find_library(METALPERFORMANCE_FRAMEWORK MetalPerformanceShaders REQUIRED)
18+
19+
set(TEST_TARGET mtl)
20+
add_executable(${TEST_TARGET} mtl.cpp mtl.h mtl.m)
21+
target_link_libraries(${TEST_TARGET} PRIVATE
22+
ggml
23+
${FOUNDATION_LIBRARY}
24+
${METAL_FRAMEWORK}
25+
${METALKIT_FRAMEWORK}
26+
${METALPERFORMANCE_FRAMEWORK}
27+
)
28+
endif()
29+

examples/mtl/mtl.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
#include "ggml.h"
2+
#include "mtl.h"
3+
4+
#include <cstdio>
5+
#include <cstring>
6+
#include <cstdlib>
7+
8+
int main(int argc, char ** argv) {
9+
ggml_time_init();
10+
11+
if (argc != 2) {
12+
fprintf(stderr, "Usage: %s llama.ggml\n", argv[0]);
13+
return -1;
14+
}
15+
16+
const char * fname_cgraph = argv[1];
17+
18+
// load the compute graph
19+
struct ggml_context * ctx_data = NULL;
20+
struct ggml_context * ctx_eval = NULL;
21+
22+
struct ggml_cgraph gf = ggml_graph_import(fname_cgraph, &ctx_data, &ctx_eval);
23+
gf.n_threads = 1;
24+
25+
// allocate eval context
26+
// needed during ggml_graph_compute() to allocate a work tensor
27+
static size_t buf_size = gf.work_size; // TODO
28+
static void * buf = malloc(buf_size);
29+
30+
struct ggml_init_params params = {
31+
/*.mem_size =*/ buf_size,
32+
/*.mem_buffer =*/ buf,
33+
/*.no_alloc =*/ false,
34+
};
35+
36+
struct ggml_context * ctx_work = ggml_init(params);
37+
38+
// this allocates all Metal resources and memory buffers
39+
auto * ctx_mtl = llama_mtl_init(ctx_data, ctx_eval, ctx_work, &gf);
40+
41+
// the actual inference happens here
42+
llama_mtl_eval(ctx_mtl, &gf);
43+
44+
llama_mtl_free(ctx_mtl);
45+
46+
ggml_free(ctx_work);
47+
ggml_free(ctx_data);
48+
ggml_free(ctx_eval);
49+
50+
return 0;
51+
}
52+

examples/mtl/mtl.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#pragma once
2+
3+
struct ggml_context;
4+
struct ggml_cgraph;
5+
6+
#ifdef __cplusplus
7+
extern "C" {
8+
#endif
9+
10+
struct ggml_mtl_context;
11+
12+
struct ggml_mtl_context * llama_mtl_init(
13+
struct ggml_context * ctx_data,
14+
struct ggml_context * ctx_eval,
15+
struct ggml_context * ctx_work,
16+
struct ggml_cgraph * gf);
17+
18+
void llama_mtl_free(struct ggml_mtl_context * ctx);
19+
20+
// return 0 on success
21+
int llama_mtl_eval(
22+
struct ggml_mtl_context * ctx,
23+
struct ggml_cgraph * gf);
24+
25+
#ifdef __cplusplus
26+
}
27+
#endif
28+

0 commit comments

Comments
 (0)