-
Notifications
You must be signed in to change notification settings - Fork 116
Layer to convert Tensor to SparseTensor dropping ignore values #1860
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
brightcoder01
merged 10 commits into
sql-machine-learning:develop
from
workingloong:to_sparse
Mar 26, 2020
Merged
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
035bc94
Layer to convert dense tensor to sparse tensor dropping ignore_value …
workingloong 01aa93a
Reformat code
workingloong 17e837d
add get_config method
workingloong 83b40f3
fix the default value of ignore_value
workingloong b9446c7
Fix the docstring
workingloong d071af5
Remove print code
workingloong 278894d
Fix the docstring by comments
workingloong 4ef65e7
Set ignore_value=0.0 for other dtypes
workingloong 3a1da78
Merge branch 'develop' into to_sparse
workingloong 6c2b8df
Add unit test to create model with ToSparse layers
workingloong File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import tensorflow as tf | ||
|
||
|
||
class ToSparse(tf.keras.layers.Layer): | ||
"""Converts a `Tensor` to a `SparseTensor`, dropping ignore_value cells. | ||
If the input is already a `SparseTensor`, just return it. | ||
|
||
Example : | ||
```python | ||
layer = ToSparse() | ||
inp = tf.constant([["A", ""], ["B", "C"]], tf.string) | ||
out = layer(inp) | ||
``` | ||
The expected output is | ||
``` | ||
tf.SparseTensor( | ||
indices=np.array([[0, 0], [1, 0], [1, 1]]), | ||
values=np.array(["A", "B", "C"]), | ||
dense_shape=(2, 2), | ||
) | ||
``` | ||
|
||
Arguments: | ||
ignore_value: Entries in inputs equal to this value will be | ||
absent from the output `SparseTensor`. If `None`, default value of | ||
inputs dtype will be used ('' for `str`, -1 for `int`). | ||
|
||
Input shape: A numeric or string `Tensor` of shape | ||
`[batch_size, d1, ..., dm]` | ||
|
||
Output shape: An `SparseTensor` with the same shape as inputs | ||
""" | ||
|
||
def __init__(self, ignore_value=None): | ||
super(ToSparse, self).__init__() | ||
self.ignore_value = ignore_value | ||
|
||
def call(self, inputs): | ||
if isinstance(inputs, tf.SparseTensor): | ||
return inputs | ||
|
||
ignore_value = self.ignore_value | ||
if ignore_value is None: | ||
if inputs.dtype == tf.string: | ||
ignore_value = "" | ||
elif inputs.dtype.is_integer: | ||
ignore_value = -1 | ||
else: | ||
ignore_value = 0.0 | ||
ignore_value = tf.cast(ignore_value, inputs.dtype) | ||
indices = tf.where(tf.not_equal(inputs, ignore_value)) | ||
values = tf.gather_nd(inputs, indices) | ||
dense_shape = tf.shape(inputs, out_type=tf.int64) | ||
return tf.SparseTensor( | ||
indices=indices, values=values, dense_shape=dense_shape | ||
) | ||
|
||
def compute_output_shape(self, input_shape): | ||
return input_shape | ||
|
||
def get_config(self): | ||
config = { | ||
"ignore_value": self.ignore_value, | ||
} | ||
base_config = super(ToSparse, self).get_config() | ||
return dict(list(base_config.items()) + list(config.items())) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import unittest | ||
|
||
import numpy as np | ||
import tensorflow as tf | ||
|
||
from elasticdl_preprocessing.layers.to_sparse import ToSparse | ||
from elasticdl_preprocessing.tests.test_utils import sparse_tensor_equal | ||
|
||
|
||
class ToSparseTest(unittest.TestCase): | ||
def test_to_sparse(self): | ||
layer = ToSparse() | ||
inp = tf.constant([["A", ""], ["B", "C"]], tf.string) | ||
output = layer.call(inp) | ||
expected_out = tf.SparseTensor( | ||
indices=np.array([[0, 0], [1, 0], [1, 1]]), | ||
values=np.array(["A", "B", "C"]), | ||
dense_shape=(2, 2), | ||
) | ||
self.assertTrue(sparse_tensor_equal(output, expected_out)) | ||
|
||
layer = ToSparse() | ||
inp = tf.constant([[12, -1], [45, 78]], tf.int64) | ||
output = layer.call(inp) | ||
expected_out = tf.SparseTensor( | ||
indices=np.array([[0, 0], [1, 0], [1, 1]]), | ||
values=np.array([12, 45, 78]), | ||
dense_shape=(2, 2), | ||
) | ||
self.assertTrue(sparse_tensor_equal(output, expected_out)) | ||
|
||
def test_model_with_to_sparse(self): | ||
inputs = tf.keras.Input(shape=(1,), dtype=tf.int32) | ||
sparse_inputs = ToSparse(ignore_value=-1)(inputs) | ||
model = tf.keras.Model(inputs=inputs, outputs=sparse_inputs) | ||
out = model.call(tf.constant([[1], [-1], [2], [3]])) | ||
|
||
expect_out = tf.SparseTensor( | ||
indices=tf.constant([[0, 0], [2, 0], [3, 0]], dtype=tf.int64), | ||
values=tf.constant([1, 2, 3], dtype=tf.int32), | ||
dense_shape=(4, 1), | ||
) | ||
self.assertTrue(sparse_tensor_equal(out, expect_out)) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we expose this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-1 for
int
seems dangerous as this is application specificThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-1 for
int
is the default value and the layer will ignore it during transformation.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use -1 as the default ignore_value for
int
type is also the implementation inside feature column. Please check the code snippet.