diff --git a/elasticdl_preprocessing/layers/round_identity.py b/elasticdl_preprocessing/layers/round_identity.py new file mode 100644 index 000000000..066de78b5 --- /dev/null +++ b/elasticdl_preprocessing/layers/round_identity.py @@ -0,0 +1,71 @@ +import tensorflow as tf +from tensorflow.python.ops.ragged import ragged_functional_ops, ragged_tensor + + +class RoundIdentity(tf.keras.layers.Layer): + """Cast a numeric feature into a discrete integer value. + + This layer transforms numeric inputs to integer output. It is a special + case of bucketizing to bins. The max value in the layer is the number of + bins. + + Example : + ```python + layer = RoundIdentity(max_value=5) + inp = np.asarray([[1.2], [1.6], [0.2], [3.1], [4.9]]) + layer(inp) + [[1], [2], [0], [3], [5]] + ``` + + Arguments: + num_buckets: Range of inputs and outputs is `[0, num_buckets)`. + **kwargs: Keyword arguments to construct a layer. + + Input shape: A numeric `Tensor`, `SparseTensor` or `RaggedTensor` of shape + `[batch_size, d1, ..., dm]` + + Output shape: An int64 tensor of shape `[batch_size, d1, ..., dm]` + + """ + + def __init__(self, num_buckets, default_value=0): + super(RoundIdentity, self).__init__() + self.num_buckets = tf.cast(num_buckets, tf.int64) + self.default_value = tf.cast(default_value, tf.int64) + + def call(self, inputs): + if isinstance(inputs, tf.SparseTensor): + id_values = self._round_and_truncate(inputs.values) + result = tf.SparseTensor( + indices=inputs.indices, + values=id_values, + dense_shape=inputs.dense_shape, + ) + elif ragged_tensor.is_ragged(inputs): + result = ragged_functional_ops.map_flat_values( + self._round_and_truncate, inputs + ) + else: + result = self._round_and_truncate(inputs) + return tf.cast(result, tf.int64) + + def _round_and_truncate(self, values): + values = tf.keras.backend.round(values) + values = tf.cast(values, tf.int64) + values = tf.where( + tf.logical_or(values < 0, values > self.num_buckets), + x=tf.fill(dims=tf.shape(values), value=self.default_value), + y=values, + ) + return values + + def compute_output_shape(self, input_shape): + return input_shape + + def get_config(self): + config = { + "num_buckets": self.num_buckets, + "default_value": self.default_value, + } + base_config = super(RoundIdentity, self).get_config() + return dict(list(base_config.items()) + list(config.items())) diff --git a/elasticdl_preprocessing/tests/round_identity_test.py b/elasticdl_preprocessing/tests/round_identity_test.py new file mode 100644 index 000000000..4e3599c2a --- /dev/null +++ b/elasticdl_preprocessing/tests/round_identity_test.py @@ -0,0 +1,34 @@ +import unittest + +import numpy as np +import tensorflow as tf + +from elasticdl_preprocessing.layers.round_identity import RoundIdentity +from elasticdl_preprocessing.tests.test_utils import ( + ragged_tensor_equal, + sparse_tensor_equal, +) + + +class RoundIdentityTest(unittest.TestCase): + def test_round_indentity(self): + round_identity = RoundIdentity(num_buckets=10) + + dense_input = tf.constant([[1.2], [1.6], [0.2], [3.1], [4.9]]) + output = round_identity(dense_input) + expected_out = np.array([[1], [2], [0], [3], [5]]) + self.assertTrue(np.array_equal(output.numpy(), expected_out)) + + ragged_input = tf.ragged.constant([[1.1, 3.4], [0.5]]) + ragged_output = round_identity(ragged_input) + expected_ragged_out = tf.ragged.constant([[1, 3], [0]], dtype=tf.int64) + self.assertTrue( + ragged_tensor_equal(ragged_output, expected_ragged_out) + ) + + sparse_input = ragged_input.to_sparse() + sparse_output = round_identity(sparse_input) + expected_sparse_out = expected_ragged_out.to_sparse() + self.assertTrue( + sparse_tensor_equal(sparse_output, expected_sparse_out) + ) diff --git a/elasticdl_preprocessing/tests/test_utils.py b/elasticdl_preprocessing/tests/test_utils.py index bf1370b13..001db7cc5 100644 --- a/elasticdl_preprocessing/tests/test_utils.py +++ b/elasticdl_preprocessing/tests/test_utils.py @@ -1,4 +1,6 @@ import numpy as np +import tensorflow as tf +from tensorflow.python.ops.ragged import ragged_tensor def sparse_tensor_equal(sp_a, sp_b): @@ -15,3 +17,28 @@ def sparse_tensor_equal(sp_a, sp_b): return False return True + + +def ragged_tensor_equal(rt_a, rt_b): + print(rt_a, rt_b) + if rt_a.shape.as_list() != rt_b.shape.as_list(): + return False + + for i in range(rt_a.shape[0]): + sub_rt_a = rt_a[i] + sub_rt_b = rt_b[i] + if ragged_tensor.is_ragged(sub_rt_a) and ragged_tensor.is_ragged( + sub_rt_b + ): + if not ragged_tensor_equal(sub_rt_a, sub_rt_b): + return False + elif isinstance(sub_rt_a, tf.Tensor) and isinstance( + sub_rt_b, tf.Tensor + ): + if sub_rt_a.dtype != sub_rt_b.dtype: + return False + if not np.array_equal(sub_rt_a.numpy(), sub_rt_b.numpy()): + return False + else: + return False + return True