-
Notifications
You must be signed in to change notification settings - Fork 221
Description
We cannot export the transform graph from tf.Transform and the graph from Keras Model together into a single SavedModel.
Background
Let's focus on these three stages in the end-to-end machine learning lifecycle of TFX:
Transform stage
: tf.Transform component transforms the raw data into the data used to train a machine learning model. It also exports the transform logic in the SavedModel format. We name it saved_model_trans
.
Training stage
: TensorFlow uses the preprocessed data to train a model. After completing the training job, it exports the model graph together with the transform graph (from saved_model_trans
) as one SavedModel. We name it saved_model_final
.
Serving stage
: TensorFlow Serving can load saved_model_final
to provide inference service. The schemas of the inference request and the raw data of the transform stage are exactly the same.
Current state: how to export transform graph and model graph together to a SavedModel?
√ Estimator (Work Well):
From tf.transform official tutorial, at the stage of exporting model, it will call estimator.export_saved_model(exported_model_dir, serving_input_fn) to complete the model exporting work. Inside its implementation, it calls serving_input_fn to load the transform graph from SavedModel at first, and then calls estimator's model_fn to generate the model graph, combines these two graph into one graph and finally exports it into one SavedModel. Please check the code snippet.
× Keras (Can not):
For TF2.0, we define a model using keras and exports it by calling tf.saved.saved_model. This SavedModel only contains the model definition including feature columns and NN structure. tf.saved.saved_model
api doesn't have the parameter serving_input_fn
just like estimator and it lacks the flexibility to combine transform graph and model graph together for inference. It will break the integration between tf.Transform and TensorFlow 2.0.
We want to improve tf.saved.saved_model
api in Keras to support this function. So that TensorFlow 2.0 can work well with tf.Transform.