From 035bc94fee8106744d598eb83e8f9aa4cc5ea86b Mon Sep 17 00:00:00 2001 From: workingloong Date: Fri, 20 Mar 2020 14:20:55 +0800 Subject: [PATCH 1/9] Layer to convert dense tensor to sparse tensor dropping ignore_value cells --- elasticdl_preprocessing/layers/to_sparse.py | 48 +++++++++++++++++++ .../tests/to_sparse_test.py | 33 +++++++++++++ 2 files changed, 81 insertions(+) create mode 100644 elasticdl_preprocessing/layers/to_sparse.py create mode 100644 elasticdl_preprocessing/tests/to_sparse_test.py diff --git a/elasticdl_preprocessing/layers/to_sparse.py b/elasticdl_preprocessing/layers/to_sparse.py new file mode 100644 index 000000000..6a3ce0c3f --- /dev/null +++ b/elasticdl_preprocessing/layers/to_sparse.py @@ -0,0 +1,48 @@ +import tensorflow as tf + + +class ToSparse(tf.keras.layers.Layer): + """Converts a `Tensor` to a `SparseTensor`, dropping ignore_value cells. + If `input_tensor` is already a `SparseTensor`, just return it. + + Example : + ```python + layer = ToSparse() + inp = tf.constant([["A", ""], ["B", "C"]], tf.string) + layer.call(inp) + tf.SparseTensor( + indices=np.array([[0, 0], [1, 0], [1, 1]]), + values=np.array(["A", "B", "C"]), + dense_shape=(2, 2), + ) + ``` + + Arguments: + ignore_value: Entries in `dense_tensor` equal to this value will be + absent from the output `SparseTensor`. If `None`, default value of + `dense_tensor`'s dtype will be used ('' for `str`, -1 for `int`). + + Input shape: A numeric or string `Tensor` of shape + `[batch_size, d1, ..., dm]` + + Output shape: An `SparseTensor` with the same shape as inputs + """ + def __init__(self, ignore_value=None): + super(ToSparse, self).__init__() + self.ignore_value = ignore_value + + def call(self, inputs): + if isinstance(inputs, tf.SparseTensor): + return inputs + if self.ignore_value is None: + if inputs.dtype == tf.string: + ignore_value = '' + elif inputs.dtype.is_integer: + ignore_value = -1 # Embedding layer cannot use -1 + ignore_value = tf.cast(ignore_value, inputs.dtype) + indices = tf.where(tf.not_equal(inputs, ignore_value)) + values = tf.gather_nd(inputs, indices) + dense_shape = tf.shape(inputs, out_type=tf.int64) + return tf.SparseTensor( + indices=indices, values=values, dense_shape=dense_shape + ) diff --git a/elasticdl_preprocessing/tests/to_sparse_test.py b/elasticdl_preprocessing/tests/to_sparse_test.py new file mode 100644 index 000000000..c52f62df2 --- /dev/null +++ b/elasticdl_preprocessing/tests/to_sparse_test.py @@ -0,0 +1,33 @@ +import unittest + +import numpy as np +import tensorflow as tf + +from elasticdl_preprocessing.layers.to_sparse import ToSparse +from elasticdl_preprocessing.tests.test_utils import sparse_tensor_equal + + +class ToSparseTest(unittest.TestCase): + def test_to_sparse(self): + layer = ToSparse() + inp = tf.constant([["A", ""], ["B", "C"]], tf.string) + output = layer.call(inp) + expected_out = tf.SparseTensor( + indices=np.array([[0, 0], [1, 0], [1, 1]]), + values=np.array(["A", "B", "C"]), + dense_shape=(2, 2), + ) + self.assertTrue(sparse_tensor_equal(output, expected_out)) + + inp = tf.constant([[12, -1], [45, 78]], tf.int64) + output = layer.call(inp) + expected_out = tf.SparseTensor( + indices=np.array([[0, 0], [1, 0], [1, 1]]), + values=np.array([12, 45, 78]), + dense_shape=(2, 2), + ) + self.assertTrue(sparse_tensor_equal(output, expected_out)) + + +if __name__ == "__main__": + unittest.main() From 01aa93a2a1c35018f47af69e5ec98e114880b2a9 Mon Sep 17 00:00:00 2001 From: workingloong Date: Fri, 20 Mar 2020 14:33:43 +0800 Subject: [PATCH 2/9] Reformat code --- elasticdl_preprocessing/layers/to_sparse.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/elasticdl_preprocessing/layers/to_sparse.py b/elasticdl_preprocessing/layers/to_sparse.py index 6a3ce0c3f..c113eb725 100644 --- a/elasticdl_preprocessing/layers/to_sparse.py +++ b/elasticdl_preprocessing/layers/to_sparse.py @@ -27,6 +27,7 @@ class ToSparse(tf.keras.layers.Layer): Output shape: An `SparseTensor` with the same shape as inputs """ + def __init__(self, ignore_value=None): super(ToSparse, self).__init__() self.ignore_value = ignore_value @@ -36,7 +37,7 @@ def call(self, inputs): return inputs if self.ignore_value is None: if inputs.dtype == tf.string: - ignore_value = '' + ignore_value = "" elif inputs.dtype.is_integer: ignore_value = -1 # Embedding layer cannot use -1 ignore_value = tf.cast(ignore_value, inputs.dtype) From 17e837dfe3b3acc5b7170085069dee6f138ef100 Mon Sep 17 00:00:00 2001 From: workingloong Date: Fri, 20 Mar 2020 14:57:41 +0800 Subject: [PATCH 3/9] add get_config method --- elasticdl_preprocessing/layers/to_sparse.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/elasticdl_preprocessing/layers/to_sparse.py b/elasticdl_preprocessing/layers/to_sparse.py index c113eb725..18bac1570 100644 --- a/elasticdl_preprocessing/layers/to_sparse.py +++ b/elasticdl_preprocessing/layers/to_sparse.py @@ -47,3 +47,13 @@ def call(self, inputs): return tf.SparseTensor( indices=indices, values=values, dense_shape=dense_shape ) + + def compute_output_shape(self, input_shape): + return input_shape + + def get_config(self): + config = { + "ignore_value": self.ignore_value, + } + base_config = super(ToSparse, self).get_config() + return dict(list(base_config.items()) + list(config.items())) From 83b40f35d5abe1486607b04173fd7966634bf012 Mon Sep 17 00:00:00 2001 From: workingloong Date: Fri, 20 Mar 2020 15:25:41 +0800 Subject: [PATCH 4/9] fix the default value of ignore_value --- elasticdl_preprocessing/layers/to_sparse.py | 9 +++++---- elasticdl_preprocessing/tests/to_sparse_test.py | 1 + 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/elasticdl_preprocessing/layers/to_sparse.py b/elasticdl_preprocessing/layers/to_sparse.py index 18bac1570..3586ddb3c 100644 --- a/elasticdl_preprocessing/layers/to_sparse.py +++ b/elasticdl_preprocessing/layers/to_sparse.py @@ -37,11 +37,12 @@ def call(self, inputs): return inputs if self.ignore_value is None: if inputs.dtype == tf.string: - ignore_value = "" + print("string") + self.ignore_value = "" elif inputs.dtype.is_integer: - ignore_value = -1 # Embedding layer cannot use -1 - ignore_value = tf.cast(ignore_value, inputs.dtype) - indices = tf.where(tf.not_equal(inputs, ignore_value)) + self.ignore_value = -1 + self.ignore_value = tf.cast(self.ignore_value, inputs.dtype) + indices = tf.where(tf.not_equal(inputs, self.ignore_value)) values = tf.gather_nd(inputs, indices) dense_shape = tf.shape(inputs, out_type=tf.int64) return tf.SparseTensor( diff --git a/elasticdl_preprocessing/tests/to_sparse_test.py b/elasticdl_preprocessing/tests/to_sparse_test.py index c52f62df2..750b7582f 100644 --- a/elasticdl_preprocessing/tests/to_sparse_test.py +++ b/elasticdl_preprocessing/tests/to_sparse_test.py @@ -19,6 +19,7 @@ def test_to_sparse(self): ) self.assertTrue(sparse_tensor_equal(output, expected_out)) + layer = ToSparse() inp = tf.constant([[12, -1], [45, 78]], tf.int64) output = layer.call(inp) expected_out = tf.SparseTensor( From b9446c7850bd19084c4d358977ea8262e6ad1ec3 Mon Sep 17 00:00:00 2001 From: workingloong Date: Fri, 20 Mar 2020 15:31:09 +0800 Subject: [PATCH 5/9] Fix the docstring --- elasticdl_preprocessing/layers/to_sparse.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/elasticdl_preprocessing/layers/to_sparse.py b/elasticdl_preprocessing/layers/to_sparse.py index 3586ddb3c..6102b67dd 100644 --- a/elasticdl_preprocessing/layers/to_sparse.py +++ b/elasticdl_preprocessing/layers/to_sparse.py @@ -3,7 +3,7 @@ class ToSparse(tf.keras.layers.Layer): """Converts a `Tensor` to a `SparseTensor`, dropping ignore_value cells. - If `input_tensor` is already a `SparseTensor`, just return it. + If the input is already a `SparseTensor`, just return it. Example : ```python @@ -18,9 +18,9 @@ class ToSparse(tf.keras.layers.Layer): ``` Arguments: - ignore_value: Entries in `dense_tensor` equal to this value will be + ignore_value: Entries in inputs equal to this value will be absent from the output `SparseTensor`. If `None`, default value of - `dense_tensor`'s dtype will be used ('' for `str`, -1 for `int`). + inputs dtype will be used ('' for `str`, -1 for `int`). Input shape: A numeric or string `Tensor` of shape `[batch_size, d1, ..., dm]` From d071af503fae60505fbfda5c8addd757e201e08a Mon Sep 17 00:00:00 2001 From: workingloong Date: Fri, 20 Mar 2020 16:54:05 +0800 Subject: [PATCH 6/9] Remove print code --- elasticdl_preprocessing/layers/to_sparse.py | 1 - 1 file changed, 1 deletion(-) diff --git a/elasticdl_preprocessing/layers/to_sparse.py b/elasticdl_preprocessing/layers/to_sparse.py index 6102b67dd..6922ceb4b 100644 --- a/elasticdl_preprocessing/layers/to_sparse.py +++ b/elasticdl_preprocessing/layers/to_sparse.py @@ -37,7 +37,6 @@ def call(self, inputs): return inputs if self.ignore_value is None: if inputs.dtype == tf.string: - print("string") self.ignore_value = "" elif inputs.dtype.is_integer: self.ignore_value = -1 From 278894df0cfa92842af7a4481c8ab89fdc8952d8 Mon Sep 17 00:00:00 2001 From: workingloong Date: Mon, 23 Mar 2020 11:15:25 +0800 Subject: [PATCH 7/9] Fix the docstring by comments --- elasticdl_preprocessing/layers/to_sparse.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/elasticdl_preprocessing/layers/to_sparse.py b/elasticdl_preprocessing/layers/to_sparse.py index 6922ceb4b..8ee5ba8dd 100644 --- a/elasticdl_preprocessing/layers/to_sparse.py +++ b/elasticdl_preprocessing/layers/to_sparse.py @@ -9,8 +9,11 @@ class ToSparse(tf.keras.layers.Layer): ```python layer = ToSparse() inp = tf.constant([["A", ""], ["B", "C"]], tf.string) - layer.call(inp) - tf.SparseTensor( + out = layer(inp) + ``` + The expected output is + ``` + tf.SparseTensor( indices=np.array([[0, 0], [1, 0], [1, 1]]), values=np.array(["A", "B", "C"]), dense_shape=(2, 2), From 4ef65e772498ed5f943733d9d2380a63b7677905 Mon Sep 17 00:00:00 2001 From: workingloong Date: Wed, 25 Mar 2020 17:30:55 +0800 Subject: [PATCH 8/9] Set ignore_value=0.0 for other dtypes --- elasticdl_preprocessing/layers/to_sparse.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/elasticdl_preprocessing/layers/to_sparse.py b/elasticdl_preprocessing/layers/to_sparse.py index 8ee5ba8dd..3aaffb27d 100644 --- a/elasticdl_preprocessing/layers/to_sparse.py +++ b/elasticdl_preprocessing/layers/to_sparse.py @@ -43,6 +43,8 @@ def call(self, inputs): self.ignore_value = "" elif inputs.dtype.is_integer: self.ignore_value = -1 + else: + self.ignore_value = 0.0 self.ignore_value = tf.cast(self.ignore_value, inputs.dtype) indices = tf.where(tf.not_equal(inputs, self.ignore_value)) values = tf.gather_nd(inputs, indices) From 6c2b8dfc6a835bca9c620124e90c7c624ec98a50 Mon Sep 17 00:00:00 2001 From: workingloong Date: Wed, 25 Mar 2020 21:01:54 +0800 Subject: [PATCH 9/9] Add unit test to create model with ToSparse layers --- elasticdl_preprocessing/layers/to_sparse.py | 14 ++++++++------ elasticdl_preprocessing/tests/to_sparse_test.py | 13 +++++++++++++ 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/elasticdl_preprocessing/layers/to_sparse.py b/elasticdl_preprocessing/layers/to_sparse.py index 3aaffb27d..412fce504 100644 --- a/elasticdl_preprocessing/layers/to_sparse.py +++ b/elasticdl_preprocessing/layers/to_sparse.py @@ -38,15 +38,17 @@ def __init__(self, ignore_value=None): def call(self, inputs): if isinstance(inputs, tf.SparseTensor): return inputs - if self.ignore_value is None: + + ignore_value = self.ignore_value + if ignore_value is None: if inputs.dtype == tf.string: - self.ignore_value = "" + ignore_value = "" elif inputs.dtype.is_integer: - self.ignore_value = -1 + ignore_value = -1 else: - self.ignore_value = 0.0 - self.ignore_value = tf.cast(self.ignore_value, inputs.dtype) - indices = tf.where(tf.not_equal(inputs, self.ignore_value)) + ignore_value = 0.0 + ignore_value = tf.cast(ignore_value, inputs.dtype) + indices = tf.where(tf.not_equal(inputs, ignore_value)) values = tf.gather_nd(inputs, indices) dense_shape = tf.shape(inputs, out_type=tf.int64) return tf.SparseTensor( diff --git a/elasticdl_preprocessing/tests/to_sparse_test.py b/elasticdl_preprocessing/tests/to_sparse_test.py index 750b7582f..8f8aa84ce 100644 --- a/elasticdl_preprocessing/tests/to_sparse_test.py +++ b/elasticdl_preprocessing/tests/to_sparse_test.py @@ -29,6 +29,19 @@ def test_to_sparse(self): ) self.assertTrue(sparse_tensor_equal(output, expected_out)) + def test_model_with_to_sparse(self): + inputs = tf.keras.Input(shape=(1,), dtype=tf.int32) + sparse_inputs = ToSparse(ignore_value=-1)(inputs) + model = tf.keras.Model(inputs=inputs, outputs=sparse_inputs) + out = model.call(tf.constant([[1], [-1], [2], [3]])) + + expect_out = tf.SparseTensor( + indices=tf.constant([[0, 0], [2, 0], [3, 0]], dtype=tf.int64), + values=tf.constant([1, 2, 3], dtype=tf.int32), + dense_shape=(4, 1), + ) + self.assertTrue(sparse_tensor_equal(out, expect_out)) + if __name__ == "__main__": unittest.main()