diff --git a/elasticdl_preprocessing/layers/to_sparse.py b/elasticdl_preprocessing/layers/to_sparse.py new file mode 100644 index 000000000..412fce504 --- /dev/null +++ b/elasticdl_preprocessing/layers/to_sparse.py @@ -0,0 +1,66 @@ +import tensorflow as tf + + +class ToSparse(tf.keras.layers.Layer): + """Converts a `Tensor` to a `SparseTensor`, dropping ignore_value cells. + If the input is already a `SparseTensor`, just return it. + + Example : + ```python + layer = ToSparse() + inp = tf.constant([["A", ""], ["B", "C"]], tf.string) + out = layer(inp) + ``` + The expected output is + ``` + tf.SparseTensor( + indices=np.array([[0, 0], [1, 0], [1, 1]]), + values=np.array(["A", "B", "C"]), + dense_shape=(2, 2), + ) + ``` + + Arguments: + ignore_value: Entries in inputs equal to this value will be + absent from the output `SparseTensor`. If `None`, default value of + inputs dtype will be used ('' for `str`, -1 for `int`). + + Input shape: A numeric or string `Tensor` of shape + `[batch_size, d1, ..., dm]` + + Output shape: An `SparseTensor` with the same shape as inputs + """ + + def __init__(self, ignore_value=None): + super(ToSparse, self).__init__() + self.ignore_value = ignore_value + + def call(self, inputs): + if isinstance(inputs, tf.SparseTensor): + return inputs + + ignore_value = self.ignore_value + if ignore_value is None: + if inputs.dtype == tf.string: + ignore_value = "" + elif inputs.dtype.is_integer: + ignore_value = -1 + else: + ignore_value = 0.0 + ignore_value = tf.cast(ignore_value, inputs.dtype) + indices = tf.where(tf.not_equal(inputs, ignore_value)) + values = tf.gather_nd(inputs, indices) + dense_shape = tf.shape(inputs, out_type=tf.int64) + return tf.SparseTensor( + indices=indices, values=values, dense_shape=dense_shape + ) + + def compute_output_shape(self, input_shape): + return input_shape + + def get_config(self): + config = { + "ignore_value": self.ignore_value, + } + base_config = super(ToSparse, self).get_config() + return dict(list(base_config.items()) + list(config.items())) diff --git a/elasticdl_preprocessing/tests/to_sparse_test.py b/elasticdl_preprocessing/tests/to_sparse_test.py new file mode 100644 index 000000000..8f8aa84ce --- /dev/null +++ b/elasticdl_preprocessing/tests/to_sparse_test.py @@ -0,0 +1,47 @@ +import unittest + +import numpy as np +import tensorflow as tf + +from elasticdl_preprocessing.layers.to_sparse import ToSparse +from elasticdl_preprocessing.tests.test_utils import sparse_tensor_equal + + +class ToSparseTest(unittest.TestCase): + def test_to_sparse(self): + layer = ToSparse() + inp = tf.constant([["A", ""], ["B", "C"]], tf.string) + output = layer.call(inp) + expected_out = tf.SparseTensor( + indices=np.array([[0, 0], [1, 0], [1, 1]]), + values=np.array(["A", "B", "C"]), + dense_shape=(2, 2), + ) + self.assertTrue(sparse_tensor_equal(output, expected_out)) + + layer = ToSparse() + inp = tf.constant([[12, -1], [45, 78]], tf.int64) + output = layer.call(inp) + expected_out = tf.SparseTensor( + indices=np.array([[0, 0], [1, 0], [1, 1]]), + values=np.array([12, 45, 78]), + dense_shape=(2, 2), + ) + self.assertTrue(sparse_tensor_equal(output, expected_out)) + + def test_model_with_to_sparse(self): + inputs = tf.keras.Input(shape=(1,), dtype=tf.int32) + sparse_inputs = ToSparse(ignore_value=-1)(inputs) + model = tf.keras.Model(inputs=inputs, outputs=sparse_inputs) + out = model.call(tf.constant([[1], [-1], [2], [3]])) + + expect_out = tf.SparseTensor( + indices=tf.constant([[0, 0], [2, 0], [3, 0]], dtype=tf.int64), + values=tf.constant([1, 2, 3], dtype=tf.int32), + dense_shape=(4, 1), + ) + self.assertTrue(sparse_tensor_equal(out, expect_out)) + + +if __name__ == "__main__": + unittest.main()