Skip to content

Commit 886fca5

Browse files
Merge pull request #15473 from diggerk:index_lookup_sparse_config
PiperOrigin-RevId: 401828338
2 parents 02fb8a0 + 0c807df commit 886fca5

File tree

2 files changed

+25
-0
lines changed

2 files changed

+25
-0
lines changed

keras/layers/preprocessing/index_lookup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,7 @@ def get_config(self):
363363
"oov_token": self.oov_token,
364364
"mask_token": self.mask_token,
365365
"output_mode": self.output_mode,
366+
"sparse": self.sparse,
366367
"pad_to_max_tokens": self.pad_to_max_tokens,
367368
"vocabulary": utils.listify_tensors(self.input_vocabulary),
368369
"vocabulary_dtype": self.vocabulary_dtype,

keras/layers/preprocessing/index_lookup_test.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2365,6 +2365,30 @@ def test_static_table_config_weight_data_transfer_succeeds(self):
23652365
new_output_dataset = model.predict(input_array)
23662366
self.assertAllEqual(new_output_dataset, expected_output)
23672367

2368+
def test_sparse_output_across_saving(self):
2369+
vocab_data = ["earth", "wind", "and", "fire"]
2370+
input_array = np.array([["earth", "wind", "and", "fire"],
2371+
["fire", "and", "earth", "michigan"]])
2372+
2373+
expected_output = [[0., 1., 1., 1., 1.], [1., 1., 0., 1., 1.]]
2374+
2375+
layer_cls = index_lookup.IndexLookup
2376+
layer = layer_cls(
2377+
max_tokens=None,
2378+
num_oov_indices=1,
2379+
mask_token="",
2380+
oov_token="[OOV]",
2381+
vocabulary_dtype=tf.string,
2382+
vocabulary=vocab_data,
2383+
output_mode="multi_hot",
2384+
sparse=True)
2385+
config = layer.get_config()
2386+
layer = layer_cls.from_config(config)
2387+
2388+
output = layer(input_array)
2389+
self.assertIsInstance(output, tf.SparseTensor)
2390+
self.assertAllEqual(tf.sparse.to_dense(output), expected_output)
2391+
23682392

23692393
class EagerExecutionDisabled(keras_parameterized.TestCase,
23702394
preprocessing_test_utils.PreprocessingLayerTest):

0 commit comments

Comments
 (0)