Skip to content

Commit 3833990

Browse files
author
Nikhil Thorat
authored
Modularize the WASM kernels into separate build targets. (#2173)
- Each kernel gets its own cc_library and file. - A new cc_library encapsulates all of them: ":all_kernels" - backend.cc contains global state and backend utilities. - Add linkstatic=1 so the compiler yells before emscripten fails at linking (it doesn't support dynamic linking, e.g. extern globals). DEV
1 parent 7c2ee51 commit 3833990

File tree

7 files changed

+269
-122
lines changed

7 files changed

+269
-122
lines changed

.vscode/settings.json

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,97 @@
3535
"typescript.tsdk": "${workspaceRoot}/node_modules/typescript/lib",
3636
"clang-format.executable": "${workspaceRoot}/node_modules/.bin/clang-format",
3737
"files.associations": {
38-
"memory": "cpp"
38+
"memory": "cpp",
39+
"any": "cpp",
40+
"array": "cpp",
41+
"atomic": "cpp",
42+
"strstream": "cpp",
43+
"*.tcc": "cpp",
44+
"bitset": "cpp",
45+
"cctype": "cpp",
46+
"chrono": "cpp",
47+
"cinttypes": "cpp",
48+
"clocale": "cpp",
49+
"cmath": "cpp",
50+
"codecvt": "cpp",
51+
"complex": "cpp",
52+
"condition_variable": "cpp",
53+
"cstdarg": "cpp",
54+
"cstddef": "cpp",
55+
"cstdint": "cpp",
56+
"cstdio": "cpp",
57+
"cstdlib": "cpp",
58+
"cstring": "cpp",
59+
"ctime": "cpp",
60+
"cwchar": "cpp",
61+
"cwctype": "cpp",
62+
"deque": "cpp",
63+
"list": "cpp",
64+
"unordered_map": "cpp",
65+
"unordered_set": "cpp",
66+
"vector": "cpp",
67+
"exception": "cpp",
68+
"algorithm": "cpp",
69+
"filesystem": "cpp",
70+
"functional": "cpp",
71+
"iterator": "cpp",
72+
"map": "cpp",
73+
"memory_resource": "cpp",
74+
"numeric": "cpp",
75+
"optional": "cpp",
76+
"random": "cpp",
77+
"ratio": "cpp",
78+
"regex": "cpp",
79+
"set": "cpp",
80+
"string": "cpp",
81+
"string_view": "cpp",
82+
"system_error": "cpp",
83+
"tuple": "cpp",
84+
"type_traits": "cpp",
85+
"utility": "cpp",
86+
"fstream": "cpp",
87+
"future": "cpp",
88+
"initializer_list": "cpp",
89+
"iomanip": "cpp",
90+
"iosfwd": "cpp",
91+
"iostream": "cpp",
92+
"istream": "cpp",
93+
"limits": "cpp",
94+
"mutex": "cpp",
95+
"new": "cpp",
96+
"ostream": "cpp",
97+
"shared_mutex": "cpp",
98+
"sstream": "cpp",
99+
"stdexcept": "cpp",
100+
"streambuf": "cpp",
101+
"thread": "cpp",
102+
"cfenv": "cpp",
103+
"typeinfo": "cpp",
104+
"valarray": "cpp",
105+
"variant": "cpp",
106+
"charconv": "cpp",
107+
"__bit_reference": "cpp",
108+
"__config": "cpp",
109+
"__debug": "cpp",
110+
"__errc": "cpp",
111+
"__functional_base": "cpp",
112+
"__hash_table": "cpp",
113+
"__locale": "cpp",
114+
"__mutex_base": "cpp",
115+
"__node_handle": "cpp",
116+
"__nullptr": "cpp",
117+
"__split_buffer": "cpp",
118+
"__sso_allocator": "cpp",
119+
"__std_stream": "cpp",
120+
"__string": "cpp",
121+
"__threading_support": "cpp",
122+
"__tree": "cpp",
123+
"__tuple": "cpp",
124+
"bit": "cpp",
125+
"ios": "cpp",
126+
"locale": "cpp",
127+
"queue": "cpp",
128+
"stack": "cpp",
129+
"*.ipp": "cpp"
39130
}
40131
}

tfjs-backend-wasm/src/backend_wasm.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ async function init(): Promise<{wasm: BackendWasmModule}> {
212212
dispose: wasm.cwrap('dispose', voidReturnType, []),
213213
add: wasm.cwrap('add', voidReturnType, ['number', 'number', 'number']),
214214
batchMatMul: wasm.cwrap(
215-
'batchMatMul', voidReturnType,
215+
'batch_matmul', voidReturnType,
216216
[
217217
'number', 'number', 'number', 'number', 'number', 'number',
218218
'number', 'number', 'number', 'number', 'number', 'number', 'number'

tfjs-backend-wasm/src/cc/BUILD

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
cc_binary(
22
name = "tfjs-backend-wasm.js",
3-
srcs = ["lib.cc"],
43
linkopts = [
54
"-s ALLOW_MEMORY_GROWTH=1",
65
"-s DEFAULT_LIBRARY_FUNCS_TO_INCLUDE=[]",
@@ -15,19 +14,53 @@ cc_binary(
1514
"-s MALLOC=emmalloc",
1615
],
1716
deps = [
18-
":kernels",
19-
":util",
17+
":backend",
18+
":all_kernels",
2019
],
2120
)
2221

2322
cc_library(
24-
name = "kernels",
25-
srcs = ["kernels.cc"],
26-
hdrs = ["kernels.h"],
27-
deps = ["@xnnpack//:XNNPACK"],
23+
name = "backend",
24+
srcs = ["backend.cc"],
25+
hdrs = ["backend.h"],
26+
linkstatic = 1,
27+
deps = [
28+
":util",
29+
"@xnnpack//:XNNPACK",
30+
],
2831
)
2932

3033
cc_library(
31-
name = "util",
32-
srcs = ["util.h"],
34+
name = "all_kernels",
35+
linkstatic = 1,
36+
deps = [
37+
":add",
38+
":batch_matmul",
39+
]
40+
)
41+
42+
cc_library(
43+
name = "add",
44+
srcs = ["add.cc"],
45+
linkstatic = 1,
46+
deps = [
47+
":backend",
48+
":util",
49+
],
50+
)
51+
52+
cc_library(
53+
name = "batch_matmul",
54+
srcs = ["batch_matmul.cc"],
55+
linkstatic = 1,
56+
deps = [
57+
":backend",
58+
":util",
59+
],
60+
)
61+
62+
cc_library(
63+
name = "util",
64+
linkstatic = 1,
65+
srcs = ["util.h"],
3366
)

tfjs-backend-wasm/src/cc/add.cc

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
/* Copyright 2019 Google Inc. All Rights Reserved.
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
* ===========================================================================*/
14+
15+
#include <emscripten.h>
16+
#include <math.h>
17+
#include <cstdio>
18+
#include <map>
19+
#include <vector>
20+
21+
#include "src/cc/backend.h"
22+
#include "src/cc/util.h"
23+
24+
template <class T>
25+
void add_impl(T* a_buf, int a_size, T* b_buf, int b_size, T* out_buf) {
26+
int size = std::max(a_size, b_size);
27+
for (int i = 0; i < size; ++i) {
28+
out_buf[i] = a_buf[i % a_size] + b_buf[i % b_size];
29+
}
30+
}
31+
// Templates need explicit instantiation when implemented in a .cc file.
32+
template void add_impl<float>(float* a_buf, int a_size, float* b_buf,
33+
int b_size, float* out_buf);
34+
template void add_impl<int>(int* a_buf, int a_size, int* b_buf, int b_size,
35+
int* out_buf);
36+
template void add_impl<bool>(bool* a_buf, int a_size, bool* b_buf, int b_size,
37+
bool* out_buf);
38+
39+
namespace tfjs {
40+
// We use C-style API to interface with Javascript.
41+
extern "C" {
42+
43+
EMSCRIPTEN_KEEPALIVE
44+
void add(int a_id, int b_id, int out_id) {
45+
const auto a_info = backend::get_tensor_info(a_id);
46+
const auto b_info = backend::get_tensor_info(b_id);
47+
const auto out_info = backend::get_tensor_info(out_id);
48+
switch (a_info.dtype) {
49+
case DType::float32:
50+
add_impl(a_info.buf.f32, a_info.size, b_info.buf.f32, b_info.size,
51+
out_info.buf.f32);
52+
break;
53+
case DType::int32:
54+
add_impl(a_info.buf.i32, a_info.size, b_info.buf.i32, b_info.size,
55+
out_info.buf.i32);
56+
break;
57+
case DType::boolean:
58+
add_impl(a_info.buf.b, a_info.size, b_info.buf.b, b_info.size,
59+
out_info.buf.b);
60+
break;
61+
default:
62+
util::warn("Add for tensor ids %d and %d failed. Unknown dtype %d", a_id,
63+
b_id, a_info.dtype);
64+
}
65+
}
66+
67+
} // extern "C"
68+
} // namespace tfjs

tfjs-backend-wasm/src/cc/lib.cc renamed to tfjs-backend-wasm/src/cc/backend.cc

Lines changed: 8 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -13,44 +13,24 @@
1313
* ===========================================================================*/
1414

1515
#include <emscripten.h>
16-
#include <math.h>
1716
#include <xnnpack.h>
1817
#include <cstdio>
1918
#include <map>
2019
#include <vector>
2120

22-
#include "src/cc/kernels.h"
21+
#include "src/cc/backend.h"
2322
#include "src/cc/util.h"
2423

25-
namespace tfjs {
26-
27-
enum DType {
28-
float32 = 0,
29-
int32 = 1,
30-
boolean = 2,
31-
};
32-
33-
// A union of pointers that points to memory for a given tensor.
34-
union DataPtrUnion {
35-
float *f32;
36-
int *i32;
37-
bool *b;
38-
};
39-
40-
// Holds information about a tensor such as dtype, shape and pointer to its data
41-
// in memory.
42-
struct TensorInfo {
43-
// Pointer to the bytes where the data is allocated.
44-
DataPtrUnion buf;
45-
DType dtype;
46-
std::vector<int> shape;
47-
// Total number of elements.
48-
int size;
49-
};
50-
24+
namespace {
5125
// Maps a unique tensor id to info about that tensor. The map owns all of its
5226
// entries.
5327
std::map<int, TensorInfo> data;
28+
} // namespace
29+
30+
namespace tfjs {
31+
namespace backend {
32+
TensorInfo get_tensor_info(int tensor_id) { return data.at(tensor_id); }
33+
} // namespace backend
5434

5535
// We use C-style API to interface with Javascript.
5636
extern "C" {
@@ -103,52 +83,6 @@ void dispose_data(int data_id) {
10383
data.erase(data_id);
10484
}
10585

106-
EMSCRIPTEN_KEEPALIVE
107-
void add(int a_id, int b_id, int out_id) {
108-
const auto a_info = data.at(a_id);
109-
const auto b_info = data.at(b_id);
110-
const auto out_info = data.at(out_id);
111-
switch (a_info.dtype) {
112-
case DType::float32:
113-
kernels::add(a_info.buf.f32, a_info.size, b_info.buf.f32, b_info.size,
114-
out_info.buf.f32);
115-
break;
116-
case DType::int32:
117-
kernels::add(a_info.buf.i32, a_info.size, b_info.buf.i32, b_info.size,
118-
out_info.buf.i32);
119-
break;
120-
case DType::boolean:
121-
kernels::add(a_info.buf.b, a_info.size, b_info.buf.b, b_info.size,
122-
out_info.buf.b);
123-
break;
124-
default:
125-
util::warn("Add for tensor ids %d and %d failed. Unknown dtype %d", a_id,
126-
b_id, a_info.dtype);
127-
}
128-
}
129-
130-
EMSCRIPTEN_KEEPALIVE
131-
void batchMatMul(int a_id, int b_id, int shared_dim, int left_dim,
132-
int right_dim, int batch_dim, int a_batch, int a_outer_step,
133-
int a_inner_step, int b_batch, int b_outer_step,
134-
int b_inner_step, int out_id) {
135-
const auto a_info = data.at(a_id);
136-
const auto b_info = data.at(b_id);
137-
const auto out_info = data.at(out_id);
138-
switch (a_info.dtype) {
139-
case DType::float32:
140-
kernels::batchMatMul(a_info.buf.f32, b_info.buf.f32, shared_dim, left_dim,
141-
right_dim, batch_dim, a_batch, a_outer_step,
142-
a_inner_step, b_batch, b_outer_step, b_inner_step,
143-
out_info.buf.f32);
144-
break;
145-
default:
146-
util::warn(
147-
"batchMatMul for tensor ids %d and %d failed. Unknown dtype %d", a_id,
148-
b_id, a_info.dtype);
149-
}
150-
}
151-
15286
EMSCRIPTEN_KEEPALIVE
15387
void dispose() {
15488
for (auto const &element : data) {

0 commit comments

Comments
 (0)