Skip to content

Commit 453e4b9

Browse files
committed
Add Lookup Op and unit tests.
1 parent d1f0283 commit 453e4b9

File tree

9 files changed

+248
-27
lines changed

9 files changed

+248
-27
lines changed

.travis.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ addons:
55
- cmake
66
- python3.6-dev
77
- python3.6-venv
8+
- g++-4.9
9+
810
matrix:
911
fast_finish: true
1012
include:
@@ -16,6 +18,10 @@ matrix:
1618
os: linux
1719
rust: stable
1820

21+
env:
22+
- CC=gcc-4.9
23+
- CXX=g++-4.9
24+
1925
install:
2026
- |
2127
if [ "$TRAVIS_OS_NAME" == "linux" ]; then

ci/script.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@ cd build
77
cmake ..
88
make
99

10-
ctest
10+
ctest -V

finalfusion-tf/kernel/FFLookupKernels.cc

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,4 +56,63 @@ class CloseFFEmbeddingsOp : public OpKernel {
5656
};
5757

5858
REGISTER_KERNEL_BUILDER(
59-
Name("CloseFFEmbeddings").Device(DEVICE_CPU), CloseFFEmbeddingsOp);
59+
Name("CloseFFEmbeddings").Device(DEVICE_CPU),
60+
CloseFFEmbeddingsOp);
61+
62+
class FFLookupOp : public OpKernel {
63+
public:
64+
explicit FFLookupOp(OpKernelConstruction *context) : OpKernel(context) {
65+
OP_REQUIRES_OK(context, context->GetAttr("mask_empty_string", &mask_empty_string_));
66+
OP_REQUIRES_OK(context, context->GetAttr("mask_failed_lookup", &mask_failed_lookup_));
67+
OP_REQUIRES_OK(context, context->GetAttr("embedding_len", &embedding_len_));
68+
}
69+
70+
void Compute(OpKernelContext *context) override {
71+
FFLookup *lookup;
72+
OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), &lookup));
73+
core::ScopedUnref unref(lookup);
74+
75+
// verify length from construction with actual length
76+
size_t const dims = lookup->dimensions();
77+
if (embedding_len_ != -1) {
78+
OP_REQUIRES(context,
79+
(dims == embedding_len_),
80+
errors::InvalidArgument("Actual embedding length (", dims, ") does not match provided length (",
81+
embedding_len_, ")"));
82+
}
83+
84+
// Get input tensor and flatten
85+
Tensor const &query_tensor = context->input(1);
86+
auto query = query_tensor.flat<string>();
87+
88+
// Set output shape: add new dim with dimensionality of embeddings
89+
TensorShape out_shape(query_tensor.shape());
90+
out_shape.AddDim(((int64) dims));
91+
92+
// Create output tensor and flatten
93+
Tensor *output_tensor = nullptr;
94+
OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output_tensor));
95+
auto output_flat = output_tensor->flat<float>();
96+
97+
for (int i = 0; i < query.size(); i++) {
98+
std::vector<float> embedding = lookup->embedding(query(i));
99+
// optionally mask failed lookups and/or empty string. Generally, empty string will lead to a failed lookup.
100+
if ((query(i).empty() && mask_empty_string_) || (mask_failed_lookup_ && embedding.empty())) {
101+
std::memset(&output_flat(i * dims), 0., dims * sizeof(float));
102+
} else {
103+
// if no masking attributes are set and the embedding is empty, return error.
104+
OP_REQUIRES(context, !embedding.empty(), errors::InvalidArgument("Embedding lookup failed for: ", query(i)));
105+
std::memcpy(&output_flat(i * dims), embedding.data(), dims * sizeof(float));
106+
}
107+
}
108+
}
109+
110+
private:
111+
bool mask_empty_string_;
112+
bool mask_failed_lookup_;
113+
int embedding_len_;
114+
};
115+
116+
REGISTER_KERNEL_BUILDER(
117+
Name("FFLookup").Device(DEVICE_CPU),
118+
FFLookupOp);

