diff --git a/examples/addition/addition.py b/examples/addition/addition.py index bc5f869e20..859c91fae6 100644 --- a/examples/addition/addition.py +++ b/examples/addition/addition.py @@ -1,14 +1,16 @@ -# TODO: Stop using v1 compatibility -import tensorflow.compat.v1 as tf +import tensorflow as tf +# check tensorflow version is 2.x +tf_major_version = tf.__version__.split('.')[0] +assert tf_major_version == '2' -tf.disable_eager_execution() -x = tf.placeholder(tf.int32, name = 'x') -y = tf.placeholder(tf.int32, name = 'y') -z = tf.add(x, y, name = 'z') +@tf.function +def add(x, y): + tf.add(x, y, name='z') -tf.variables_initializer(tf.global_variables(), name = 'init') +x = tf.TensorSpec((), dtype=tf.dtypes.int32, name='x') +y = tf.TensorSpec((), dtype=tf.dtypes.int32, name='y') -definition = tf.Session().graph_def +concrete_function = add.get_concrete_function(x, y) directory = 'examples/addition' -tf.train.write_graph(definition, directory, 'model.pb', as_text=False) +tf.io.write_graph(concrete_function.graph, directory, 'model.pb', as_text=False) diff --git a/examples/addition/model.pb b/examples/addition/model.pb index 19330df599..1a45a6aeae 100644 Binary files a/examples/addition/model.pb and b/examples/addition/model.pb differ