Skip to content

Commit d757e70

Browse files
authored
IndexLookup layer to map strings from a vocabulary to integer indices (#1864)
* Lookup layer to map strings from a vocabulary to integer indices * IndexLookup layer to map strings from a vocabulary to integer indices * Fix the docstring * Fix docstring * Add note for TF version * Fix by comments * Add an unit test to create model with IndexLookup layers
1 parent 55d8f94 commit d757e70

File tree

2 files changed

+176
-0
lines changed

2 files changed

+176
-0
lines changed
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
from __future__ import absolute_import, division, print_function
2+
3+
import collections
4+
5+
import tensorflow as tf
6+
from tensorflow.python.ops import lookup_ops
7+
8+
9+
class IndexLookup(tf.keras.layers.Layer):
10+
"""Maps strings to integer indices by looking up a vocabulary.
11+
12+
This layer transforms categorical inputs to zero-based integer by
13+
lookuping with a vocabulary list. TensorFlow 2.2 has developed
14+
`tf.keras.layers.preprocessing.IndexLookup` but not released it yet.
15+
So the layer is a simple temporary version. The codes in TensorFlow 2.2 is
16+
`tensorflow.python.keras.layers.preprocessing.index_lookup.IndexLookup`.
17+
18+
Note that the TensorFlow version with the layer must be greater than 2.0.0.
19+
20+
Example:
21+
```python
22+
layer = IndexLookup(vocabulary=['A', 'B', 'C'])
23+
inp = np.array([['A'], ['B'], ['C'], ['D'], ['E']])
24+
layer(inputs)
25+
```
26+
Then output will be `[[0], [1], [2], [3], [3]]`
27+
28+
Attributes:
29+
num_oov_tokens: The number of out-of-vocabulary tokens to use; defaults to
30+
1. If this value is more than 1,
31+
`hash(inputs) % num_oov_tokens + len(vocabulary)` converts OOV inputs
32+
to integer values.
33+
vocabulary: A list of vocabulary terms, or a path to a text file
34+
containing a vocabulary to load into this layer. The file should
35+
contain one token per line.
36+
37+
Input: A string `tf.Tensor`,`tf.SparseTensor` or
38+
`tf.RaggedTensor`.
39+
40+
Output: An int64 tensor with the same type as input.
41+
42+
"""
43+
44+
def __init__(self, vocabulary=None, num_oov_tokens=1, **kwargs):
45+
super(IndexLookup, self).__init__()
46+
self.num_oov_tokens = num_oov_tokens
47+
48+
if vocabulary is not None and isinstance(vocabulary, str):
49+
vocabulary = self._get_vocabulary_from_file(vocabulary)
50+
vocabulary_set = set(vocabulary)
51+
if len(vocabulary) != len(vocabulary_set):
52+
repeated_items = [
53+
item
54+
for item, count in collections.Counter(vocabulary).items()
55+
if count > 1
56+
]
57+
raise ValueError(
58+
"The passed vocabulary has at least one repeated "
59+
"term. Please uniquify your dataset before passing "
60+
"it to IndexLookup(). The repeated terms are %s"
61+
% repeated_items
62+
)
63+
self.vocabulary = vocabulary
64+
65+
def build(self, input_shape):
66+
self.table = lookup_ops.index_table_from_tensor(
67+
vocabulary_list=self.vocabulary,
68+
num_oov_buckets=self.num_oov_tokens,
69+
)
70+
71+
def call(self, inputs):
72+
if isinstance(inputs, tf.SparseTensor):
73+
lookup_id = self.table.lookup(inputs.values)
74+
output = tf.SparseTensor(
75+
indices=inputs.indices,
76+
values=lookup_id,
77+
dense_shape=inputs.dense_shape,
78+
)
79+
elif isinstance(inputs, tf.RaggedTensor):
80+
return tf.ragged.map_flat_values(self.table.lookup, inputs,)
81+
else:
82+
output = self.table.lookup(inputs)
83+
return tf.cast(output, tf.int64)
84+
85+
def _get_vocabulary_from_file(self, vocabulary_path):
86+
vocab = []
87+
with tf.io.gfile.GFile(vocabulary_path, "r") as reader:
88+
while True:
89+
# Get the next line, and break if it is None.
90+
text = reader.readline()
91+
if not text:
92+
break
93+
94+
# Convert the raw text into UTF8 and strip whitespace.
95+
if isinstance(text, str):
96+
token = text
97+
elif isinstance(text, bytes):
98+
token = text.decode("utf-8", "ignore")
99+
token = token.strip()
100+
vocab.append(token)
101+
return vocab
102+
103+
def vocab_size(self):
104+
return self._table.size().numpy()
105+
106+
def compute_output_shape(self, input_shape):
107+
return input_shape
108+
109+
def get_config(self):
110+
config = {
111+
"num_oov_tokens": self.num_oov_tokens,
112+
"vocabulary": None,
113+
}
114+
base_config = super(IndexLookup, self).get_config()
115+
return dict(list(base_config.items()) + list(config.items()))
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import os
2+
import tempfile
3+
import unittest
4+
5+
import numpy as np
6+
import tensorflow as tf
7+
8+
from elasticdl_preprocessing.layers.index_lookup import IndexLookup
9+
from elasticdl_preprocessing.tests.test_utils import (
10+
ragged_tensor_equal,
11+
sparse_tensor_equal,
12+
)
13+
14+
15+
class IndexLookupTest(unittest.TestCase):
16+
def test_lookup_with_list(self):
17+
lookup_layer = IndexLookup(vocabulary=["A", "B", "C"])
18+
self._check_lookup(lookup_layer)
19+
20+
def test_lookup_with_file(self):
21+
with tempfile.TemporaryDirectory() as temp_dir:
22+
vocab_file = os.path.join(temp_dir, "vocab_test.txt")
23+
with open(vocab_file, "w") as f:
24+
f.write("A\n")
25+
f.write("B\n")
26+
f.write("C\n")
27+
lookup_layer = IndexLookup(vocabulary=vocab_file)
28+
self._check_lookup(lookup_layer)
29+
30+
def test_model_with_lookup(self):
31+
inputs = tf.keras.Input(shape=(1,), dtype=tf.string)
32+
lookup_out = IndexLookup(vocabulary=["A", "B", "C"])(inputs)
33+
model = tf.keras.Model(inputs=inputs, outputs=lookup_out)
34+
out = model.call(tf.constant([["A"], ["C"], ["B"], ["D"], ["E"]]))
35+
self.assertTrue(
36+
np.array_equal(
37+
np.array([[0], [2], [1], [3], [3]], dtype=int), out.numpy()
38+
)
39+
)
40+
41+
def _check_lookup(self, lookup_layer):
42+
dense_input = tf.constant([["A"], ["B"], ["C"], ["D"], ["E"]])
43+
output = lookup_layer(dense_input)
44+
expected_out = np.array([[0], [1], [2], [3], [3]])
45+
self.assertTrue(np.array_equal(output.numpy(), expected_out))
46+
47+
ragged_input = tf.ragged.constant([["A", "B", "C"], ["D", "E"]])
48+
ragged_output = lookup_layer(ragged_input)
49+
expected_ragged_out = tf.ragged.constant(
50+
[[0, 1, 2], [3, 3]], dtype=tf.int64
51+
)
52+
self.assertTrue(
53+
ragged_tensor_equal(ragged_output, expected_ragged_out)
54+
)
55+
56+
sparse_input = ragged_input.to_sparse()
57+
sparse_output = lookup_layer(sparse_input)
58+
expected_sparse_out = expected_ragged_out.to_sparse()
59+
self.assertTrue(
60+
sparse_tensor_equal(sparse_output, expected_sparse_out)
61+
)

0 commit comments

Comments
 (0)