finalfusion-tf/ops/FFLookupOps.cc

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,27 @@ namespace tensorflow {
2222
REGISTER_OP("CloseFFEmbeddings")
2323
.Input("embeds: resource")
2424
.SetShapeFn(shape_inference::NoOutputs);
25+
26+
REGISTER_OP("FFLookup")
27+
.Input("embeds: resource")
28+
.Input("query: string")
29+
.Attr("embedding_len: int >= -1 = -1")
30+
.Attr("mask_empty_string: bool = true")
31+
.Attr("mask_failed_lookup: bool = true")
32+
.Output("embeddings: float")
33+
.SetShapeFn([](
34+
::tensorflow::shape_inference::InferenceContext *c
35+
) {
36+
ShapeHandle strings_shape = c->input(1);
37+
ShapeHandle output_shape;
38+
int embedding_len;
39+
TF_RETURN_IF_ERROR(c->GetAttr("embedding_len", &embedding_len));
40+
TF_RETURN_IF_ERROR(
41+
c->Concatenate(strings_shape, c->Vector(embedding_len), &output_shape)
42+
);
43+
ShapeHandle embeds = c->output(0);
44+
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &embeds));
45+
c->set_output(0, output_shape);
46+
return Status::OK();
47+
});
2548
}

tests/CMakeLists.txt

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,12 @@ include(CTest)
33
file(COPY testdata/test.fifu DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/data)
44
message(${CMAKE_CURRENT_BINARY_DIR})
55

