From be3c822fbbfeacf44d377d78298982d442e07af9 Mon Sep 17 00:00:00 2001 From: workingloong Date: Tue, 17 Mar 2020 11:27:27 +0800 Subject: [PATCH 1/7] Add RoundIdentity layer to transform numeric values to integer ids --- elasticdl_preprocessing/layers/__init__.py | 0 .../layers/round_identity.py | 78 +++++++++++++++++++ elasticdl_preprocessing/tests/__init__.py | 0 .../tests/round_identity_test.py | 27 +++++++ 4 files changed, 105 insertions(+) create mode 100644 elasticdl_preprocessing/layers/__init__.py create mode 100644 elasticdl_preprocessing/layers/round_identity.py create mode 100644 elasticdl_preprocessing/tests/__init__.py create mode 100644 elasticdl_preprocessing/tests/round_identity_test.py diff --git a/elasticdl_preprocessing/layers/__init__.py b/elasticdl_preprocessing/layers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/elasticdl_preprocessing/layers/round_identity.py b/elasticdl_preprocessing/layers/round_identity.py new file mode 100644 index 000000000..0cdf76295 --- /dev/null +++ b/elasticdl_preprocessing/layers/round_identity.py @@ -0,0 +1,78 @@ +import tensorflow as tf + +from tensorflow.python.framework import dtypes +from tensorflow.python.ops.ragged import ragged_tensor +from tensorflow.python.ops.ragged import ragged_functional_ops + +NUMERIC_TYPES = frozenset( + [dtypes.float32, dtypes.float64, dtypes.int8, dtypes.int16, dtypes.int32, + dtypes.int64, dtypes.uint8, dtypes.qint8, dtypes.qint32, dtypes.quint8, + dtypes.complex64]) + + +class RoundIdentity(tf.keras.layers.Layer): + """Implements numeric feature roundding with a max value. + + This layer transforms numeric inputs to integer output. It is a special + case of bucketizing to bins. The max value in the layer is the number of + bins. + + Example : + ```python + layer = RoundIdentity(max_value=5) + inp = np.asarray([[1.2], [1.6], [0.2], [3.1], [4.9]]) + layer(inputs) + [[1], [2], [0], [3], [5]] + ``` + + Arguments: + num_buckets: Range of inputs and outputs is `[0, num_buckets)`. + **kwargs: Keyword arguments to construct a layer. + + Input shape: A numeric tensor of shape + `[batch_size, d1, ..., dm]` + + Output shape: An int64 tensor of shape `[batch_size, d1, ..., dm]` + + """ + def __init__(self, num_buckets, default_value=0): + super(RoundIdentity, self).__init__() + self.num_buckets = tf.cast(num_buckets, tf.int64) + self.default_value = tf.cast(default_value, tf.int64) + + def call(self, inputs): + if isinstance(inputs, tf.SparseTensor): + id_values = self._round_and_truncate(inputs.values) + result = tf.SparseTensor( + indices=inputs.indices, + values=id_values, + dense_shape=inputs.dense_shape, + ) + elif ragged_tensor.is_ragged(inputs): + result = ragged_functional_ops.map_flat_values( + self._round_and_truncate, inputs + ) + else: + result = self._round_and_truncate(inputs) + return tf.cast(result, tf.int64) + + def _round_and_truncate(self, values): + values = tf.keras.backend.round(values) + values = tf.cast(values, tf.int64) + values = tf.where( + tf.logical_or(values < 0, values > self.num_buckets), + x=tf.fill(dims=tf.shape(values), value=self.default_value), + y=values + ) + return values + + def compute_output_shape(self, input_shape): + return input_shape + + def get_config(self): + config = { + 'num_buckets': self.num_buckets, + 'default_value': self.default_value, + } + base_config = super(RoundIdentity, self).get_config() + return dict(list(base_config.items()) + list(config.items())) \ No newline at end of file diff --git a/elasticdl_preprocessing/tests/__init__.py b/elasticdl_preprocessing/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/elasticdl_preprocessing/tests/round_identity_test.py b/elasticdl_preprocessing/tests/round_identity_test.py new file mode 100644 index 000000000..6e1a657cd --- /dev/null +++ b/elasticdl_preprocessing/tests/round_identity_test.py @@ -0,0 +1,27 @@ +import os +import unittest +import numpy as np +import tensorflow as tf + +from elasticdl_preprocessing.layers.round_identity import RoundIdentity + + +class RoundIdentityTest(unittest.TestCase): + def test_round_indentity(self): + round_identity = RoundIdentity(num_buckets=10) + + dense_input = tf.constant([[1.2], [1.6], [0.2], [3.1], [4.9]]) + output = round_identity(dense_input) + expected_out = np.array([[1], [2], [0], [3], [5]]) + self.assertTrue(np.array_equal(output.numpy(), expected_out)) + + ragged_input = tf.ragged.constant([[1.1, 3.4], [0.5]]) + ragged_output = round_identity(ragged_input) + ragged_output_values = ragged_output.values.numpy() + expected_out = np.array([1.0, 3.0, 0.0]) + self.assertTrue(np.array_equal(ragged_output_values, expected_out)) + + sparse_input = ragged_input.to_sparse() + sparse_output = round_identity(sparse_input) + sparse_output_values = sparse_output.values + self.assertTrue(np.array_equal(sparse_output_values, expected_out)) From ea319da4273f542c271e12575f25cfde4aea46de Mon Sep 17 00:00:00 2001 From: workingloong Date: Tue, 17 Mar 2020 14:43:08 +0800 Subject: [PATCH 2/7] Fix the docstring --- elasticdl_preprocessing/layers/round_identity.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/elasticdl_preprocessing/layers/round_identity.py b/elasticdl_preprocessing/layers/round_identity.py index 0cdf76295..6766fe230 100644 --- a/elasticdl_preprocessing/layers/round_identity.py +++ b/elasticdl_preprocessing/layers/round_identity.py @@ -4,14 +4,9 @@ from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.ops.ragged import ragged_functional_ops -NUMERIC_TYPES = frozenset( - [dtypes.float32, dtypes.float64, dtypes.int8, dtypes.int16, dtypes.int32, - dtypes.int64, dtypes.uint8, dtypes.qint8, dtypes.qint32, dtypes.quint8, - dtypes.complex64]) - class RoundIdentity(tf.keras.layers.Layer): - """Implements numeric feature roundding with a max value. + """Implements numeric feature rounding with a max value. This layer transforms numeric inputs to integer output. It is a special case of bucketizing to bins. The max value in the layer is the number of From 69495b669adf015739963a4b569cd03e7104b930 Mon Sep 17 00:00:00 2001 From: workingloong Date: Tue, 17 Mar 2020 14:45:52 +0800 Subject: [PATCH 3/7] Config pre-commit with elasticdl_preprocessing --- .travis.yml | 2 +- elasticdl_preprocessing/layers/round_identity.py | 14 ++++++-------- .../tests/round_identity_test.py | 2 +- 3 files changed, 8 insertions(+), 10 deletions(-) diff --git a/.travis.yml b/.travis.yml index 04d7c298f..44a4c2d25 100644 --- a/.travis.yml +++ b/.travis.yml @@ -28,7 +28,7 @@ jobs: name: "Pre-commit Check" script: - docker build --target dev -t elasticdl:dev -f elasticdl/docker/Dockerfile . - - docker run --rm -it -v $PWD:/work -w /work elasticdl:dev bash -c "pre-commit run --files $(find elasticdl/python model_zoo setup.py scripts/ -name '*.py' -print0 | tr '\0' ' ') $(find elasticdl/pkg -name '*.go' -print0 | tr '\0' ' ')" + - docker run --rm -it -v $PWD:/work -w /work elasticdl:dev bash -c "pre-commit run --files $(find elasticdl/python elasticdl_preprocessing model_zoo setup.py scripts/ -name '*.py' -print0 | tr '\0' ' ') $(find elasticdl/pkg -name '*.go' -print0 | tr '\0' ' ')" - stage: tests name: "Tests" script: diff --git a/elasticdl_preprocessing/layers/round_identity.py b/elasticdl_preprocessing/layers/round_identity.py index 6766fe230..caa736fb5 100644 --- a/elasticdl_preprocessing/layers/round_identity.py +++ b/elasticdl_preprocessing/layers/round_identity.py @@ -1,8 +1,5 @@ import tensorflow as tf - -from tensorflow.python.framework import dtypes -from tensorflow.python.ops.ragged import ragged_tensor -from tensorflow.python.ops.ragged import ragged_functional_ops +from tensorflow.python.ops.ragged import ragged_functional_ops, ragged_tensor class RoundIdentity(tf.keras.layers.Layer): @@ -30,6 +27,7 @@ class RoundIdentity(tf.keras.layers.Layer): Output shape: An int64 tensor of shape `[batch_size, d1, ..., dm]` """ + def __init__(self, num_buckets, default_value=0): super(RoundIdentity, self).__init__() self.num_buckets = tf.cast(num_buckets, tf.int64) @@ -57,7 +55,7 @@ def _round_and_truncate(self, values): values = tf.where( tf.logical_or(values < 0, values > self.num_buckets), x=tf.fill(dims=tf.shape(values), value=self.default_value), - y=values + y=values, ) return values @@ -66,8 +64,8 @@ def compute_output_shape(self, input_shape): def get_config(self): config = { - 'num_buckets': self.num_buckets, - 'default_value': self.default_value, + "num_buckets": self.num_buckets, + "default_value": self.default_value, } base_config = super(RoundIdentity, self).get_config() - return dict(list(base_config.items()) + list(config.items())) \ No newline at end of file + return dict(list(base_config.items()) + list(config.items())) diff --git a/elasticdl_preprocessing/tests/round_identity_test.py b/elasticdl_preprocessing/tests/round_identity_test.py index 6e1a657cd..1425f6419 100644 --- a/elasticdl_preprocessing/tests/round_identity_test.py +++ b/elasticdl_preprocessing/tests/round_identity_test.py @@ -1,5 +1,5 @@ -import os import unittest + import numpy as np import tensorflow as tf From 4a25f3471d122dcbdca96efbe3f82e02d6764686 Mon Sep 17 00:00:00 2001 From: workingloong Date: Thu, 19 Mar 2020 11:15:40 +0800 Subject: [PATCH 4/7] Fix the docstring by comments --- elasticdl_preprocessing/layers/round_identity.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/elasticdl_preprocessing/layers/round_identity.py b/elasticdl_preprocessing/layers/round_identity.py index caa736fb5..81048bbb1 100644 --- a/elasticdl_preprocessing/layers/round_identity.py +++ b/elasticdl_preprocessing/layers/round_identity.py @@ -21,7 +21,7 @@ class RoundIdentity(tf.keras.layers.Layer): num_buckets: Range of inputs and outputs is `[0, num_buckets)`. **kwargs: Keyword arguments to construct a layer. - Input shape: A numeric tensor of shape + Input shape: A numeric `Tensor`, `SparseTensor` or `RaggedTensor` of shape `[batch_size, d1, ..., dm]` Output shape: An int64 tensor of shape `[batch_size, d1, ..., dm]` From 9dca79cbfbf41bd2e619c0ebbb6c7eae31db0a47 Mon Sep 17 00:00:00 2001 From: workingloong Date: Thu, 19 Mar 2020 15:33:53 +0800 Subject: [PATCH 5/7] Add a method to check ragged tensors equal --- .../tests/round_identity_test.py | 17 ++++++++---- elasticdl_preprocessing/tests/test_utils.py | 27 +++++++++++++++++++ 2 files changed, 39 insertions(+), 5 deletions(-) diff --git a/elasticdl_preprocessing/tests/round_identity_test.py b/elasticdl_preprocessing/tests/round_identity_test.py index 1425f6419..4e3599c2a 100644 --- a/elasticdl_preprocessing/tests/round_identity_test.py +++ b/elasticdl_preprocessing/tests/round_identity_test.py @@ -4,6 +4,10 @@ import tensorflow as tf from elasticdl_preprocessing.layers.round_identity import RoundIdentity +from elasticdl_preprocessing.tests.test_utils import ( + ragged_tensor_equal, + sparse_tensor_equal, +) class RoundIdentityTest(unittest.TestCase): @@ -17,11 +21,14 @@ def test_round_indentity(self): ragged_input = tf.ragged.constant([[1.1, 3.4], [0.5]]) ragged_output = round_identity(ragged_input) - ragged_output_values = ragged_output.values.numpy() - expected_out = np.array([1.0, 3.0, 0.0]) - self.assertTrue(np.array_equal(ragged_output_values, expected_out)) + expected_ragged_out = tf.ragged.constant([[1, 3], [0]], dtype=tf.int64) + self.assertTrue( + ragged_tensor_equal(ragged_output, expected_ragged_out) + ) sparse_input = ragged_input.to_sparse() sparse_output = round_identity(sparse_input) - sparse_output_values = sparse_output.values - self.assertTrue(np.array_equal(sparse_output_values, expected_out)) + expected_sparse_out = expected_ragged_out.to_sparse() + self.assertTrue( + sparse_tensor_equal(sparse_output, expected_sparse_out) + ) diff --git a/elasticdl_preprocessing/tests/test_utils.py b/elasticdl_preprocessing/tests/test_utils.py index bf1370b13..3dfa43550 100644 --- a/elasticdl_preprocessing/tests/test_utils.py +++ b/elasticdl_preprocessing/tests/test_utils.py @@ -1,4 +1,6 @@ import numpy as np +import tensorflow as tf +from tensorflow.python.ops.ragged import ragged_functional_ops, ragged_tensor def sparse_tensor_equal(sp_a, sp_b): @@ -15,3 +17,28 @@ def sparse_tensor_equal(sp_a, sp_b): return False return True + + +def ragged_tensor_equal(rt_a, rt_b): + print(rt_a, rt_b) + if rt_a.shape.as_list() != rt_b.shape.as_list(): + return False + + for i in range(rt_a.shape[0]): + sub_rt_a = rt_a[i] + sub_rt_b = rt_b[i] + if ragged_tensor.is_ragged(sub_rt_a) and ragged_tensor.is_ragged( + sub_rt_b + ): + if not ragged_tensor_equal(sub_rt_a, sub_rt_b): + return False + elif isinstance(sub_rt_a, tf.Tensor) and isinstance( + sub_rt_b, tf.Tensor + ): + if sub_rt_a.dtype != sub_rt_b.dtype: + return False + if not np.array_equal(sub_rt_a.numpy(), sub_rt_b.numpy()): + return False + else: + return False + return True From 1c39600cb70c676d9b5677379db3bd2c60585921 Mon Sep 17 00:00:00 2001 From: workingloong Date: Thu, 19 Mar 2020 17:51:14 +0800 Subject: [PATCH 6/7] Remove unused import --- elasticdl_preprocessing/layers/round_identity.py | 2 +- elasticdl_preprocessing/tests/test_utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/elasticdl_preprocessing/layers/round_identity.py b/elasticdl_preprocessing/layers/round_identity.py index 81048bbb1..47a8b040b 100644 --- a/elasticdl_preprocessing/layers/round_identity.py +++ b/elasticdl_preprocessing/layers/round_identity.py @@ -3,7 +3,7 @@ class RoundIdentity(tf.keras.layers.Layer): - """Implements numeric feature rounding with a max value. + """Cast a numeric feature into a discrete integer value. This layer transforms numeric inputs to integer output. It is a special case of bucketizing to bins. The max value in the layer is the number of diff --git a/elasticdl_preprocessing/tests/test_utils.py b/elasticdl_preprocessing/tests/test_utils.py index 3dfa43550..001db7cc5 100644 --- a/elasticdl_preprocessing/tests/test_utils.py +++ b/elasticdl_preprocessing/tests/test_utils.py @@ -1,6 +1,6 @@ import numpy as np import tensorflow as tf -from tensorflow.python.ops.ragged import ragged_functional_ops, ragged_tensor +from tensorflow.python.ops.ragged import ragged_tensor def sparse_tensor_equal(sp_a, sp_b): From fab901a7f28627d0a8300471c3888c32e478d058 Mon Sep 17 00:00:00 2001 From: workingloong Date: Thu, 19 Mar 2020 19:27:44 +0800 Subject: [PATCH 7/7] Fix example --- elasticdl_preprocessing/layers/round_identity.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/elasticdl_preprocessing/layers/round_identity.py b/elasticdl_preprocessing/layers/round_identity.py index 47a8b040b..066de78b5 100644 --- a/elasticdl_preprocessing/layers/round_identity.py +++ b/elasticdl_preprocessing/layers/round_identity.py @@ -13,7 +13,7 @@ class RoundIdentity(tf.keras.layers.Layer): ```python layer = RoundIdentity(max_value=5) inp = np.asarray([[1.2], [1.6], [0.2], [3.1], [4.9]]) - layer(inputs) + layer(inp) [[1], [2], [0], [3], [5]] ```