Skip to content
Merged
71 changes: 71 additions & 0 deletions elasticdl_preprocessing/layers/round_identity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import tensorflow as tf
from tensorflow.python.ops.ragged import ragged_functional_ops, ragged_tensor


class RoundIdentity(tf.keras.layers.Layer):
"""Cast a numeric feature into a discrete integer value.

This layer transforms numeric inputs to integer output. It is a special
case of bucketizing to bins. The max value in the layer is the number of
bins.

Example :
```python
layer = RoundIdentity(max_value=5)
inp = np.asarray([[1.2], [1.6], [0.2], [3.1], [4.9]])
layer(inp)
[[1], [2], [0], [3], [5]]
```

Arguments:
num_buckets: Range of inputs and outputs is `[0, num_buckets)`.
**kwargs: Keyword arguments to construct a layer.

Input shape: A numeric `Tensor`, `SparseTensor` or `RaggedTensor` of shape
`[batch_size, d1, ..., dm]`

Output shape: An int64 tensor of shape `[batch_size, d1, ..., dm]`

"""

def __init__(self, num_buckets, default_value=0):
super(RoundIdentity, self).__init__()
self.num_buckets = tf.cast(num_buckets, tf.int64)
self.default_value = tf.cast(default_value, tf.int64)

def call(self, inputs):
if isinstance(inputs, tf.SparseTensor):
id_values = self._round_and_truncate(inputs.values)
result = tf.SparseTensor(
indices=inputs.indices,
values=id_values,
dense_shape=inputs.dense_shape,
)
elif ragged_tensor.is_ragged(inputs):
result = ragged_functional_ops.map_flat_values(
self._round_and_truncate, inputs
)
else:
result = self._round_and_truncate(inputs)
return tf.cast(result, tf.int64)

def _round_and_truncate(self, values):
values = tf.keras.backend.round(values)
values = tf.cast(values, tf.int64)
values = tf.where(
tf.logical_or(values < 0, values > self.num_buckets),
x=tf.fill(dims=tf.shape(values), value=self.default_value),
y=values,
)
return values

def compute_output_shape(self, input_shape):
return input_shape

def get_config(self):
config = {
"num_buckets": self.num_buckets,
"default_value": self.default_value,
}
base_config = super(RoundIdentity, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
34 changes: 34 additions & 0 deletions elasticdl_preprocessing/tests/round_identity_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import unittest

import numpy as np
import tensorflow as tf

from elasticdl_preprocessing.layers.round_identity import RoundIdentity
from elasticdl_preprocessing.tests.test_utils import (
ragged_tensor_equal,
sparse_tensor_equal,
)


class RoundIdentityTest(unittest.TestCase):
def test_round_indentity(self):
round_identity = RoundIdentity(num_buckets=10)

dense_input = tf.constant([[1.2], [1.6], [0.2], [3.1], [4.9]])
output = round_identity(dense_input)
expected_out = np.array([[1], [2], [0], [3], [5]])
self.assertTrue(np.array_equal(output.numpy(), expected_out))

ragged_input = tf.ragged.constant([[1.1, 3.4], [0.5]])
ragged_output = round_identity(ragged_input)
expected_ragged_out = tf.ragged.constant([[1, 3], [0]], dtype=tf.int64)
self.assertTrue(
ragged_tensor_equal(ragged_output, expected_ragged_out)
)

sparse_input = ragged_input.to_sparse()
sparse_output = round_identity(sparse_input)
expected_sparse_out = expected_ragged_out.to_sparse()
self.assertTrue(
sparse_tensor_equal(sparse_output, expected_sparse_out)
)
27 changes: 27 additions & 0 deletions elasticdl_preprocessing/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import numpy as np
import tensorflow as tf
from tensorflow.python.ops.ragged import ragged_tensor


def sparse_tensor_equal(sp_a, sp_b):
Expand All @@ -15,3 +17,28 @@ def sparse_tensor_equal(sp_a, sp_b):
return False

return True


def ragged_tensor_equal(rt_a, rt_b):
print(rt_a, rt_b)
if rt_a.shape.as_list() != rt_b.shape.as_list():
return False

for i in range(rt_a.shape[0]):
sub_rt_a = rt_a[i]
sub_rt_b = rt_b[i]
if ragged_tensor.is_ragged(sub_rt_a) and ragged_tensor.is_ragged(
sub_rt_b
):
if not ragged_tensor_equal(sub_rt_a, sub_rt_b):
return False
elif isinstance(sub_rt_a, tf.Tensor) and isinstance(
sub_rt_b, tf.Tensor
):
if sub_rt_a.dtype != sub_rt_b.dtype:
return False
if not np.array_equal(sub_rt_a.numpy(), sub_rt_b.numpy()):
return False
else:
return False
return True