6-
add_test(NAME python-init-close
7-
COMMAND pytest ${CMAKE_CURRENT_SOURCE_DIR}
6+
add_test(NAME eager-mode
7+
COMMAND pytest ${CMAKE_CURRENT_SOURCE_DIR}/test_eager_mode.py
8+
WORKING_DIRECTORY ${PROJECT_BINARY_DIR}
9+
)
10+
11+
add_test(NAME graph-mode
12+
COMMAND pytest ${CMAKE_CURRENT_SOURCE_DIR}/test_graph_mode.py
813
WORKING_DIRECTORY ${PROJECT_BINARY_DIR}
914
)

tests/conftest.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,14 @@
1-
import os
21
import platform
32

43
import pytest
54
import tensorflow as tf
65

7-
tf.enable_eager_execution()
8-
96

107
@pytest.fixture
11-
def ff_lib(tests_root):
8+
def ff_lib():
129
if platform.system() == "Darwin":
1310
LIB_SUFFIX = ".dylib"
1411
else:
1512
LIB_SUFFIX = ".so"
1613

1714
yield tf.load_op_library("./finalfusion-tf/libfinalfusion_tf" + LIB_SUFFIX)
18-
19-
20-
@pytest.fixture
21-
def tests_root():
22-
yield os.path.dirname(__file__)

tests/test_eager_mode.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import numpy as np
2+
import pytest
3+
import tensorflow as tf
4+
5+
tf.enable_eager_execution()
6+
7+
8+
def test_init_and_close(ff_lib):
9+
embeddings = ff_lib.ff_embeddings()
10+
ff_lib.initialize_ff_embeddings(embeddings, "tests/data/test.fifu", mmap=False)
11+
ff_lib.close_ff_embeddings(embeddings)
12+
13+
14+
def test_init_and_close_mmap(ff_lib):
15+
embeddings = ff_lib.ff_embeddings()
16+
ff_lib.initialize_ff_embeddings(embeddings, "tests/data/test.fifu", mmap=True)
17+
ff_lib.close_ff_embeddings(embeddings)
18+
19+
20+
def test_eager_lookup(ff_lib):
21+
embeddings = ff_lib.ff_embeddings()
22+
ff_lib.initialize_ff_embeddings(embeddings, "tests/data/test.fifu", mmap=False)
23+
24+
ber = ff_lib.ff_lookup(embeddings, "Berlin", mask_empty_string=False, mask_failed_lookup=False)
25+
ber_list = ff_lib.ff_lookup(embeddings, ["Berlin"], mask_empty_string=False, mask_failed_lookup=False)
26+
ber_tensor = ff_lib.ff_lookup(embeddings, [["Berlin"]], mask_empty_string=False, mask_failed_lookup=False)
27+
28+
assert ber.shape == (100,)
29+
assert ber_list.shape == (1, 100)
30+
assert ber_tensor.shape == (1, 1, 100)
31+
32+
ff_lib.close_ff_embeddings(embeddings)
33+
34+
35+
def test_eager_lookup_masked(ff_lib):
36+
embeddings = ff_lib.ff_embeddings()
37+
ff_lib.initialize_ff_embeddings(embeddings, "tests/data/test.fifu", False)
38+
tuebingen_masked = ff_lib.ff_lookup(embeddings, "Tübingen", mask_empty_string=False, mask_failed_lookup=True,
39+
embedding_len=100)
40+
empty_masked = ff_lib.ff_lookup(embeddings, "", mask_empty_string=True, mask_failed_lookup=False, embedding_len=100)
41+
empty_masked_through_fail = ff_lib.ff_lookup(embeddings, "", mask_empty_string=False, mask_failed_lookup=True,
42+
embedding_len=100)
43+
assert np.allclose(tuebingen_masked, 0.)
44+
assert np.allclose(empty_masked, 0.)
45+
assert np.allclose(empty_masked_through_fail, 0.)
46+
ff_lib.close_ff_embeddings(embeddings)
47+
48+
49+
def test_eager_errors(ff_lib):
50+
embeddings = ff_lib.ff_embeddings()
51+
with pytest.raises(tf.errors.UnknownError):
52+
ff_lib.initialize_ff_embeddings(embeddings, "foo.fifu", False)
53+
54+
ff_lib.initialize_ff_embeddings(embeddings, "tests/data/test.fifu", False)
55+
56+
with pytest.raises(tf.errors.AlreadyExistsError):
57+
ff_lib.initialize_ff_embeddings(embeddings, "tests/data/test.fifu", False)
58+
59+
with pytest.raises(tf.errors.InvalidArgumentError):
60+
ff_lib.ff_lookup(embeddings, "Tübingen", mask_empty_string=False, mask_failed_lookup=False, embedding_len=100)
61+
62+
# shape mismatch, 10 vs. actual 100
63+
with pytest.raises(tf.errors.InvalidArgumentError):
64+
ff_lib.ff_lookup(embeddings, "Berlin", mask_empty_string=False, mask_failed_lookup=False, embedding_len=10)
65+
66+
with pytest.raises(tf.errors.InvalidArgumentError):
67+
ff_lib.ff_lookup(embeddings, "", mask_empty_string=False, mask_failed_lookup=False, embedding_len=100)
68+
69+
ff_lib.close_ff_embeddings(embeddings)
70+
with pytest.raises(tf.errors.NotFoundError):
71+
ff_lib.close_ff_embeddings(embeddings)

tests/test_graph_mode.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import numpy as np
2+
import pytest
3+
import tensorflow as tf
4+
5+
6+
def test_graph_lookup(ff_lib):
7+
embeddings = ff_lib.ff_embeddings()
8+
init = ff_lib.initialize_ff_embeddings(embeddings, "tests/data/test.fifu", False)
9+
10+
ber = ff_lib.ff_lookup(embeddings, "Berlin", mask_empty_string=False, mask_failed_lookup=False, embedding_len=100)
11+
assert ber.shape == (100,)
12+
13+
ber_list = ff_lib.ff_lookup(embeddings, ["Berlin"], mask_empty_string=False, mask_failed_lookup=False,
14+
embedding_len=100)
15+
assert ber_list.shape == (1, 100)
16+
17+
ber_tensor = ff_lib.ff_lookup(embeddings, [["Berlin"]], mask_empty_string=False, mask_failed_lookup=False,
18+
embedding_len=100)
19+
assert ber_tensor.shape == (1, 1, 100)
20+
21+
ber_no_shape = ff_lib.ff_lookup(embeddings, "Berlin", mask_empty_string=False, mask_failed_lookup=False)
22+
assert ber_no_shape.shape.rank == 1
23+
assert ber_no_shape.shape[0].value is None
24+
25+
ber_list_no_shape = ff_lib.ff_lookup(embeddings, ["Berlin"], mask_empty_string=False, mask_failed_lookup=False)
26+
assert ber_list_no_shape.shape.rank == 2
27+
assert ber_list_no_shape.shape[0].value == tf.Dimension(1)
28+
assert ber_list_no_shape.shape[1].value is None
29+
30+
with tf.Session() as sess:
31+
sess.run([init])
32+
res = sess.run([ber, ber_list, ber_tensor])
33+
assert res[0].shape == (100,)
34+
assert res[1].shape == (1, 100)
35+
assert res[2].shape == (1, 1, 100)
36+
sess.run([ff_lib.close_ff_embeddings(embeddings)])
37+
38+
39+
def test_graph_lookup_masked(ff_lib):
40+
embeddings = ff_lib.ff_embeddings()
41+
init = ff_lib.initialize_ff_embeddings(embeddings, "tests/data/test.fifu", True)
42+
tuebingen_masked = ff_lib.ff_lookup(embeddings, "Tübingen", mask_empty_string=False, mask_failed_lookup=True,
43+
embedding_len=100)
44+
empty_masked = ff_lib.ff_lookup(embeddings, "", mask_empty_string=True, mask_failed_lookup=False, embedding_len=100)
45+
empty_masked_through_fail = ff_lib.ff_lookup(embeddings, "", mask_empty_string=False, mask_failed_lookup=True,
46+
embedding_len=100)
47+
with tf.Session() as sess:
48+
sess.run([init])
49+
res = sess.run([tuebingen_masked, empty_masked, empty_masked_through_fail])
50+
assert np.allclose(res, 0.)
51+
52+
53+
def test_graph_errors(ff_lib):
54+
embeddings = ff_lib.ff_embeddings()
55+
tuebingen_unmasked = ff_lib.ff_lookup(embeddings, "Tübingen", mask_empty_string=False, mask_failed_lookup=False,
56+
embedding_len=100)
57+
ber_bad_shape = ff_lib.ff_lookup(embeddings, "Berlin", mask_empty_string=False, mask_failed_lookup=False,
58+
embedding_len=10)
59+
assert ber_bad_shape.shape == (10,)
60+
empty_unmasked = ff_lib.ff_lookup(embeddings, "", mask_empty_string=False, mask_failed_lookup=False,
61+
embedding_len=100)
62+
63+
with tf.Session() as sess:
64+
with pytest.raises(tf.errors.UnknownError):
65+
sess.run([ff_lib.initialize_ff_embeddings(embeddings, "foo.fifu", False)])
66+
67+
sess.run([ff_lib.initialize_ff_embeddings(embeddings, "tests/data/test.fifu", False)])
68+
69+
with pytest.raises(tf.errors.AlreadyExistsError):
70+
sess.run([ff_lib.initialize_ff_embeddings(embeddings, "tests/data/test.fifu", False)])
71+
with pytest.raises(tf.errors.InvalidArgumentError):
72+
sess.run([tuebingen_unmasked])
73+
with pytest.raises(tf.errors.InvalidArgumentError):
74+
sess.run([empty_unmasked])
75+
with pytest.raises(tf.errors.InvalidArgumentError):
76+
sess.run([ber_bad_shape])
77+
sess.run([ff_lib.close_ff_embeddings(embeddings)])
78+
with pytest.raises(tf.errors.NotFoundError):
79+
sess.run([ff_lib.close_ff_embeddings(embeddings)])

tests/test_init_close.py

Lines changed: 0 additions & 14 deletions
This file was deleted.

0 commit comments

Comments
 (0)