Skip to content

Commit 29a8060

Browse files
authored
Layer to convert Tensor to SparseTensor dropping ignore values (#1860)
* Layer to convert dense tensor to sparse tensor dropping ignore_value cells * Reformat code * add get_config method * fix the default value of ignore_value * Fix the docstring * Remove print code * Fix the docstring by comments * Set ignore_value=0.0 for other dtypes * Add unit test to create model with ToSparse layers
1 parent d757e70 commit 29a8060

File tree

2 files changed

+113
-0
lines changed

2 files changed

+113
-0
lines changed
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import tensorflow as tf
2+
3+
4+
class ToSparse(tf.keras.layers.Layer):
5+
"""Converts a `Tensor` to a `SparseTensor`, dropping ignore_value cells.
6+
If the input is already a `SparseTensor`, just return it.
7+
8+
Example :
9+
```python
10+
layer = ToSparse()
11+
inp = tf.constant([["A", ""], ["B", "C"]], tf.string)
12+
out = layer(inp)
13+
```
14+
The expected output is
15+
```
16+
tf.SparseTensor(
17+
indices=np.array([[0, 0], [1, 0], [1, 1]]),
18+
values=np.array(["A", "B", "C"]),
19+
dense_shape=(2, 2),
20+
)
21+
```
22+
23+
Arguments:
24+
ignore_value: Entries in inputs equal to this value will be
25+
absent from the output `SparseTensor`. If `None`, default value of
26+
inputs dtype will be used ('' for `str`, -1 for `int`).
27+
28+
Input shape: A numeric or string `Tensor` of shape
29+
`[batch_size, d1, ..., dm]`
30+
31+
Output shape: An `SparseTensor` with the same shape as inputs
32+
"""
33+
34+
def __init__(self, ignore_value=None):
35+
super(ToSparse, self).__init__()
36+
self.ignore_value = ignore_value
37+
38+
def call(self, inputs):
39+
if isinstance(inputs, tf.SparseTensor):
40+
return inputs
41+
42+
ignore_value = self.ignore_value
43+
if ignore_value is None:
44+
if inputs.dtype == tf.string:
45+
ignore_value = ""
46+
elif inputs.dtype.is_integer:
47+
ignore_value = -1
48+
else:
49+
ignore_value = 0.0
50+
ignore_value = tf.cast(ignore_value, inputs.dtype)
51+
indices = tf.where(tf.not_equal(inputs, ignore_value))
52+
values = tf.gather_nd(inputs, indices)
53+
dense_shape = tf.shape(inputs, out_type=tf.int64)
54+
return tf.SparseTensor(
55+
indices=indices, values=values, dense_shape=dense_shape
56+
)
57+
58+
def compute_output_shape(self, input_shape):
59+
return input_shape
60+
61+
def get_config(self):
62+
config = {
63+
"ignore_value": self.ignore_value,
64+
}
65+
base_config = super(ToSparse, self).get_config()
66+
return dict(list(base_config.items()) + list(config.items()))
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import unittest
2+
3+
import numpy as np
4+
import tensorflow as tf
5+
6+
from elasticdl_preprocessing.layers.to_sparse import ToSparse
7+
from elasticdl_preprocessing.tests.test_utils import sparse_tensor_equal
8+
9+
10+
class ToSparseTest(unittest.TestCase):
11+
def test_to_sparse(self):
12+
layer = ToSparse()
13+
inp = tf.constant([["A", ""], ["B", "C"]], tf.string)
14+
output = layer.call(inp)
15+
expected_out = tf.SparseTensor(
16+
indices=np.array([[0, 0], [1, 0], [1, 1]]),
17+
values=np.array(["A", "B", "C"]),
18+
dense_shape=(2, 2),
19+
)
20+
self.assertTrue(sparse_tensor_equal(output, expected_out))
21+
22+
layer = ToSparse()
23+
inp = tf.constant([[12, -1], [45, 78]], tf.int64)
24+
output = layer.call(inp)
25+
expected_out = tf.SparseTensor(
26+
indices=np.array([[0, 0], [1, 0], [1, 1]]),
27+
values=np.array([12, 45, 78]),
28+
dense_shape=(2, 2),
29+
)
30+
self.assertTrue(sparse_tensor_equal(output, expected_out))
31+
32+
def test_model_with_to_sparse(self):
33+
inputs = tf.keras.Input(shape=(1,), dtype=tf.int32)
34+
sparse_inputs = ToSparse(ignore_value=-1)(inputs)
35+
model = tf.keras.Model(inputs=inputs, outputs=sparse_inputs)
36+
out = model.call(tf.constant([[1], [-1], [2], [3]]))
37+
38+
expect_out = tf.SparseTensor(
39+
indices=tf.constant([[0, 0], [2, 0], [3, 0]], dtype=tf.int64),
40+
values=tf.constant([1, 2, 3], dtype=tf.int32),
41+
dense_shape=(4, 1),
42+
)
43+
self.assertTrue(sparse_tensor_equal(out, expect_out))
44+
45+
46+
if __name__ == "__main__":
47+
unittest.main()

0 commit comments

Comments
 (0)