diff --git a/RELEASE.md b/RELEASE.md index 6ef49ea9d4..fbafb8db13 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -9,6 +9,7 @@ most likely you discovered a bug and should not use an f-string in the first place. If it is truly your intention to print the placeholder (not its resolved value) for debugging purposes, use `repr()` or `!r` instead. +* Drop supports for the Estimator API. ### For Pipeline Authors @@ -224,7 +225,7 @@ ## Bug Fixes and Other Changes -* Support to task type "workerpool1" of CLUSTER_SPEC in Vertex AI training's +* Support to task type "workerpool1" of CLUSTER_SPEC in Vertex AI training's service according to the changes of task type in Tuner component. * Propagates unexpected import failures in the public v1 module. @@ -2887,4 +2888,4 @@ the 1.1.x release for TFX library. ### For component authors -* N/A \ No newline at end of file +* N/A diff --git a/docs/guide/evaluator.md b/docs/guide/evaluator.md index a1a72ab15e..639c4ff1e4 100644 --- a/docs/guide/evaluator.md +++ b/docs/guide/evaluator.md @@ -66,9 +66,7 @@ import tensorflow_model_analysis as tfma eval_config = tfma.EvalConfig( model_specs=[ # This assumes a serving model with signature 'serving_default'. If - # using estimator based EvalSavedModel, add signature_name='eval' and - # remove the label_key. Note, if using a TFLite model, then you must set - # model_type='tf_lite'. + # using a TFLite model, then you must set model_type='tf_lite'. tfma.ModelSpec(label_key='') ], metrics_specs=[ diff --git a/docs/guide/fairness_indicators.md b/docs/guide/fairness_indicators.md index 7f891d1408..771cdf0d05 100644 --- a/docs/guide/fairness_indicators.md +++ b/docs/guide/fairness_indicators.md @@ -43,16 +43,6 @@ an evaluation set that does, or considering proxy features within your feature set that may highlight outcome disparities. For additional guidance, see [here](https://tensorflow.org/responsible_ai/fairness_indicators/guide/guidance). -### Model - -You can use the Tensorflow Estimator class to build your model. Support for -Keras models is coming soon to TFMA. If you would like to run TFMA on a Keras -model, please see the “Model-Agnostic TFMA” section below. - -After your Estimator is trained, you will need to export a saved model for -evaluation purposes. To learn more, see the -[TFMA guide](https://www.tensorflow.org/tfx/model_analysis/get_started). - ### Configuring Slices Next, define the slices you would like to evaluate on: diff --git a/docs/guide/index.md b/docs/guide/index.md index cf70a88ecf..95eb0b6b56 100644 --- a/docs/guide/index.md +++ b/docs/guide/index.md @@ -438,23 +438,6 @@ using the exact same code during both training and inference. Using the modeling code, including the SavedModel from the Transform component, you can consume your training and evaluation data and train your model. -When working with Estimator based models, the last section of your modeling -code should save your model as both a SavedModel and an EvalSavedModel. Saving -as an EvalSavedModel ensures the metrics used at training time are also -available during evaluation (note that this is not required for keras based -models). Saving an EvalSavedModel requires that you import the -[TensorFlow Model Analysis (TFMA)](tfma.md) library in your Trainer component. - -```python -import tensorflow_model_analysis as tfma -... - -tfma.export.export_eval_savedmodel( - estimator=estimator, - export_dir_base=eval_model_dir, - eval_input_receiver_fn=receiver_fn) -``` - An optional [Tuner](tuner.md) component can be added before Trainer to tune the hyperparameters (e.g., number of layers) for the model. With the given model and hyperparameters' search space, tuning algorithm will find the best diff --git a/docs/guide/keras.md b/docs/guide/keras.md index f0870b8200..9f85393b89 100644 --- a/docs/guide/keras.md +++ b/docs/guide/keras.md @@ -38,54 +38,10 @@ they become available in TF 2.x, you can follow the ## Estimator -The Estimator API has been retained in TensorFlow 2.x, but is not the focus of -new features and development. Code written in TensorFlow 1.x or 2.x using -Estimators will continue to work as expected in TFX. +The Estimator API has been fully dropped since TensorFlow 2.16, we decided to +discontinue the support for it. -Here is an end-to-end TFX example using pure Estimator: -[Taxi example (Estimator)](https://github.com/tensorflow/tfx/blob/r0.21/tfx/examples/chicago_taxi_pipeline/taxi_utils.py) - -## Keras with `model_to_estimator` - -Keras models can be wrapped with the `tf.keras.estimator.model_to_estimator` -function, which allows them to work as if they were Estimators. To use this: - -1. Build a Keras model. -2. Pass the compiled model into `model_to_estimator`. -3. Use the result of `model_to_estimator` in Trainer, the way you would - typically use an Estimator. - -```py -# Build a Keras model. -def _keras_model_builder(): - """Creates a Keras model.""" - ... - - model = tf.keras.Model(inputs=inputs, outputs=output) - model.compile() - - return model - - -# Write a typical trainer function -def trainer_fn(trainer_fn_args, schema): - """Build the estimator, using model_to_estimator.""" - ... - - # Model to estimator - estimator = tf.keras.estimator.model_to_estimator( - keras_model=_keras_model_builder(), config=run_config) - - return { - 'estimator': estimator, - ... - } -``` - -Other than the user module file of Trainer, the rest of the pipeline remains -unchanged. - -## Native Keras (i.e. Keras without `model_to_estimator`) +## Native Keras (i.e. Keras without Estimator) !!! Note Full support for all features in Keras is in progress, in most cases, @@ -101,7 +57,7 @@ Here are several examples with native Keras: 'Hello world' end-to-end example. * [MNIST](https://github.com/tensorflow/tfx/blob/master/tfx/examples/mnist/mnist_pipeline_native_keras.py) ([module file](https://github.com/tensorflow/tfx/blob/master/tfx/examples/mnist/mnist_utils_native_keras.py)): - Image and TFLite end-to-end example. + Image end-to-end example. * [Taxi](https://github.com/tensorflow/tfx/blob/master/tfx/examples/chicago_taxi_pipeline/taxi_pipeline_native_keras.py) ([module file](https://github.com/tensorflow/tfx/blob/master/tfx/examples/chicago_taxi_pipeline/taxi_utils_native_keras.py)): end-to-end example with advanced Transform usage. @@ -132,11 +88,6 @@ will be discussed in the following Trainer and Evaluator sections. #### Trainer -To configure native Keras, the `GenericExecutor` needs to be set for Trainer -component to replace the default Estimator based executor. For details, please -check -[here](trainer.md#configuring-the-trainer-component). - ##### Keras Module file with Transform The training module file must contains a `run_fn` which will be called by the @@ -296,9 +247,4 @@ validate the current model compared with previous models. With this change, the Pusher component now consumes a blessing result from Evaluator instead of ModelValidator. -The new Evaluator supports Keras models as well as Estimator models. The -`_eval_input_receiver_fn` and eval saved model which were required previously -will no longer be needed with Keras, since Evaluator is now based on the same -`SavedModel` that is used for serving. - [See Evaluator for more information](evaluator.md). diff --git a/docs/guide/modelval.md b/docs/guide/modelval.md index b2bafc63a5..9dc68d3a28 100644 --- a/docs/guide/modelval.md +++ b/docs/guide/modelval.md @@ -33,9 +33,7 @@ import tensorflow_model_analysis as tfma eval_config = tfma.EvalConfig( model_specs=[ - # This assumes a serving model with signature 'serving_default'. If - # using estimator based EvalSavedModel, add signature_name: 'eval' and - # remove the label_key. + # This assumes a serving model with signature 'serving_default'. tfma.ModelSpec(label_key='') ], metrics_specs=[ diff --git a/docs/guide/train.md b/docs/guide/train.md index 395db2814f..092c2876fe 100644 --- a/docs/guide/train.md +++ b/docs/guide/train.md @@ -22,59 +22,3 @@ a [Transform](transform.md) component, and the layers of the Transform model sho be included with your model so that when you export your SavedModel and EvalSavedModel they will include the transformations that were created by the [Transform](transform.md) component. - -A typical TensorFlow model design for TFX looks like this: - -```python -def _build_estimator(tf_transform_dir, - config, - hidden_units=None, - warm_start_from=None): - """Build an estimator for predicting the tipping behavior of taxi riders. - - Args: - tf_transform_dir: directory in which the tf-transform model was written - during the preprocessing step. - config: tf.contrib.learn.RunConfig defining the runtime environment for the - estimator (including model_dir). - hidden_units: [int], the layer sizes of the DNN (input layer first) - warm_start_from: Optional directory to warm start from. - - Returns: - Resulting DNNLinearCombinedClassifier. - """ - metadata_dir = os.path.join(tf_transform_dir, - transform_fn_io.TRANSFORMED_METADATA_DIR) - transformed_metadata = metadata_io.read_metadata(metadata_dir) - transformed_feature_spec = transformed_metadata.schema.as_feature_spec() - - transformed_feature_spec.pop(_transformed_name(_LABEL_KEY)) - - real_valued_columns = [ - tf.feature_column.numeric_column(key, shape=()) - for key in _transformed_names(_DENSE_FLOAT_FEATURE_KEYS) - ] - categorical_columns = [ - tf.feature_column.categorical_column_with_identity( - key, num_buckets=_VOCAB_SIZE + _OOV_SIZE, default_value=0) - for key in _transformed_names(_VOCAB_FEATURE_KEYS) - ] - categorical_columns += [ - tf.feature_column.categorical_column_with_identity( - key, num_buckets=_FEATURE_BUCKET_COUNT, default_value=0) - for key in _transformed_names(_BUCKET_FEATURE_KEYS) - ] - categorical_columns += [ - tf.feature_column.categorical_column_with_identity( - key, num_buckets=num_buckets, default_value=0) - for key, num_buckets in zip( - _transformed_names(_CATEGORICAL_FEATURE_KEYS), # - _MAX_CATEGORICAL_FEATURE_VALUES) - ] - return tf.estimator.DNNLinearCombinedClassifier( - config=config, - linear_feature_columns=categorical_columns, - dnn_feature_columns=real_valued_columns, - dnn_hidden_units=hidden_units or [100, 70, 50, 25], - warm_start_from=warm_start_from) -``` diff --git a/docs/guide/trainer.md b/docs/guide/trainer.md index ba80f2e4ca..596dcbeec2 100644 --- a/docs/guide/trainer.md +++ b/docs/guide/trainer.md @@ -29,14 +29,14 @@ Trainer emits: At least one model for inference/serving (typically in SavedModel We provide support for alternate model formats such as [TFLite](https://www.tensorflow.org/lite) through the [Model Rewriting Library](https://github.com/tensorflow/tfx/blob/master/tfx/components/trainer/rewriting/README.md). -See the link to the Model Rewriting Library for examples of how to convert both Estimator and Keras +See the link to the Model Rewriting Library for examples of how to convert Keras models. ## Generic Trainer Generic trainer enables developers to use any TensorFlow model API with the -Trainer component. In addition to TensorFlow Estimators, developers can use -Keras models or custom training loops. For details, please see the +Trainer component. Developers can use Keras models or custom training loops. +For details, please see the [RFC for generic trainer](https://github.com/tensorflow/community/blob/master/rfcs/20200117-tfx-generic-trainer.md). ### Configuring the Trainer Component @@ -57,10 +57,8 @@ trainer = Trainer( ``` Trainer invokes a training module, which is specified in the `module_file` -parameter. Instead of `trainer_fn`, a `run_fn` is required in the module file if -the `GenericExecutor` is specified in the `custom_executor_spec`. The -`trainer_fn` was responsible for creating the model. In addition to that, -`run_fn` also needs to handle the training part and output the trained model to +parameter. A `run_fn` is required in the module file, +and it needs to handle the training part and output the trained model to a the desired location given by [FnArgs](https://github.com/tensorflow/tfx/blob/master/tfx/components/trainer/fn_args_utils.py): diff --git a/docs/tutorials/tfx/cloud-ai-platform-pipelines.md b/docs/tutorials/tfx/cloud-ai-platform-pipelines.md index 7edd78f6ab..40977a0d05 100644 --- a/docs/tutorials/tfx/cloud-ai-platform-pipelines.md +++ b/docs/tutorials/tfx/cloud-ai-platform-pipelines.md @@ -333,9 +333,6 @@ Here is brief description of the Python files. - `features.py` `features_test.py` — defines features for the model - `preprocessing.py` / `preprocessing_test.py` — defines preprocessing jobs using `tf::Transform` - - `estimator` - This directory contains an Estimator based model. - - `constants.py` — defines constants of the model - - `model.py` / `model_test.py` — defines DNN model using TF estimator - `keras` - This directory contains a Keras based model. - `constants.py` — defines constants of the model - `model.py` / `model_test.py` — defines DNN model using Keras diff --git a/docs/tutorials/tfx/tfx_for_mobile.md b/docs/tutorials/tfx/tfx_for_mobile.md index ec12a0575c..8de3b697a1 100644 --- a/docs/tutorials/tfx/tfx_for_mobile.md +++ b/docs/tutorials/tfx/tfx_for_mobile.md @@ -30,10 +30,6 @@ The TFX Trainer expects a user-defined `run_fn` to be specified in a module file. This `run_fn` defines the model to be trained, trains it for the specified number of iterations, and exports the trained model. -In the rest of this section, we provide code snippets which show the changes -required to invoke the TFLite rewriter and export a TFLite model. All of this -code is located in the `run_fn` of the [MNIST TFLite module](https://github.com/tensorflow/tfx/blob/master/tfx/examples/mnist/mnist_utils_native_keras_lite.py). - As shown in the code below, we must first create a signature that takes a `Tensor` for every feature as input. Note that this is a departure from most existing models in TFX, which take diff --git a/tfx/components/testdata/module_file/trainer_module.py b/tfx/components/testdata/module_file/trainer_module.py index bf46404c88..4fdc7550e6 100644 --- a/tfx/components/testdata/module_file/trainer_module.py +++ b/tfx/components/testdata/module_file/trainer_module.py @@ -13,33 +13,29 @@ # limitations under the License. """Python source file include taxi pipeline functions and necesasry utils. -For a TFX pipeline to successfully run, a preprocessing_fn and a -_build_estimator function needs to be provided. This file contains both. - -This file is equivalent to examples/chicago_taxi/trainer/model.py and -examples/chicago_taxi/preprocess.py. +The utilities in this file are used to build a model with native Keras. +This module file will be used in Transform and generic Trainer. """ -import absl +from typing import Optional + +from absl import logging import tensorflow as tf -from tensorflow import estimator as tf_estimator -import tensorflow_model_analysis as tfma import tensorflow_transform as tft -from tensorflow_transform.tf_metadata import schema_utils -from tfx.components.trainer import executor -from tfx.utils import io_utils -from tfx.utils import path_utils -from tfx_bsl.public.tfxio import TensorFlowDatasetOptions -from tensorflow_metadata.proto.v0 import schema_pb2 - +from tfx.components.trainer import fn_args_utils +from tfx_bsl.tfxio import dataset_options # Categorical features are assumed to each have a maximum value in the dataset. -_MAX_CATEGORICAL_FEATURE_VALUES = [24, 31, 12] +_MAX_CATEGORICAL_FEATURE_VALUES = [24, 31, 13] _CATEGORICAL_FEATURE_KEYS = [ - 'trip_start_hour', 'trip_start_day', 'trip_start_month', - 'pickup_census_tract', 'dropoff_census_tract', 'pickup_community_area', - 'dropoff_community_area' + 'trip_start_hour', + 'trip_start_day', + 'trip_start_month', + 'pickup_census_tract', + 'dropoff_census_tract', + 'pickup_community_area', + 'dropoff_community_area', ] _DENSE_FLOAT_FEATURE_KEYS = ['trip_miles', 'fare', 'trip_seconds'] @@ -48,8 +44,10 @@ _FEATURE_BUCKET_COUNT = 10 _BUCKET_FEATURE_KEYS = [ - 'pickup_latitude', 'pickup_longitude', 'dropoff_latitude', - 'dropoff_longitude' + 'pickup_latitude', + 'pickup_longitude', + 'dropoff_latitude', + 'dropoff_longitude', ] # Number of vocabulary terms used for encoding VOCAB_FEATURES by tf.transform @@ -76,276 +74,295 @@ def _transformed_names(keys): return [_transformed_name(key) for key in keys] -# Tf.Transform considers these features as "raw" -def _get_raw_feature_spec(schema): - return schema_utils.schema_as_feature_spec(schema).feature_spec - +def _fill_in_missing(x): + """Replace missing values in a SparseTensor. -def _gzip_reader_fn(filenames): - """Small utility returning a record reader that can read gzip'ed files.""" - return tf.data.TFRecordDataset(filenames, compression_type='GZIP') - - -def _build_estimator(config, hidden_units=None, warm_start_from=None): - """Build an estimator for predicting the tipping behavior of taxi riders. + Fills in missing values of `x` with '' or 0, and converts to a dense tensor. Args: - config: tf.estimator.RunConfig defining the runtime environment for the - estimator (including model_dir). - hidden_units: [int], the layer sizes of the DNN (input layer first) - warm_start_from: Optional directory to warm start from. + x: A `SparseTensor` of rank 2. Its dense shape should have size at most 1 + in the second dimension. Returns: - A dict of the following: - - estimator: The estimator that will be used for training and eval. - - train_spec: Spec for training. - - eval_spec: Spec for eval. - - eval_input_receiver_fn: Input function for eval. + A rank 1 tensor where missing values of `x` have been filled in. """ - real_valued_columns = [ - tf.feature_column.numeric_column(key, shape=()) - for key in _transformed_names(_DENSE_FLOAT_FEATURE_KEYS) - ] - categorical_columns = [ - tf.feature_column.categorical_column_with_identity( - key, num_buckets=_VOCAB_SIZE + _OOV_SIZE, default_value=0) - for key in _transformed_names(_VOCAB_FEATURE_KEYS) - ] - categorical_columns += [ - tf.feature_column.categorical_column_with_identity( - key, num_buckets=_FEATURE_BUCKET_COUNT, default_value=0) - for key in _transformed_names(_BUCKET_FEATURE_KEYS) - ] - categorical_columns += [ - tf.feature_column.categorical_column_with_identity( # pylint: disable=g-complex-comprehension - key, - num_buckets=num_buckets, - default_value=0) for key, num_buckets in zip( - _transformed_names(_CATEGORICAL_FEATURE_KEYS), - _MAX_CATEGORICAL_FEATURE_VALUES) - ] - return tf_estimator.DNNLinearCombinedClassifier( - config=config, - linear_feature_columns=categorical_columns, - dnn_feature_columns=real_valued_columns, - dnn_hidden_units=hidden_units or [100, 70, 50, 25], - warm_start_from=warm_start_from) - - -def _example_serving_receiver_fn(tf_transform_output, schema): - """Build the serving in inputs. + if not isinstance(x, tf.sparse.SparseTensor): + return x + + default_value = '' if x.dtype == tf.string else 0 + dense_tensor = tf.sparse.to_dense( + tf.SparseTensor(x.indices, x.values, [x.dense_shape[0], 1]), + default_value, + ) + return dense_tensor + + +def _get_tf_examples_serving_signature(model, tf_transform_output): + """Returns a serving signature that accepts `tensorflow.Example`.""" + model.tft_layer_inference = tf_transform_output.transform_features_layer() + + @tf.function( + input_signature=[ + tf.TensorSpec(shape=[None], dtype=tf.string, name='examples') + ] + ) + def serve_tf_examples_fn(serialized_tf_example): + raw_feature_spec = tf_transform_output.raw_feature_spec() + raw_feature_spec.pop(_LABEL_KEY) + raw_features = tf.io.parse_example(serialized_tf_example, raw_feature_spec) + transformed_features = model.tft_layer_inference(raw_features) + logging.info('serve_transformed_features = %s', transformed_features) + + outputs = model(transformed_features) + return {'outputs': outputs} + + return serve_tf_examples_fn + + +def _get_transform_features_signature(model, tf_transform_output): + """Returns a serving signature that accepts `tensorflow.Example`.""" + model.tft_layer_eval = tf_transform_output.transform_features_layer() + + @tf.function( + input_signature=[ + tf.TensorSpec(shape=[None], dtype=tf.string, name='examples') + ] + ) + def transform_features_fn(serialized_tf_example): + raw_feature_spec = tf_transform_output.raw_feature_spec() + raw_features = tf.io.parse_example(serialized_tf_example, raw_feature_spec) + transformed_features = model.tft_layer_eval(raw_features) + logging.info('eval_transformed_features = %s', transformed_features) + return transformed_features + + return transform_features_fn + + +def _input_fn( + file_pattern: list[str], + data_accessor: fn_args_utils.DataAccessor, + tf_transform_output: tft.TFTransformOutput, + batch_size: int = 200, +) -> tf.data.Dataset: + """Generates features and label for tuning/training. Args: + file_pattern: List of paths or patterns of input tfrecord files. + data_accessor: fn_args_utils.DataAccessor for converting input to + RecordBatch. tf_transform_output: A TFTransformOutput. - schema: the schema of the input data. + batch_size: representing the number of consecutive elements of returned + dataset to combine in a single batch Returns: - Tensorflow graph which parses examples, applying tf-transform to them. + A dataset that contains (features, indices) tuple where features is a + dictionary of Tensors, and indices is a single Tensor of label indices. """ - raw_feature_spec = _get_raw_feature_spec(schema) - raw_feature_spec.pop(_LABEL_KEY) - - raw_input_fn = tf_estimator.export.build_parsing_serving_input_receiver_fn( - raw_feature_spec, default_batch_size=None) - serving_input_receiver = raw_input_fn() - - transformed_features = tf_transform_output.transform_raw_features( - serving_input_receiver.features) - - return tf_estimator.export.ServingInputReceiver( - transformed_features, serving_input_receiver.receiver_tensors) + return data_accessor.tf_dataset_factory( + file_pattern, + dataset_options.TensorFlowDatasetOptions( + batch_size=batch_size, label_key=_transformed_name(_LABEL_KEY) + ), + tf_transform_output.transformed_metadata.schema, + ).repeat() -def _eval_input_receiver_fn(tf_transform_output, schema): - """Build everything needed for the tf-model-analysis to run the model. +def _build_keras_model( + hidden_units: Optional[list[int]] = None, +) -> tf.keras.Model: + """Creates a DNN Keras model for classifying taxi data. Args: - tf_transform_output: A TFTransformOutput. - schema: the schema of the input data. + hidden_units: [int], the layer sizes of the DNN (input layer first). Returns: - EvalInputReceiver function, which contains: - - Tensorflow graph which parses raw untransformed features, applies the - tf-transform preprocessing operators. - - Set of raw, untransformed features. - - Label against which predictions will be compared. + A Wide and Deep keras Model. """ - # Notice that the inputs are raw features, not transformed features here. - raw_feature_spec = _get_raw_feature_spec(schema) - - serialized_tf_example = tf.compat.v1.placeholder( - dtype=tf.string, shape=[None], name='input_example_tensor') - - # Add a parse_example operator to the tensorflow graph, which will parse - # raw, untransformed, tf examples. - features = tf.io.parse_example( - serialized=serialized_tf_example, features=raw_feature_spec) - - # Now that we have our raw examples, process them through the tf-transform - # function computed during the preprocessing step. - transformed_features = tf_transform_output.transform_raw_features( - features) + # Following values are hard coded for simplicity in this example, + # However prefarably they should be passsed in as hparams. - # The key name MUST be 'examples'. - receiver_tensors = {'examples': serialized_tf_example} - - # NOTE: Model is driven by transformed features (since training works on the - # materialized output of TFT, but slicing will happen on raw features. - features.update(transformed_features) - - return tfma.export.EvalInputReceiver( - features=features, - receiver_tensors=receiver_tensors, - labels=transformed_features[_transformed_name(_LABEL_KEY)]) + # Keras needs the feature definitions at compile time. + deep_input = { + colname: tf.keras.layers.Input(name=colname, shape=(1,), dtype=tf.float32) + for colname in _transformed_names(_DENSE_FLOAT_FEATURE_KEYS) + } + wide_vocab_input = { + colname: tf.keras.layers.Input(name=colname, shape=(1,), dtype='int32') + for colname in _transformed_names(_VOCAB_FEATURE_KEYS) + } + wide_bucket_input = { + colname: tf.keras.layers.Input(name=colname, shape=(1,), dtype='int32') + for colname in _transformed_names(_BUCKET_FEATURE_KEYS) + } + wide_categorical_input = { + colname: tf.keras.layers.Input(name=colname, shape=(1,), dtype='int32') + for colname in _transformed_names(_CATEGORICAL_FEATURE_KEYS) + } + input_layers = { + **deep_input, + **wide_vocab_input, + **wide_bucket_input, + **wide_categorical_input, + } - -def _input_fn( - filenames, data_accessor, tf_transform_output, batch_size=200): - """Generates features and labels for training or evaluation. + # TODO(b/161952382): Replace with Keras premade models and + # Keras preprocessing layers. + deep = tf.keras.layers.concatenate( + [tf.keras.layers.Normalization()(layer) for layer in deep_input.values()] + ) + for numnodes in (hidden_units or [100, 70, 50, 25]): + deep = tf.keras.layers.Dense(numnodes)(deep) + + wide_layers = [] + for key in _transformed_names(_VOCAB_FEATURE_KEYS): + wide_layers.append( + tf.keras.layers.CategoryEncoding(num_tokens=_VOCAB_SIZE + _OOV_SIZE)( + input_layers[key] + ) + ) + for key in _transformed_names(_BUCKET_FEATURE_KEYS): + wide_layers.append( + tf.keras.layers.CategoryEncoding(num_tokens=_FEATURE_BUCKET_COUNT)( + input_layers[key] + ) + ) + for key, num_tokens in zip( + _transformed_names(_CATEGORICAL_FEATURE_KEYS), + _MAX_CATEGORICAL_FEATURE_VALUES, + ): + wide_layers.append( + tf.keras.layers.CategoryEncoding(num_tokens=num_tokens)( + input_layers[key] + ) + ) + wide = tf.keras.layers.concatenate(wide_layers) + + output = tf.keras.layers.Dense(1, activation='sigmoid')( + tf.keras.layers.concatenate([deep, wide]) + ) + output = tf.squeeze(output, -1) + + model = tf.keras.Model(input_layers, output) + model.compile( + loss='binary_crossentropy', + optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), + metrics=[tf.keras.metrics.BinaryAccuracy()], + ) + model.summary(print_fn=logging.info) + return model + + +def stats_options_updater_fn(unused_stats_type, stats_options): + """Callback function for setting pre and post-transform stats options. Args: - filenames: [str] list of CSV files to read data from. - data_accessor: fn_args_utils.DataAccessor. - tf_transform_output: A TFTransformOutput. - batch_size: int First dimension size of the Tensors returned by input_fn + unused_stats_type: a stats_options_util.StatsType object. + stats_options: a tfdv.StatsOptions object. Returns: - A (features, indices) tuple where features is a dictionary of - Tensors, and indices is a single Tensor of label indices. + An updated tfdv.StatsOptions object. """ - dataset = data_accessor.tf_dataset_factory( - filenames, - TensorFlowDatasetOptions( - batch_size=batch_size, - label_key=_transformed_name(_LABEL_KEY)), - tf_transform_output.transformed_metadata.schema) + return stats_options - return tf.compat.v1.data.make_one_shot_iterator( - dataset).get_next() - -# TFX will call this function -def trainer_fn(trainer_fn_args, schema): - """Build the estimator using the high level API. +# TFX Transform will call this function. +def preprocessing_fn(inputs): + """tf.transform's callback function for preprocessing inputs. Args: - trainer_fn_args: Holds args used to train the model as name/value pairs. - schema: Holds the schema of the training examples. + inputs: map from feature keys to raw not-yet-transformed features. Returns: - A dict of the following: - - estimator: The estimator that will be used for training and eval. - - train_spec: Spec for training. - - eval_spec: Spec for eval. - - eval_input_receiver_fn: Input function for eval. + Map from string feature key to transformed feature operations. """ - if trainer_fn_args.hyperparameters: - hp = trainer_fn_args.hyperparameters - first_dnn_layer_size = hp.get('first_dnn_layer_size') - num_dnn_layers = hp.get('num_dnn_layers') - dnn_decay_factor = hp.get('dnn_decay_factor') - else: - # Number of nodes in the first layer of the DNN - first_dnn_layer_size = 100 - num_dnn_layers = 4 - dnn_decay_factor = 0.7 - - train_batch_size = 40 - eval_batch_size = 40 - - tf_transform_output = tft.TFTransformOutput(trainer_fn_args.transform_output) - - train_input_fn = lambda: _input_fn( # pylint: disable=g-long-lambda - trainer_fn_args.train_files, - trainer_fn_args.data_accessor, - tf_transform_output, - batch_size=train_batch_size) - - eval_input_fn = lambda: _input_fn( # pylint: disable=g-long-lambda - trainer_fn_args.eval_files, - trainer_fn_args.data_accessor, - tf_transform_output, - batch_size=eval_batch_size) - - train_spec = tf_estimator.TrainSpec( # pylint: disable=g-long-lambda - train_input_fn, - max_steps=trainer_fn_args.train_steps) - - serving_receiver_fn = lambda: _example_serving_receiver_fn( # pylint: disable=g-long-lambda - tf_transform_output, schema) - - exporter = tf_estimator.FinalExporter('chicago-taxi', serving_receiver_fn) - eval_spec = tf_estimator.EvalSpec( - eval_input_fn, - steps=trainer_fn_args.eval_steps, - exporters=[exporter], - name='chicago-taxi-eval') - - run_config = tf_estimator.RunConfig( - save_checkpoints_steps=999, - # keep_checkpoint_max must be more than the number of worker replicas - # nodes if training distributed, in order to avoid race condition. - keep_checkpoint_max=5) - - export_dir = path_utils.serving_model_dir(trainer_fn_args.model_run_dir) - run_config = run_config.replace(model_dir=export_dir) - warm_start_from = trainer_fn_args.base_model - - estimator = _build_estimator( - # Construct layers sizes with exponetial decay - hidden_units=[ - max(2, int(first_dnn_layer_size * dnn_decay_factor**i)) - for i in range(num_dnn_layers) - ], - config=run_config, - warm_start_from=warm_start_from) - - # Create an input receiver for TFMA processing - receiver_fn = lambda: _eval_input_receiver_fn( # pylint: disable=g-long-lambda - tf_transform_output, schema) - - return { - 'estimator': estimator, - 'train_spec': train_spec, - 'eval_spec': eval_spec, - 'eval_input_receiver_fn': receiver_fn - } - - -# TFX generic trainer will call this function -def run_fn(fn_args: executor.TrainerFnArgs): + outputs = {} + for key in _DENSE_FLOAT_FEATURE_KEYS: + # If sparse make it dense, setting nan's to 0 or '', and apply zscore. + outputs[_transformed_name(key)] = tft.scale_to_z_score( + _fill_in_missing(inputs[key]) + ) + + for key in _VOCAB_FEATURE_KEYS: + # Build a vocabulary for this feature. + outputs[_transformed_name(key)] = tft.compute_and_apply_vocabulary( + _fill_in_missing(inputs[key]), + top_k=_VOCAB_SIZE, + num_oov_buckets=_OOV_SIZE, + ) + + for key in _BUCKET_FEATURE_KEYS: + outputs[_transformed_name(key)] = tft.bucketize( + _fill_in_missing(inputs[key]), _FEATURE_BUCKET_COUNT + ) + + for key in _CATEGORICAL_FEATURE_KEYS: + outputs[_transformed_name(key)] = _fill_in_missing(inputs[key]) + + # Was this passenger a big tipper? + taxi_fare = _fill_in_missing(inputs[_FARE_KEY]) + tips = _fill_in_missing(inputs[_LABEL_KEY]) + outputs[_transformed_name(_LABEL_KEY)] = tf.where( + tf.math.is_nan(taxi_fare), + tf.cast(tf.zeros_like(taxi_fare), tf.int64), + # Test if the tip was > 20% of the fare. + tf.cast( + tf.greater(tips, tf.multiply(taxi_fare, tf.constant(0.2))), tf.int64 + ), + ) + + return outputs + + +# TFX Trainer will call this function. +def run_fn(fn_args: fn_args_utils.FnArgs): """Train the model based on given args. Args: fn_args: Holds args used to train the model as name/value pairs. """ - schema = io_utils.parse_pbtxt_file(fn_args.schema_file, schema_pb2.Schema()) - - training_spec = trainer_fn(fn_args, schema) - - # Train the model - absl.logging.info('Training model.') - tf_estimator.train_and_evaluate(training_spec['estimator'], - training_spec['train_spec'], - training_spec['eval_spec']) - - # Export an eval savedmodel for TFMA - # NOTE: When trained in distributed training cluster, eval_savedmodel must be - # exported only by the chief worker. - absl.logging.info('Exporting eval_savedmodel for TFMA.') - tfma.export.export_eval_savedmodel( - estimator=training_spec['estimator'], - export_dir_base=path_utils.eval_model_dir(fn_args.model_run_dir), - eval_input_receiver_fn=training_spec['eval_input_receiver_fn']) - - # TODO(b/160795287): Deprecate estimator based executor. - # Copy serving and eval model from model_run to model artifact directory. - serving_source = path_utils.serving_model_path(fn_args.model_run_dir) - io_utils.copy_dir(serving_source, fn_args.serving_model_dir) - - eval_source = path_utils.eval_model_path(fn_args.model_run_dir) - io_utils.copy_dir(eval_source, fn_args.eval_model_dir) - - absl.logging.info('Training complete. Model written to %s', - fn_args.serving_model_dir) - absl.logging.info('Exported eval_savedmodel to %s.', fn_args.eval_model_dir) + # Number of nodes in the first layer of the DNN + first_dnn_layer_size = 100 + num_dnn_layers = 4 + dnn_decay_factor = 0.7 + + tf_transform_output = tft.TFTransformOutput(fn_args.transform_graph_path) + + train_dataset = _input_fn( + fn_args.train_files, fn_args.data_accessor, tf_transform_output, 40 + ) + eval_dataset = _input_fn( + fn_args.eval_files, fn_args.data_accessor, tf_transform_output, 40 + ) + + mirrored_strategy = tf.distribute.MirroredStrategy() + with mirrored_strategy.scope(): + model = _build_keras_model( + # Construct layers sizes with exponetial decay + hidden_units=[ + max(2, int(first_dnn_layer_size * dnn_decay_factor**i)) + for i in range(num_dnn_layers) + ] + ) + + # Write logs to path + tensorboard_callback = tf.keras.callbacks.TensorBoard( + log_dir=fn_args.model_run_dir, update_freq='epoch' + ) + + model.fit( + train_dataset, + steps_per_epoch=fn_args.train_steps, + validation_data=eval_dataset, + validation_steps=fn_args.eval_steps, + callbacks=[tensorboard_callback], + ) + + signatures = { + 'serving_default': _get_tf_examples_serving_signature( + model, tf_transform_output + ), + 'transform_features': _get_transform_features_signature( + model, tf_transform_output + ), + } + model.save(fn_args.serving_model_dir, save_format='tf', signatures=signatures) diff --git a/tfx/components/testdata/module_file/transform_module.py b/tfx/components/testdata/module_file/transform_module.py deleted file mode 100644 index eac211009b..0000000000 --- a/tfx/components/testdata/module_file/transform_module.py +++ /dev/null @@ -1,145 +0,0 @@ -# Copyright 2019 Google LLC. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Python source file include taxi pipeline functions and necesasry utils. - -For a TFX pipeline to successfully run, a preprocessing_fn and a -_build_estimator function needs to be provided. This file contains both. - -This file is equivalent to examples/chicago_taxi/trainer/model.py and -examples/chicago_taxi/preprocess.py. -""" - -import tensorflow as tf -import tensorflow_transform as tft - - -_CATEGORICAL_FEATURE_KEYS = [ - 'trip_start_hour', 'trip_start_day', 'trip_start_month', - 'pickup_census_tract', 'dropoff_census_tract', 'pickup_community_area', - 'dropoff_community_area' -] - -_DENSE_FLOAT_FEATURE_KEYS = ['trip_miles', 'fare', 'trip_seconds'] - -# Number of buckets used by tf.transform for encoding each feature. -_FEATURE_BUCKET_COUNT = 10 - -_BUCKET_FEATURE_KEYS = [ - 'pickup_latitude', 'pickup_longitude', 'dropoff_latitude', - 'dropoff_longitude' -] - -# Number of vocabulary terms used for encoding VOCAB_FEATURES by tf.transform -_VOCAB_SIZE = 1000 - -# Count of out-of-vocab buckets in which unrecognized VOCAB_FEATURES are hashed. -_OOV_SIZE = 10 - -_VOCAB_FEATURE_KEYS = [ - 'payment_type', - 'company', -] - -# Keys -_LABEL_KEY = 'tips' -_FARE_KEY = 'fare' - - -def _transformed_name(key): - return key + '_xf' - - -def _fill_in_missing(x): - """Replace missing values in a SparseTensor. - - Fills in missing values of `x` with '' or 0, and converts to a dense tensor. - - Args: - x: A `SparseTensor` of rank 2. Its dense shape should have size at most 1 - in the second dimension. - - Returns: - A rank 1 tensor where missing values of `x` have been filled in. - """ - if not isinstance(x, tf.sparse.SparseTensor): - return x - - default_value = '' if x.dtype == tf.string else 0 - return tf.squeeze( - tf.sparse.to_dense( - tf.SparseTensor(x.indices, x.values, [x.dense_shape[0], 1]), - default_value), - axis=1) - - -@tf.function -def _identity(x): - """Make sure everything still works when there is a tf.function used.""" - return x - - -def preprocessing_fn(inputs, custom_config): - """tf.transform's callback function for preprocessing inputs. - - Args: - inputs: map from feature keys to raw not-yet-transformed features. - custom_config: additional properties for pre-processing. - - Returns: - Map from string feature key to transformed features. - """ - outputs = {} - for key in _DENSE_FLOAT_FEATURE_KEYS: - # If sparse make it dense, setting nan's to 0 or '', and apply zscore. - outputs[_transformed_name(key)] = tft.scale_to_z_score( - _fill_in_missing(_identity(inputs[key]))) - - for key in _VOCAB_FEATURE_KEYS: - # Build a vocabulary for this feature. - outputs[_transformed_name(key)] = tft.compute_and_apply_vocabulary( - _fill_in_missing(inputs[key]), - top_k=custom_config.get('VOCAB_SIZE', _VOCAB_SIZE), - num_oov_buckets=custom_config.get('OOV_SIZE', _OOV_SIZE)) - - for key in _BUCKET_FEATURE_KEYS: - outputs[_transformed_name(key)] = tft.bucketize( - _fill_in_missing(inputs[key]), _FEATURE_BUCKET_COUNT) - - for key in _CATEGORICAL_FEATURE_KEYS: - outputs[_transformed_name(key)] = _fill_in_missing(inputs[key]) - - # Was this passenger a big tipper? - taxi_fare = _fill_in_missing(inputs[_FARE_KEY]) - tips = _fill_in_missing(inputs[_LABEL_KEY]) - outputs[_transformed_name(_LABEL_KEY)] = tf.compat.v1.where( - tf.math.is_nan(taxi_fare), - tf.cast(tf.zeros_like(taxi_fare), tf.int64), - # Test if the tip was > 20% of the fare. - tf.cast( - tf.greater(tips, tf.multiply(taxi_fare, tf.constant(0.2))), tf.int64)) - - return outputs - - -def stats_options_updater_fn(unused_stats_type, stats_options): - """Callback function for setting pre and post-transform stats options. - - Args: - unused_stats_type: a stats_options_util.StatsType object. - stats_options: a tfdv.StatsOptions object. - - Returns: - An updated tfdv.StatsOptions object. - """ - return stats_options diff --git a/tfx/components/testdata/transform/transform_graph/transform_fn/saved_model.pb b/tfx/components/testdata/transform/transform_graph/transform_fn/saved_model.pb index 238753cec0..2d4389ff2c 100644 Binary files a/tfx/components/testdata/transform/transform_graph/transform_fn/saved_model.pb and b/tfx/components/testdata/transform/transform_graph/transform_fn/saved_model.pb differ diff --git a/tfx/components/testdata/transform/transform_graph/transformed_metadata/schema.pbtxt b/tfx/components/testdata/transform/transform_graph/transformed_metadata/schema.pbtxt index a0bf9aefb8..9fcc61ca73 100644 --- a/tfx/components/testdata/transform/transform_graph/transformed_metadata/schema.pbtxt +++ b/tfx/components/testdata/transform/transform_graph/transformed_metadata/schema.pbtxt @@ -10,6 +10,9 @@ feature { min_fraction: 1.0 } shape { + dim { + size: 1 + } } } feature { @@ -19,6 +22,9 @@ feature { min_fraction: 1.0 } shape { + dim { + size: 1 + } } } feature { @@ -28,6 +34,9 @@ feature { min_fraction: 1.0 } shape { + dim { + size: 1 + } } } feature { @@ -42,6 +51,9 @@ feature { min_fraction: 1.0 } shape { + dim { + size: 1 + } } } feature { @@ -56,6 +68,9 @@ feature { min_fraction: 1.0 } shape { + dim { + size: 1 + } } } feature { @@ -65,6 +80,9 @@ feature { min_fraction: 1.0 } shape { + dim { + size: 1 + } } } feature { @@ -79,6 +97,9 @@ feature { min_fraction: 1.0 } shape { + dim { + size: 1 + } } } feature { @@ -88,6 +109,9 @@ feature { min_fraction: 1.0 } shape { + dim { + size: 1 + } } } feature { @@ -97,6 +121,9 @@ feature { min_fraction: 1.0 } shape { + dim { + size: 1 + } } } feature { @@ -111,6 +138,9 @@ feature { min_fraction: 1.0 } shape { + dim { + size: 1 + } } } feature { @@ -125,6 +155,9 @@ feature { min_fraction: 1.0 } shape { + dim { + size: 1 + } } } feature { @@ -134,6 +167,9 @@ feature { min_fraction: 1.0 } shape { + dim { + size: 1 + } } } feature { @@ -143,6 +179,9 @@ feature { min_fraction: 1.0 } shape { + dim { + size: 1 + } } } feature { @@ -152,6 +191,9 @@ feature { min_fraction: 1.0 } shape { + dim { + size: 1 + } } } feature { @@ -161,6 +203,9 @@ feature { min_fraction: 1.0 } shape { + dim { + size: 1 + } } } feature { @@ -170,6 +215,9 @@ feature { min_fraction: 1.0 } shape { + dim { + size: 1 + } } } feature { @@ -179,6 +227,8 @@ feature { min_fraction: 1.0 } shape { + dim { + size: 1 + } } } -# generate_legacy_feature_spec: false diff --git a/tfx/components/testdata/transform/transformed_examples/Split-eval/transformed_examples-00000-of-00001.gz b/tfx/components/testdata/transform/transformed_examples/Split-eval/transformed_examples-00000-of-00001.gz index 376504a59d..49b883e95f 100644 Binary files a/tfx/components/testdata/transform/transformed_examples/Split-eval/transformed_examples-00000-of-00001.gz and b/tfx/components/testdata/transform/transformed_examples/Split-eval/transformed_examples-00000-of-00001.gz differ diff --git a/tfx/components/testdata/transform/transformed_examples/Split-train/transformed_examples-00000-of-00001.gz b/tfx/components/testdata/transform/transformed_examples/Split-train/transformed_examples-00000-of-00001.gz index 874b435c17..103266d34f 100644 Binary files a/tfx/components/testdata/transform/transformed_examples/Split-train/transformed_examples-00000-of-00001.gz and b/tfx/components/testdata/transform/transformed_examples/Split-train/transformed_examples-00000-of-00001.gz differ diff --git a/tfx/components/trainer/component.py b/tfx/components/trainer/component.py index 93dd4052cc..4efd9beb64 100644 --- a/tfx/components/trainer/component.py +++ b/tfx/components/trainer/component.py @@ -80,8 +80,6 @@ def __init__( hyperparameters: Optional[types.BaseChannel] = None, module_file: Optional[Union[str, data_types.RuntimeParameter]] = None, run_fn: Optional[Union[str, data_types.RuntimeParameter]] = None, - # TODO(b/147702778): deprecate trainer_fn. - trainer_fn: Optional[Union[str, data_types.RuntimeParameter]] = None, train_args: Optional[Union[trainer_pb2.TrainArgs, data_types.RuntimeParameter]] = None, eval_args: Optional[Union[trainer_pb2.EvalArgs, @@ -122,21 +120,6 @@ def run_fn(trainer.fn_args_utils.FnArgs) and the trained model must be saved to `FnArgs.serving_model_dir` when this function is executed. - For Estimator based Executor, The `module_file` must implement a function - named `trainer_fn` at its top level. The function must have the - following signature. - ``` python - def trainer_fn(trainer.fn_args_utils.FnArgs, - tensorflow_metadata.proto.v0.schema_pb2) -> Dict: - ... - ``` - where the returned Dict has the following key-values. - - - `estimator`: an instance of `tf.estimator.Estimator` - - `train_spec`: an instance of `tf.estimator.TrainSpec` - - `eval_spec`: an instance of `tf.estimator.EvalSpec` - - `eval_input_receiver_fn`: an instance of tfma `EvalInputReceiver`. - Exactly one of `module_file` or `run_fn` must be supplied if Trainer uses GenericExecutor (default). Use of a [RuntimeParameter][tfx.v1.dsl.experimental.RuntimeParameter] for this argument is experimental. @@ -144,11 +127,6 @@ def trainer_fn(trainer.fn_args_utils.FnArgs, trainer. See 'module_file' for details. Exactly one of 'module_file' or 'run_fn' must be supplied if Trainer uses GenericExecutor (default). Use of a [RuntimeParameter][tfx.v1.dsl.experimental.RuntimeParameter] for this argument is experimental. - trainer_fn: A python path to UDF model definition function for estimator - based trainer. See 'module_file' for the required signature of the UDF. - Exactly one of 'module_file' or 'trainer_fn' must be supplied if Trainer - uses Estimator based Executor. Use of a [RuntimeParameter][tfx.v1.dsl.experimental.RuntimeParameter] for this - argument is experimental. train_args: A proto.TrainArgs instance, containing args used for training Currently only splits and num_steps are available. Default behavior (when splits is empty) is train on `train` split. @@ -162,17 +140,15 @@ def trainer_fn(trainer.fn_args_utils.FnArgs, Raises: ValueError: - - When both or neither of `module_file` and user function - (e.g., trainer_fn and run_fn) is supplied. + - When both or neither of `module_file` and `run_fn` is supplied. - When both or neither of `examples` and `transformed_examples` is supplied. - When `transformed_examples` is supplied but `transform_graph` is not supplied. """ - if [bool(module_file), bool(run_fn), bool(trainer_fn)].count(True) != 1: + if [bool(module_file), bool(run_fn)].count(True) != 1: raise ValueError( - "Exactly one of 'module_file', 'trainer_fn', or 'run_fn' must be " - "supplied.") + "Exactly one of 'module_file', or 'run_fn' must be supplied.") if bool(examples) == bool(transformed_examples): raise ValueError( @@ -203,7 +179,6 @@ def trainer_fn(trainer.fn_args_utils.FnArgs, eval_args=eval_args or trainer_pb2.EvalArgs(), module_file=module_file, run_fn=run_fn, - trainer_fn=trainer_fn, custom_config=(custom_config if isinstance(custom_config, data_types.RuntimeParameter) else json_utils.dumps(custom_config)), diff --git a/tfx/components/trainer/component_test.py b/tfx/components/trainer/component_test.py index 0d5dda7438..de9ea0fe9a 100644 --- a/tfx/components/trainer/component_test.py +++ b/tfx/components/trainer/component_test.py @@ -78,19 +78,6 @@ def testConstructWithParameter(self): str(trainer.spec.exec_properties[ standard_component_specs.MODULE_FILE_KEY])) - def testConstructFromTrainerFn(self): - trainer_fn = 'path.to.my_trainer_fn' - trainer = component.Trainer( - trainer_fn=trainer_fn, - examples=self.examples, - transform_graph=self.transform_graph, - train_args=self.train_args, - eval_args=self.eval_args) - self._verify_outputs(trainer) - self.assertEqual( - trainer_fn, - trainer.spec.exec_properties[standard_component_specs.TRAINER_FN_KEY]) - def testConstructFromRunFn(self): run_fn = 'path.to.my_run_fn' trainer = component.Trainer( @@ -147,16 +134,6 @@ def testConstructMissingUserModule(self): eval_args=self.eval_args) def testConstructDuplicateUserModule(self): - with self.assertRaises(ValueError): - _ = component.Trainer( - module_file='/path/to/module/file', - trainer_fn='path.to.my_trainer_fn', - examples=self.examples, - transform_graph=self.transform_graph, - schema=self.schema, - train_args=self.train_args, - eval_args=self.eval_args) - with self.assertRaises(ValueError): _ = component.Trainer( module_file='/path/to/module/file', @@ -169,7 +146,7 @@ def testConstructDuplicateUserModule(self): def testConstructWithHParams(self): trainer = component.Trainer( - trainer_fn='path.to.my_trainer_fn', + module_file='/path/to/module/file', examples=self.examples, transform_graph=self.transform_graph, schema=self.schema, @@ -193,7 +170,7 @@ def testConstructWithRuntimeParam(self): ptype=str, ) trainer = component.Trainer( - trainer_fn='path.to.my_trainer_fn', + module_file='/path/to/module/file', examples=self.examples, train_args=self.train_args, eval_args=eval_args, diff --git a/tfx/components/trainer/executor.py b/tfx/components/trainer/executor.py index 0fe867a052..0d086fc295 100644 --- a/tfx/components/trainer/executor.py +++ b/tfx/components/trainer/executor.py @@ -18,8 +18,6 @@ from typing import Any, Dict, List import absl -from tensorflow import estimator as tf_estimator -import tensorflow_model_analysis as tfma from tfx import types from tfx.components.trainer import constants from tfx.components.trainer import fn_args_utils @@ -33,7 +31,6 @@ from tfx.utils import path_utils from tensorflow.python.lib.io import file_io # pylint: disable=g-direct-tensorflow-import -from tensorflow_metadata.proto.v0 import schema_pb2 TrainerFnArgs = deprecation_utils.deprecated_alias( # pylint: disable=invalid-name @@ -185,118 +182,3 @@ def Do(self, input_dict: Dict[str, List[types.Artifact]], absl.logging.info( 'Training complete. Model written to %s. ModelRun written to %s', fn_args.serving_model_dir, fn_args.model_run_dir) - - -class Executor(GenericExecutor): - """Local estimator based trainer executor used by the TFX Trainer component. - - How to create a trainer callback function to be used by this Trainer executor: - An estimator can be executed by TFX by first creating a trainer_fn callback - method that returns an estimator and some additional parameters, similar to - https://github.com/tensorflow/tfx/blob/master/tfx/examples/chicago_taxi_pipeline/taxi_utils.py#L285. - This becomes the basis of the new Executor for Trainer. This Executor will - then train and evaluate this estimator using the - tf.estimator.train_and_evaluate API to train locally. - """ - - def Do(self, input_dict: Dict[str, List[types.Artifact]], - output_dict: Dict[str, List[types.Artifact]], - exec_properties: Dict[str, Any]) -> None: - """Uses a user-supplied tf.estimator to train a TensorFlow model locally. - - The Trainer Executor invokes a training_fn callback function provided by - the user via the module_file parameter. With the tf.estimator returned by - this function, the Trainer Executor then builds a TensorFlow model using the - user-provided tf.estimator. - - Args: - input_dict: Input dict from input key to a list of ML-Metadata Artifacts. - - examples: Examples used for training, must include 'train' and 'eval' - if custom splits is not specified in train_args and eval_args. - - transform_graph: Optional input transform graph. - - schema: Schema of the data. - output_dict: Output dict from output key to a list of Artifacts. - - model: Exported model. - - model_run: Model training related outputs (e.g., Tensorboard logs) - exec_properties: A dict of execution properties. - - train_args: JSON string of trainer_pb2.TrainArgs instance, providing - args for training. - - eval_args: JSON string of trainer_pb2.EvalArgs instance, providing - args for eval. - - module_file: Python module file containing UDF model definition. - Exactly one of `module_file`, `module_path` and `trainer_fn` should - be passed. - - module_path: Python module path containing UDF model definition. - Exactly one of `module_file`, `module_path` and `trainer_fn` should - be passed. - - trainer_fn: Python module path to the trainer function. - Exactly one of `module_file`, `module_path` and `trainer_fn` should - be passed. - - warm_starting: Whether or not we need to do warm starting. - - warm_start_from: Optional. If warm_starting is True, this is the - directory to find previous model to warm start on. - - custom_config: Optional. JSON-serialized dict of additional parameters - to pass to trainer function. - - Returns: - None - - Raises: - ValueError: When not exactly one of `module_file`, `module_path` and - `trainer_fn` are present in `exec_properties`. - """ - self._log_startup(input_dict, output_dict, exec_properties) - - fn_args = self._GetFnArgs(input_dict, output_dict, exec_properties) - trainer_fn = udf_utils.get_fn(exec_properties, 'trainer_fn') - - schema = io_utils.parse_pbtxt_file(fn_args.schema_file, schema_pb2.Schema()) - - # TODO(b/160795287): Deprecate estimator based executor. - # Provide user with a modified fn_args, with model_run given as - # the working directory. Executor will then copy user models to - # model artifact directory. - serving_dest = fn_args.serving_model_dir - eval_dest = fn_args.eval_model_dir - - working_dir = fn_args.model_run_dir - fn_args.serving_model_dir = path_utils.serving_model_dir(working_dir) - fn_args.eval_model_dir = path_utils.eval_model_dir(working_dir) - - training_spec = trainer_fn(fn_args, schema) - - # Train the model - absl.logging.info('Training model.') - tf_estimator.train_and_evaluate(training_spec['estimator'], - training_spec['train_spec'], - training_spec['eval_spec']) - - absl.logging.info( - 'Training complete. Model written to %s. ModelRun written to %s', - fn_args.serving_model_dir, fn_args.model_run_dir) - - # Export an eval savedmodel for TFMA. If distributed training, it must only - # be written by the chief worker, as would be done for serving savedmodel. - if _is_chief(): - absl.logging.info('Exporting eval_savedmodel for TFMA.') - tfma.export.export_eval_savedmodel( - estimator=training_spec['estimator'], - export_dir_base=fn_args.eval_model_dir, - eval_input_receiver_fn=training_spec['eval_input_receiver_fn']) - - absl.logging.info('Exported eval_savedmodel to %s.', - fn_args.eval_model_dir) - - # TODO(b/160795287): Deprecate estimator based executor. - # Copy serving and eval model from model_run to model artifact directory. - serving_source = path_utils.serving_model_path(fn_args.model_run_dir) - io_utils.copy_dir(serving_source, serving_dest) - absl.logging.info('Serving model copied to: %s.', serving_dest) - - eval_source = path_utils.eval_model_path(fn_args.model_run_dir) - io_utils.copy_dir(eval_source, eval_dest) - absl.logging.info('Eval model copied to: %s.', eval_dest) - - else: - absl.logging.info( - 'Model export is skipped because this is not the chief worker.') diff --git a/tfx/components/trainer/executor_test.py b/tfx/components/trainer/executor_test.py index b11f278b52..d3d0dd57af 100644 --- a/tfx/components/trainer/executor_test.py +++ b/tfx/components/trainer/executor_test.py @@ -16,10 +16,8 @@ import copy import json import os -from unittest import mock import tensorflow as tf -from tfx.components.testdata.module_file import trainer_module from tfx.components.trainer import executor from tfx.dsl.io import fileio from tfx.proto import trainer_pb2 @@ -27,7 +25,6 @@ from tfx.types import standard_artifacts from tfx.types import standard_component_specs from tfx.utils import io_utils -from tfx.utils import name_utils from tfx.utils import path_utils from tfx.utils import proto_utils @@ -94,22 +91,14 @@ def setUp(self): self._module_file = os.path.join(self._source_data_dir, standard_component_specs.MODULE_FILE_KEY, 'trainer_module.py') - self._trainer_fn = name_utils.get_full_name(trainer_module.trainer_fn) - # Executors for test. - self._trainer_executor = executor.Executor() - self._generic_trainer_executor = executor.GenericExecutor() + # Executor for test. + self._executor = executor.GenericExecutor() def _verify_model_exports(self): - self.assertTrue( - fileio.exists(path_utils.eval_model_dir(self._model_exports.uri))) self.assertTrue( fileio.exists(path_utils.serving_model_dir(self._model_exports.uri))) - def _verify_no_eval_model_exports(self): - self.assertFalse( - fileio.exists(path_utils.eval_model_dir(self._model_exports.uri))) - def _verify_model_run_exports(self): self.assertTrue(fileio.exists(os.path.dirname(self._model_run_exports.uri))) @@ -119,49 +108,13 @@ def _do(self, test_executor): output_dict=self._output_dict, exec_properties=self._exec_properties) - def testGenericExecutor(self): - self._exec_properties[ - standard_component_specs.MODULE_FILE_KEY] = self._module_file - self._do(self._generic_trainer_executor) - self._verify_model_exports() - self._verify_model_run_exports() - - @mock.patch('tfx.components.trainer.executor._is_chief') - def testDoChief(self, mock_is_chief): - mock_is_chief.return_value = True + def testDo(self): self._exec_properties[ standard_component_specs.MODULE_FILE_KEY] = self._module_file - self._do(self._trainer_executor) + self._do(self._executor) self._verify_model_exports() self._verify_model_run_exports() - @mock.patch('tfx.components.trainer.executor._is_chief') - def testDoNonChief(self, mock_is_chief): - mock_is_chief.return_value = False - self._exec_properties[ - standard_component_specs.MODULE_FILE_KEY] = self._module_file - self._do(self._trainer_executor) - self._verify_no_eval_model_exports() - self._verify_model_run_exports() - - def testDoWithModuleFile(self): - self._exec_properties[ - standard_component_specs.MODULE_FILE_KEY] = self._module_file - self._do(self._trainer_executor) - self._verify_model_exports() - self._verify_model_run_exports() - - def testDoWithTrainerFn(self): - self._exec_properties[ - standard_component_specs.TRAINER_FN_KEY] = self._trainer_fn - self._do(self._trainer_executor) - self._verify_model_exports() - self._verify_model_run_exports() - - def testDoWithNoTrainerFn(self): - with self.assertRaises(ValueError): - self._do(self._trainer_executor) - def testDoWithHyperParameters(self): hp_artifact = standard_artifacts.HyperParameters() hp_artifact.uri = os.path.join(self._output_data_dir, 'hyperparameters/') @@ -181,7 +134,7 @@ def testDoWithHyperParameters(self): self._exec_properties[ standard_component_specs.MODULE_FILE_KEY] = self._module_file - self._do(self._trainer_executor) + self._do(self._executor) self._verify_model_exports() self._verify_model_run_exports() @@ -190,36 +143,6 @@ def testMultipleArtifacts(self): standard_component_specs.EXAMPLES_KEY] = self._multiple_artifacts self._exec_properties[ standard_component_specs.MODULE_FILE_KEY] = self._module_file - self._do(self._generic_trainer_executor) - self._verify_model_exports() - self._verify_model_run_exports() - - def testDoWithCustomSplits(self): - # Update input dict. - io_utils.copy_dir( - os.path.join(self._source_data_dir, - 'transform/transformed_examples/Split-train'), - os.path.join(self._output_data_dir, 'data/Split-training')) - io_utils.copy_dir( - os.path.join(self._source_data_dir, - 'transform/transformed_examples/Split-eval'), - os.path.join(self._output_data_dir, 'data/Split-evaluating')) - examples = standard_artifacts.Examples() - examples.uri = os.path.join(self._output_data_dir, 'data') - examples.split_names = artifact_utils.encode_split_names( - ['training', 'evaluating']) - self._input_dict[standard_component_specs.EXAMPLES_KEY] = [examples] - - # Update exec properties skeleton with custom splits. - self._exec_properties[ - standard_component_specs.TRAIN_ARGS_KEY] = proto_utils.proto_to_json( - trainer_pb2.TrainArgs(splits=['training'], num_steps=1000)) - self._exec_properties[ - standard_component_specs.EVAL_ARGS_KEY] = proto_utils.proto_to_json( - trainer_pb2.EvalArgs(splits=['evaluating'], num_steps=500)) - - self._exec_properties[ - standard_component_specs.MODULE_FILE_KEY] = self._module_file - self._do(self._trainer_executor) + self._do(self._executor) self._verify_model_exports() self._verify_model_run_exports() diff --git a/tfx/components/trainer/rewriting/README.md b/tfx/components/trainer/rewriting/README.md deleted file mode 100644 index 10568ff0e4..0000000000 --- a/tfx/components/trainer/rewriting/README.md +++ /dev/null @@ -1,75 +0,0 @@ -# Model Rewriting Library - -The TFX model rewriting library makes it simple to make post-training -modifications (i.e. rewrites) to models within TFX. These modifications can vary -from small-scale edits (e.g. signature changes) to wholesale model conversions -from one type to another (e.g. from SavedModel to -[TFLite](https://www.tensorflow.org/lite)). - -The library is invoked from user code in the Trainer. We both make it simple to -create custom rewrites and provide a set of commonly-used ones. For example, -the -[TFLiteRewriter](https://github.com/tensorflow/tfx/blob/master/tfx/components/trainer/rewriting/tflite_rewriter.py) -converts SavedModels to TFLite. - -## Using rewriters -To instantiate a rewriter, use the rewriter factory. - -```python -from tfx.components.trainer.rewriting import rewriter_factory - -... - -tfrw = rewriter_factory.create_rewriter( - rewriter_factory.TFLITE_REWRITER, name='my_rewriter') -``` - -Then use the appropriate converter (`RewritingExporter` for Estimators or -`rewrite_saved_model` for Keras) to rewrite your model. - -When using Estimators, we recommend you invoke these converters in the -`trainer_fn` definition in the utils file of your pipeline. For example, in the -chicago taxi pipeline, this would be the taxi_utils.py -[file](https://github.com/tensorflow/tfx/blob/master/tfx/examples/chicago_taxi_pipeline/taxi_utils.py) -and the changes would be as follows: - -```python -import tensorflow as tf -from tfx.components.trainer.rewriting import converters - -... - -base_exporter = tf.estimator.FinalExporter('chicago-taxi', serving_receiver_fn) -rewriting_exporter = converters.RewritingExporter(base_exporter, tfrw) -eval_spec = tf.estimator.EvalSpec( - eval_input_fn, - steps=trainer_fn_args.eval_steps, - exporters=[rewriting_exporter], - name='chicago-taxi-eval') -``` -For Keras, we recommend you invoke these converters in the `run_fn` definition -in the utils file of your pipeline. For example, for the MNIST pipeline, this -would be the mnist_utils_native_keras_lite.py -[file](https://github.com/tensorflow/tfx/blob/master/tfx/examples/mnist/mnist_utils_native_keras_lite.py) -and the changes would be as follows: - -```python -import tensorflow as tf -from tfx.components.trainer.rewriting import converters - -... - -model.save('/path/to/model', save_format='tf', signatures=signatures) -converters.rewrite_saved_model('/path/to/model', '/path/to/rewritten/model', - tfrw) -``` -A complete end-to-end pipeline that uses the TFLite rewriter can be found [here](https://github.com/tensorflow/tfx/blob/master/tfx/examples/mnist/mnist_pipeline_native_keras.py). - - -## Creating new rewriters - -To create new rewriters, simply take the following steps: - -* Define a rewriter that inherits from `BaseRewriter` in rewriter.py. - -* Import the rewriter and add a constant to rewriter_factory.py. diff --git a/tfx/components/trainer/rewriting/converters.py b/tfx/components/trainer/rewriting/converters.py deleted file mode 100644 index 5b743c0f5b..0000000000 --- a/tfx/components/trainer/rewriting/converters.py +++ /dev/null @@ -1,131 +0,0 @@ -# Copyright 2020 Google LLC. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Converters rewrite models using the provided rewriters.""" - -import os -import time - - -from tensorflow import estimator as tf_estimator -from tfx.components.trainer.rewriting import rewriter -from tfx.dsl.io import fileio - - -def _invoke_rewriter(src: str, dst: str, rewriter_inst: rewriter.BaseRewriter, - src_model_type: rewriter.ModelType, - dst_model_type: rewriter.ModelType): - """Converts the provided model by invoking the specified rewriters. - - Args: - src: Path to the source model. - dst: Path where the destination model is to be written. - rewriter_inst: instance of the rewriter to invoke. - src_model_type: the `rewriter.ModelType` of the source model. - dst_model_type: the `rewriter.ModelType` of the destination model. - - Raises: - ValueError: if the source path is the same as the destination path. - """ - - if src == dst: - raise ValueError('Source path and destination path cannot match.') - - original_model = rewriter.ModelDescription(src_model_type, src) - rewritten_model = rewriter.ModelDescription(dst_model_type, dst) - - rewriter_inst.perform_rewrite(original_model, rewritten_model) - - -class RewritingExporter(tf_estimator.Exporter): - """This class invokes the base exporter and a series of rewriters.""" - - def __init__(self, base_exporter: tf_estimator.Exporter, - rewriter_inst: rewriter.BaseRewriter): - """Initializes the rewriting exporter. - - Args: - base_exporter: The exporter of the original model. - rewriter_inst: The rewriter instance to invoke. Must inherit from - `rewriter.BaseRewriter`. - """ - self._base_exporter = base_exporter - self._rewriter_inst = rewriter_inst - - @property - def name(self): - """Name of the exporter.""" - return self._base_exporter.name - - def export(self, estimator, export_path, checkpoint_path, eval_result, - is_the_final_export): - """Exports the given `Estimator` to a specific format. - - Performs the export as defined by the base_exporter and invokes all of the - specified rewriters. - - Args: - estimator: the `Estimator` to export. - export_path: A string containing a directory where to write the export. - checkpoint_path: The checkpoint path to export. - eval_result: The output of `Estimator.evaluate` on this checkpoint. - is_the_final_export: This boolean is True when this is an export in the - end of training. It is False for the intermediate exports during the - training. When passing `Exporter` to `tf.estimator.train_and_evaluate` - `is_the_final_export` is always False if `TrainSpec.max_steps` is - `None`. - - Returns: - The string path to the base exported directory or `None` if export is - skipped. - - Raises: - RuntimeError: Unable to create a temporary rewrite directory. - """ - base_path = self._base_exporter.export(estimator, export_path, - checkpoint_path, eval_result, - is_the_final_export) - if not base_path: - return None - - tmp_rewrite_folder = 'tmp-rewrite-' + str(int(time.time())) - tmp_rewrite_path = os.path.join(export_path, tmp_rewrite_folder) - if fileio.exists(tmp_rewrite_path): - raise RuntimeError('Unable to create a unique temporary rewrite path.') - fileio.makedirs(tmp_rewrite_path) - - _invoke_rewriter(base_path, tmp_rewrite_path, self._rewriter_inst, - rewriter.ModelType.SAVED_MODEL, - rewriter.ModelType.ANY_MODEL) - - fileio.rmtree(base_path) - fileio.rename(tmp_rewrite_path, base_path) - return base_path - - -def rewrite_saved_model( - src: str, - dst: str, - rewriter_inst: rewriter.BaseRewriter, - dst_model_type: rewriter.ModelType = rewriter.ModelType.SAVED_MODEL): - """Rewrites the provided SavedModel. - - Args: - src: location of the saved_model to rewrite. - dst: location of the rewritten saved_model. - rewriter_inst: the rewriter instance to invoke. Must inherit from - `rewriter.BaseRewriter`. - dst_model_type: the `rewriter.ModelType` of the destination model. - """ - _invoke_rewriter(src, dst, rewriter_inst, rewriter.ModelType.SAVED_MODEL, - dst_model_type) diff --git a/tfx/components/trainer/rewriting/converters_test.py b/tfx/components/trainer/rewriting/converters_test.py deleted file mode 100644 index f75a7414b0..0000000000 --- a/tfx/components/trainer/rewriting/converters_test.py +++ /dev/null @@ -1,174 +0,0 @@ -# Copyright 2020 Google LLC. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Tests for third_party.tfx.components.trainer.rewriting.converters.""" - -import os -import tempfile - -from absl.testing.absltest import mock - -import tensorflow as tf - -from tensorflow import estimator as tf_estimator -from tfx.components.trainer.rewriting import converters -from tfx.components.trainer.rewriting import rewriter -from tfx.dsl.io import fileio - -BASE_EXPORT_SUBDIR = 'export_1' -ORIGINAL_SAVED_MODEL = 'saved_model.pbtxt' -ORIGINAL_VOCAB = 'vocab' -REWRITTEN_SAVED_MODEL = 'rewritten_model.pbtxt' -REWRITTEN_VOCAB = 'rewritten_vocab' - - -def _export_fn(estimator, export_path, checkpoint_path, eval_result, - is_the_final_export): - del estimator, checkpoint_path, eval_result, is_the_final_export - path = os.path.join(export_path, BASE_EXPORT_SUBDIR) - fileio.makedirs(path) - with fileio.open(os.path.join(path, ORIGINAL_SAVED_MODEL), 'w') as f: - f.write(str(ORIGINAL_SAVED_MODEL)) - - assets_path = os.path.join(path, tf.saved_model.ASSETS_DIRECTORY) - fileio.makedirs(assets_path) - with fileio.open(os.path.join(assets_path, ORIGINAL_VOCAB), 'w') as f: - f.write(str(ORIGINAL_VOCAB)) - - return path - - -class RewritingExporterTest(tf.test.TestCase): - - class _TestRewriter(rewriter.BaseRewriter): - - def __init__(self, rewrite_raises_error): - """Initializes the MyRewriter class. - - Args: - rewrite_raises_error: Boolean specifying whether to raise a ValueError. - """ - self._rewrite_raises_error = rewrite_raises_error - self.rewrite_called = False - - @property - def name(self): - return 'test_rewriter' - - def _pre_rewrite_validate(self, original_model): - pass - - def _rewrite(self, original_model, rewritten_model): - self.rewrite_called = True - assert fileio.exists( - os.path.join(original_model.path, ORIGINAL_SAVED_MODEL)) - assert fileio.exists( - os.path.join(original_model.path, tf.saved_model.ASSETS_DIRECTORY, - ORIGINAL_VOCAB)) - with fileio.open( - os.path.join(rewritten_model.path, REWRITTEN_SAVED_MODEL), 'w') as f: - f.write(str(REWRITTEN_SAVED_MODEL)) - assets_path = os.path.join(rewritten_model.path, - tf.saved_model.ASSETS_DIRECTORY) - fileio.makedirs(assets_path) - with fileio.open(os.path.join(assets_path, REWRITTEN_VOCAB), 'w') as f: - f.write(str(REWRITTEN_VOCAB)) - if self._rewrite_raises_error: - raise ValueError('rewrite-error') - - def _post_rewrite_validate(self, rewritten_model): - pass - - def setUp(self): - super().setUp() - self._estimator = 'estimator' - self._export_path = tempfile.mkdtemp() - self._checkpoint_path = 'checkpoint_path' - self._eval_result = 'eval_result' - self._is_the_final_export = True - self._base_exporter = tf_estimator.FinalExporter( - name='base_exporter', serving_input_receiver_fn=lambda: None) - - @mock.patch.object(tf_estimator.FinalExporter, 'export') - def testRewritingExporterSucceeds(self, base_exporter_mock): - - base_exporter_mock.side_effect = _export_fn - - tr = self._TestRewriter(False) - r_e = converters.RewritingExporter(self._base_exporter, tr) - final_path = r_e.export(self._estimator, self._export_path, - self._checkpoint_path, self._eval_result, - self._is_the_final_export) - self.assertEqual(final_path, - os.path.join(self._export_path, BASE_EXPORT_SUBDIR)) - self.assertTrue( - fileio.exists(os.path.join(final_path, REWRITTEN_SAVED_MODEL))) - self.assertTrue( - fileio.exists( - os.path.join(final_path, tf.saved_model.ASSETS_DIRECTORY, - REWRITTEN_VOCAB))) - - base_exporter_mock.assert_called_once_with(self._estimator, - self._export_path, - self._checkpoint_path, - self._eval_result, - self._is_the_final_export) - - @mock.patch.object(tf_estimator.FinalExporter, 'export') - def testRewritingHandlesNoBaseExport(self, base_exporter_mock): - - base_exporter_mock.return_value = None - - tr = self._TestRewriter(False) - r_e = converters.RewritingExporter(self._base_exporter, tr) - final_path = r_e.export(self._estimator, self._export_path, - self._checkpoint_path, self._eval_result, - self._is_the_final_export) - self.assertIsNone(final_path) - self.assertFalse(tr.rewrite_called) - - base_exporter_mock.assert_called_once_with(self._estimator, - self._export_path, - self._checkpoint_path, - self._eval_result, - self._is_the_final_export) - - @mock.patch.object(tf_estimator.FinalExporter, 'export') - def testRewritingExporterHandlesError(self, base_exporter_mock): - - base_exporter_mock.side_effect = _export_fn - - tr = self._TestRewriter(True) - r_e = converters.RewritingExporter(self._base_exporter, tr) - with self.assertRaisesRegex(ValueError, '.*rewrite-error'): - r_e.export(self._estimator, self._export_path, self._checkpoint_path, - self._eval_result, self._is_the_final_export) - base_exporter_mock.assert_called_once_with(self._estimator, - self._export_path, - self._checkpoint_path, - self._eval_result, - self._is_the_final_export) - self.assertTrue(tr.rewrite_called) - - -class RewriteSavedModelTest(tf.test.TestCase): - - @mock.patch.object(converters, '_invoke_rewriter') - def testRewritingExporterSucceeds(self, invoke_rewriter_mock): - src = '/my/src' - dst = '/my/dst' - rewriter_inst = 'r1' - converters.rewrite_saved_model(src, dst, rewriter_inst) - invoke_rewriter_mock.assert_called_once_with(src, dst, rewriter_inst, - rewriter.ModelType.SAVED_MODEL, - rewriter.ModelType.SAVED_MODEL) diff --git a/tfx/components/transform/__init__.py b/tfx/components/transform/__init__.py index ca966a36bf..04bdba31bd 100644 --- a/tfx/components/transform/__init__.py +++ b/tfx/components/transform/__init__.py @@ -11,3 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from tfx.components.transform import executor +from tfx.components.transform import executor_utils +from tfx.components.transform import labels +from tfx.components.transform import stats_options_util + +__all__ = [ + "executor", + "executor_utils", + "labels", + "stats_options_util", +] diff --git a/tfx/components/transform/component.py b/tfx/components/transform/component.py index 1430917e1e..ab0c2cc04a 100644 --- a/tfx/components/transform/component.py +++ b/tfx/components/transform/component.py @@ -44,9 +44,8 @@ class Transform(base_beam_component.BaseBeamComponent): can define the optional `stats_options_updater_fn` within the module file. ## Providing a preprocessing function - The TFX executor will use the estimator provided in the `module_file` file - to train the model. The Transform executor will look specifically for the - `preprocessing_fn()` function within that file. + The Transform executor will look specifically for the `preprocessing_fn()` + function within that file. An example of `preprocessing_fn()` can be found in the [user-supplied code](https://github.com/tensorflow/tfx/blob/master/tfx/examples/chicago_taxi_pipeline/taxi_utils.py) diff --git a/tfx/components/transform/executor_test.py b/tfx/components/transform/executor_test.py index 1b71798a4b..dd18941c06 100644 --- a/tfx/components/transform/executor_test.py +++ b/tfx/components/transform/executor_test.py @@ -20,7 +20,6 @@ import tempfile from unittest import mock - from absl.testing import parameterized import apache_beam as beam import tensorflow as tf @@ -28,7 +27,7 @@ import tensorflow_transform as tft from tensorflow_transform.beam import tft_unit from tfx import types -from tfx.components.testdata.module_file import transform_module +from tfx.components.testdata.module_file import trainer_module from tfx.components.transform import executor from tfx.dsl.io import fileio from tfx.proto import example_gen_pb2 @@ -59,11 +58,11 @@ class ExecutorTest(tft_unit.TransformTestCase): _FILE_FORMAT = None _PAYLOAD_FORMAT = example_gen_pb2.FORMAT_TF_EXAMPLE - _PREPROCESSING_FN = transform_module.preprocessing_fn - _STATS_OPTIONS_UPDATER_FN = transform_module.stats_options_updater_fn + _PREPROCESSING_FN = trainer_module.preprocessing_fn + _STATS_OPTIONS_UPDATER_FN = trainer_module.stats_options_updater_fn _SCHEMA_ARTIFACT_DIR = 'schema_gen' - _MODULE_FILE = 'module_file/transform_module.py' + _MODULE_FILE = 'module_file/trainer_module.py' _TEST_COUNTERS = { 'num_instances': 24909, diff --git a/tfx/examples/airflow_workshop/taxi/setup/dags/taxi_pipeline.py b/tfx/examples/airflow_workshop/taxi/setup/dags/taxi_pipeline.py index e790e20745..0c6f81bfe2 100644 --- a/tfx/examples/airflow_workshop/taxi/setup/dags/taxi_pipeline.py +++ b/tfx/examples/airflow_workshop/taxi/setup/dags/taxi_pipeline.py @@ -132,9 +132,7 @@ def _create_pipeline(pipeline_name: str, pipeline_root: str, data_root: str, # perform quality validation of a candidate model (compared to a baseline). eval_config = tfma.EvalConfig( # Step 6 model_specs=[ # Step 6 - # This assumes a serving model with signature 'serving_default'. If - # using estimator based EvalSavedModel, add signature_name: 'eval' and - # remove the label_key. + # This assumes a serving model with signature 'serving_default'. tfma.ModelSpec( # Step 6 signature_name='serving_default', # Step 6 label_key='tips', # Step 6 diff --git a/tfx/examples/bigquery_ml/taxi_utils_bqml.py b/tfx/examples/bigquery_ml/taxi_utils_bqml.py index 74e8958dcd..4fdc7550e6 100644 --- a/tfx/examples/bigquery_ml/taxi_utils_bqml.py +++ b/tfx/examples/bigquery_ml/taxi_utils_bqml.py @@ -11,32 +11,31 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Python source file include taxi pipeline functions and necessary utils. +"""Python source file include taxi pipeline functions and necesasry utils. -For a TFX pipeline to successfully run, a preprocessing_fn and a -_build_estimator function needs to be provided. This file contains both. - -This file is equivalent to examples/chicago_taxi/trainer/model.py and -examples/chicago_taxi/preprocess.py. +The utilities in this file are used to build a model with native Keras. +This module file will be used in Transform and generic Trainer. """ -from typing import List +from typing import Optional +from absl import logging import tensorflow as tf -from tensorflow import estimator as tf_estimator -import tensorflow_model_analysis as tfma import tensorflow_transform as tft -from tensorflow_transform.tf_metadata import schema_utils -from tfx.components.trainer.fn_args_utils import DataAccessor +from tfx.components.trainer import fn_args_utils from tfx_bsl.tfxio import dataset_options # Categorical features are assumed to each have a maximum value in the dataset. -_MAX_CATEGORICAL_FEATURE_VALUES = [24, 31, 12] +_MAX_CATEGORICAL_FEATURE_VALUES = [24, 31, 13] _CATEGORICAL_FEATURE_KEYS = [ - 'trip_start_hour', 'trip_start_day', 'trip_start_month', - 'pickup_census_tract', 'dropoff_census_tract', 'pickup_community_area', - 'dropoff_community_area' + 'trip_start_hour', + 'trip_start_day', + 'trip_start_month', + 'pickup_census_tract', + 'dropoff_census_tract', + 'pickup_community_area', + 'dropoff_community_area', ] _DENSE_FLOAT_FEATURE_KEYS = ['trip_miles', 'fare', 'trip_seconds'] @@ -45,8 +44,10 @@ _FEATURE_BUCKET_COUNT = 10 _BUCKET_FEATURE_KEYS = [ - 'pickup_latitude', 'pickup_longitude', 'dropoff_latitude', - 'dropoff_longitude' + 'pickup_latitude', + 'pickup_longitude', + 'dropoff_latitude', + 'dropoff_longitude', ] # Number of vocabulary terms used for encoding VOCAB_FEATURES by tf.transform @@ -73,37 +74,198 @@ def _transformed_names(keys): return [_transformed_name(key) for key in keys] -# Tf.Transform considers these features as "raw" -def _get_raw_feature_spec(schema): - return schema_utils.schema_as_feature_spec(schema).feature_spec - - def _fill_in_missing(x): - """Replace missing values in a SparseTensors. + """Replace missing values in a SparseTensor. - If x is a SparseTensors, fills in missing values of `x` with '' or 0, and - converts to a dense tensor. Otherwise it returns x as is. + Fills in missing values of `x` with '' or 0, and converts to a dense tensor. Args: - x: A `SparseTensor` of rank 2 or a tensor that is not an instance of - `SparseTensor`. If input is a `SparseTensor` its dense shape should have - size at most 1 in the second dimension. + x: A `SparseTensor` of rank 2. Its dense shape should have size at most 1 + in the second dimension. Returns: - A rank 1 tensor where missing values of `x` have been filled in, or x as is - if x is not an instance of `SparseTensor` + A rank 1 tensor where missing values of `x` have been filled in. """ - if not isinstance(x, tf.SparseTensor): + if not isinstance(x, tf.sparse.SparseTensor): return x default_value = '' if x.dtype == tf.string else 0 - return tf.squeeze( - tf.sparse.to_dense( - tf.SparseTensor(x.indices, x.values, [x.dense_shape[0], 1]), - default_value), - axis=1) + dense_tensor = tf.sparse.to_dense( + tf.SparseTensor(x.indices, x.values, [x.dense_shape[0], 1]), + default_value, + ) + return dense_tensor + + +def _get_tf_examples_serving_signature(model, tf_transform_output): + """Returns a serving signature that accepts `tensorflow.Example`.""" + model.tft_layer_inference = tf_transform_output.transform_features_layer() + + @tf.function( + input_signature=[ + tf.TensorSpec(shape=[None], dtype=tf.string, name='examples') + ] + ) + def serve_tf_examples_fn(serialized_tf_example): + raw_feature_spec = tf_transform_output.raw_feature_spec() + raw_feature_spec.pop(_LABEL_KEY) + raw_features = tf.io.parse_example(serialized_tf_example, raw_feature_spec) + transformed_features = model.tft_layer_inference(raw_features) + logging.info('serve_transformed_features = %s', transformed_features) + + outputs = model(transformed_features) + return {'outputs': outputs} + + return serve_tf_examples_fn + + +def _get_transform_features_signature(model, tf_transform_output): + """Returns a serving signature that accepts `tensorflow.Example`.""" + model.tft_layer_eval = tf_transform_output.transform_features_layer() + + @tf.function( + input_signature=[ + tf.TensorSpec(shape=[None], dtype=tf.string, name='examples') + ] + ) + def transform_features_fn(serialized_tf_example): + raw_feature_spec = tf_transform_output.raw_feature_spec() + raw_features = tf.io.parse_example(serialized_tf_example, raw_feature_spec) + transformed_features = model.tft_layer_eval(raw_features) + logging.info('eval_transformed_features = %s', transformed_features) + return transformed_features + + return transform_features_fn + + +def _input_fn( + file_pattern: list[str], + data_accessor: fn_args_utils.DataAccessor, + tf_transform_output: tft.TFTransformOutput, + batch_size: int = 200, +) -> tf.data.Dataset: + """Generates features and label for tuning/training. + + Args: + file_pattern: List of paths or patterns of input tfrecord files. + data_accessor: fn_args_utils.DataAccessor for converting input to + RecordBatch. + tf_transform_output: A TFTransformOutput. + batch_size: representing the number of consecutive elements of returned + dataset to combine in a single batch + + Returns: + A dataset that contains (features, indices) tuple where features is a + dictionary of Tensors, and indices is a single Tensor of label indices. + """ + return data_accessor.tf_dataset_factory( + file_pattern, + dataset_options.TensorFlowDatasetOptions( + batch_size=batch_size, label_key=_transformed_name(_LABEL_KEY) + ), + tf_transform_output.transformed_metadata.schema, + ).repeat() + +def _build_keras_model( + hidden_units: Optional[list[int]] = None, +) -> tf.keras.Model: + """Creates a DNN Keras model for classifying taxi data. + Args: + hidden_units: [int], the layer sizes of the DNN (input layer first). + + Returns: + A Wide and Deep keras Model. + """ + # Following values are hard coded for simplicity in this example, + # However prefarably they should be passsed in as hparams. + + # Keras needs the feature definitions at compile time. + deep_input = { + colname: tf.keras.layers.Input(name=colname, shape=(1,), dtype=tf.float32) + for colname in _transformed_names(_DENSE_FLOAT_FEATURE_KEYS) + } + wide_vocab_input = { + colname: tf.keras.layers.Input(name=colname, shape=(1,), dtype='int32') + for colname in _transformed_names(_VOCAB_FEATURE_KEYS) + } + wide_bucket_input = { + colname: tf.keras.layers.Input(name=colname, shape=(1,), dtype='int32') + for colname in _transformed_names(_BUCKET_FEATURE_KEYS) + } + wide_categorical_input = { + colname: tf.keras.layers.Input(name=colname, shape=(1,), dtype='int32') + for colname in _transformed_names(_CATEGORICAL_FEATURE_KEYS) + } + input_layers = { + **deep_input, + **wide_vocab_input, + **wide_bucket_input, + **wide_categorical_input, + } + + # TODO(b/161952382): Replace with Keras premade models and + # Keras preprocessing layers. + deep = tf.keras.layers.concatenate( + [tf.keras.layers.Normalization()(layer) for layer in deep_input.values()] + ) + for numnodes in (hidden_units or [100, 70, 50, 25]): + deep = tf.keras.layers.Dense(numnodes)(deep) + + wide_layers = [] + for key in _transformed_names(_VOCAB_FEATURE_KEYS): + wide_layers.append( + tf.keras.layers.CategoryEncoding(num_tokens=_VOCAB_SIZE + _OOV_SIZE)( + input_layers[key] + ) + ) + for key in _transformed_names(_BUCKET_FEATURE_KEYS): + wide_layers.append( + tf.keras.layers.CategoryEncoding(num_tokens=_FEATURE_BUCKET_COUNT)( + input_layers[key] + ) + ) + for key, num_tokens in zip( + _transformed_names(_CATEGORICAL_FEATURE_KEYS), + _MAX_CATEGORICAL_FEATURE_VALUES, + ): + wide_layers.append( + tf.keras.layers.CategoryEncoding(num_tokens=num_tokens)( + input_layers[key] + ) + ) + wide = tf.keras.layers.concatenate(wide_layers) + + output = tf.keras.layers.Dense(1, activation='sigmoid')( + tf.keras.layers.concatenate([deep, wide]) + ) + output = tf.squeeze(output, -1) + + model = tf.keras.Model(input_layers, output) + model.compile( + loss='binary_crossentropy', + optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), + metrics=[tf.keras.metrics.BinaryAccuracy()], + ) + model.summary(print_fn=logging.info) + return model + + +def stats_options_updater_fn(unused_stats_type, stats_options): + """Callback function for setting pre and post-transform stats options. + + Args: + unused_stats_type: a stats_options_util.StatsType object. + stats_options: a tfdv.StatsOptions object. + + Returns: + An updated tfdv.StatsOptions object. + """ + return stats_options + + +# TFX Transform will call this function. def preprocessing_fn(inputs): """tf.transform's callback function for preprocessing inputs. @@ -117,18 +279,21 @@ def preprocessing_fn(inputs): for key in _DENSE_FLOAT_FEATURE_KEYS: # If sparse make it dense, setting nan's to 0 or '', and apply zscore. outputs[_transformed_name(key)] = tft.scale_to_z_score( - _fill_in_missing(inputs[key])) + _fill_in_missing(inputs[key]) + ) for key in _VOCAB_FEATURE_KEYS: # Build a vocabulary for this feature. outputs[_transformed_name(key)] = tft.compute_and_apply_vocabulary( _fill_in_missing(inputs[key]), top_k=_VOCAB_SIZE, - num_oov_buckets=_OOV_SIZE) + num_oov_buckets=_OOV_SIZE, + ) for key in _BUCKET_FEATURE_KEYS: outputs[_transformed_name(key)] = tft.bucketize( - _fill_in_missing(inputs[key]), _FEATURE_BUCKET_COUNT) + _fill_in_missing(inputs[key]), _FEATURE_BUCKET_COUNT + ) for key in _CATEGORICAL_FEATURE_KEYS: outputs[_transformed_name(key)] = _fill_in_missing(inputs[key]) @@ -136,226 +301,68 @@ def preprocessing_fn(inputs): # Was this passenger a big tipper? taxi_fare = _fill_in_missing(inputs[_FARE_KEY]) tips = _fill_in_missing(inputs[_LABEL_KEY]) - outputs[_transformed_name(_LABEL_KEY)] = tf.compat.v1.where( + outputs[_transformed_name(_LABEL_KEY)] = tf.where( tf.math.is_nan(taxi_fare), tf.cast(tf.zeros_like(taxi_fare), tf.int64), # Test if the tip was > 20% of the fare. tf.cast( - tf.greater(tips, tf.multiply(taxi_fare, tf.constant(0.2))), tf.int64)) + tf.greater(tips, tf.multiply(taxi_fare, tf.constant(0.2))), tf.int64 + ), + ) return outputs -def _build_estimator(config, hidden_units=None, warm_start_from=None): - """Build an estimator for predicting the tipping behavior of taxi riders. - - Args: - config: tf.estimator.RunConfig defining the runtime environment for the - estimator (including model_dir). - hidden_units: [int], the layer sizes of the DNN (input layer first) - warm_start_from: Optional directory to warm start from. - - Returns: - A dict of the following: - - estimator: The estimator that will be used for training and eval. - - train_spec: Spec for training. - - eval_spec: Spec for eval. - - eval_input_receiver_fn: Input function for eval. - """ - real_valued_columns = [ - tf.feature_column.numeric_column(key, shape=()) - for key in _transformed_names(_DENSE_FLOAT_FEATURE_KEYS) - ] - categorical_columns = [ - tf.feature_column.categorical_column_with_identity( - key, num_buckets=_VOCAB_SIZE + _OOV_SIZE, default_value=0) - for key in _transformed_names(_VOCAB_FEATURE_KEYS) - ] - categorical_columns += [ - tf.feature_column.categorical_column_with_identity( - key, num_buckets=_FEATURE_BUCKET_COUNT, default_value=0) - for key in _transformed_names(_BUCKET_FEATURE_KEYS) - ] - categorical_columns += [ - tf.feature_column.categorical_column_with_identity( # pylint: disable=g-complex-comprehension - key, - num_buckets=num_buckets, - default_value=0) for key, num_buckets in zip( - _transformed_names(_CATEGORICAL_FEATURE_KEYS), - _MAX_CATEGORICAL_FEATURE_VALUES) - ] - return tf_estimator.DNNLinearCombinedClassifier( - config=config, - linear_feature_columns=categorical_columns, - dnn_feature_columns=real_valued_columns, - dnn_hidden_units=hidden_units or [100, 70, 50, 25], - warm_start_from=warm_start_from) - - -def _flat_input_serving_receiver_fn(tf_transform_output, schema): - """Build the serving function for flat list of Dense tensors as input. - - Args: - tf_transform_output: A TFTransformOutput. - schema: the schema of the input data. - - Returns: - Tensorflow graph which parses examples, applying tf-transform to them. - """ - raw_feature_spec = _get_raw_feature_spec(schema) - raw_feature_spec.pop(_LABEL_KEY) - - raw_input_fn = tf_estimator.export.build_parsing_serving_input_receiver_fn( - raw_feature_spec, default_batch_size=None) - serving_input_receiver = raw_input_fn() - - transformed_features = tf_transform_output.transform_raw_features( - serving_input_receiver.features) - - # We construct a receiver function that receives flat list of Dense tensors as - # features. This is as per BigQuery ML serving requirements. - return tf_estimator.export.ServingInputReceiver( - transformed_features, serving_input_receiver.features) - - -def _eval_input_receiver_fn(tf_transform_output, schema): - """Build everything needed for the tf-model-analysis to run the model. +# TFX Trainer will call this function. +def run_fn(fn_args: fn_args_utils.FnArgs): + """Train the model based on given args. Args: - tf_transform_output: A TFTransformOutput. - schema: the schema of the input data. - - Returns: - EvalInputReceiver function, which contains: - - Tensorflow graph which parses raw untransformed features, applies the - tf-transform preprocessing operators. - - Set of raw, untransformed features. - - Label against which predictions will be compared. - """ - # Notice that the inputs are raw features, not transformed features here. - raw_feature_spec = _get_raw_feature_spec(schema) - - serialized_tf_example = tf.compat.v1.placeholder( - dtype=tf.string, shape=[None], name='input_example_tensor') - - # Add a parse_example operator to the tensorflow graph, which will parse - # raw, untransformed, tf examples. - features = tf.io.parse_example( - serialized=serialized_tf_example, features=raw_feature_spec) - - # Now that we have our raw examples, process them through the tf-transform - # function computed during the preprocessing step. - transformed_features = tf_transform_output.transform_raw_features(features) - - # The key name MUST be 'examples'. - receiver_tensors = {'examples': serialized_tf_example} - - # NOTE: Model is driven by transformed features (since training works on the - # materialized output of TFT, but slicing will happen on raw features. - features.update(transformed_features) - - return tfma.export.EvalInputReceiver( - features=features, - receiver_tensors=receiver_tensors, - labels=transformed_features[_transformed_name(_LABEL_KEY)]) - - -def _input_fn(file_pattern: List[str], - data_accessor: DataAccessor, - tf_transform_output: tft.TFTransformOutput, - batch_size: int = 200) -> tf.data.Dataset: - """Generates features and label for tuning/training. - - Args: - file_pattern: List of paths or patterns of input tfrecord files. - data_accessor: DataAccessor for converting input to RecordBatch. - tf_transform_output: A TFTransformOutput. - batch_size: representing the number of consecutive elements of returned - dataset to combine in a single batch - - Returns: - A dataset that contains (features, indices) tuple where features is a - dictionary of Tensors, and indices is a single Tensor of label indices. - """ - return data_accessor.tf_dataset_factory( - file_pattern, - dataset_options.TensorFlowDatasetOptions( - batch_size=batch_size, label_key=_transformed_name(_LABEL_KEY)), - tf_transform_output.transformed_metadata.schema) - - -# TFX will call this function -def trainer_fn(trainer_fn_args, schema): - """Build the estimator using the high level API. - - Args: - trainer_fn_args: Holds args used to train the model as name/value pairs. - schema: Holds the schema of the training examples. - - Returns: - A dict of the following: - - estimator: The estimator that will be used for training and eval. - - train_spec: Spec for training. - - eval_spec: Spec for eval. - - eval_input_receiver_fn: Input function for eval. + fn_args: Holds args used to train the model as name/value pairs. """ # Number of nodes in the first layer of the DNN first_dnn_layer_size = 100 num_dnn_layers = 4 dnn_decay_factor = 0.7 - train_batch_size = 40 - eval_batch_size = 40 - - tf_transform_output = tft.TFTransformOutput(trainer_fn_args.transform_output) - - train_input_fn = lambda: _input_fn( # pylint: disable=g-long-lambda - trainer_fn_args.train_files, - trainer_fn_args.data_accessor, - tf_transform_output, - batch_size=train_batch_size) - - eval_input_fn = lambda: _input_fn( # pylint: disable=g-long-lambda - trainer_fn_args.eval_files, - trainer_fn_args.data_accessor, - tf_transform_output, - batch_size=eval_batch_size) - - train_spec = tf_estimator.TrainSpec( # pylint: disable=g-long-lambda - train_input_fn, - max_steps=trainer_fn_args.train_steps) - - serving_receiver_fn = lambda: _flat_input_serving_receiver_fn( # pylint: disable=g-long-lambda - tf_transform_output, schema) - - exporter = tf_estimator.FinalExporter('chicago-taxi', serving_receiver_fn) - eval_spec = tf_estimator.EvalSpec( - eval_input_fn, - steps=trainer_fn_args.eval_steps, - exporters=[exporter], - name='chicago-taxi-eval') - - run_config = tf_estimator.RunConfig( - save_checkpoints_steps=999, keep_checkpoint_max=1) - - run_config = run_config.replace(model_dir=trainer_fn_args.serving_model_dir) - - estimator = _build_estimator( - # Construct layers sizes with exponential decay - hidden_units=[ - max(2, int(first_dnn_layer_size * dnn_decay_factor**i)) - for i in range(num_dnn_layers) - ], - config=run_config, - warm_start_from=trainer_fn_args.base_model) - - # Create an input receiver for TFMA processing - receiver_fn = lambda: _eval_input_receiver_fn( # pylint: disable=g-long-lambda - tf_transform_output, schema) - - return { - 'estimator': estimator, - 'train_spec': train_spec, - 'eval_spec': eval_spec, - 'eval_input_receiver_fn': receiver_fn + tf_transform_output = tft.TFTransformOutput(fn_args.transform_graph_path) + + train_dataset = _input_fn( + fn_args.train_files, fn_args.data_accessor, tf_transform_output, 40 + ) + eval_dataset = _input_fn( + fn_args.eval_files, fn_args.data_accessor, tf_transform_output, 40 + ) + + mirrored_strategy = tf.distribute.MirroredStrategy() + with mirrored_strategy.scope(): + model = _build_keras_model( + # Construct layers sizes with exponetial decay + hidden_units=[ + max(2, int(first_dnn_layer_size * dnn_decay_factor**i)) + for i in range(num_dnn_layers) + ] + ) + + # Write logs to path + tensorboard_callback = tf.keras.callbacks.TensorBoard( + log_dir=fn_args.model_run_dir, update_freq='epoch' + ) + + model.fit( + train_dataset, + steps_per_epoch=fn_args.train_steps, + validation_data=eval_dataset, + validation_steps=fn_args.eval_steps, + callbacks=[tensorboard_callback], + ) + + signatures = { + 'serving_default': _get_tf_examples_serving_signature( + model, tf_transform_output + ), + 'transform_features': _get_transform_features_signature( + model, tf_transform_output + ), } + model.save(fn_args.serving_model_dir, save_format='tf', signatures=signatures) diff --git a/tfx/examples/bigquery_ml/taxi_utils_bqml_test.py b/tfx/examples/bigquery_ml/taxi_utils_bqml_test.py deleted file mode 100644 index 2b6c7ef70b..0000000000 --- a/tfx/examples/bigquery_ml/taxi_utils_bqml_test.py +++ /dev/null @@ -1,173 +0,0 @@ -# Copyright 2019 Google LLC. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Tests for taxi_utils_bqml.py.""" - -import os -import types - -import apache_beam as beam -import tensorflow as tf -from tensorflow import estimator as tf_estimator -import tensorflow_model_analysis as tfma -import tensorflow_transform as tft -from tensorflow_transform import beam as tft_beam -from tensorflow_transform.tf_metadata import dataset_metadata -from tensorflow_transform.tf_metadata import schema_utils -from tfx.components.trainer import executor as trainer_executor -from tfx.components.trainer.fn_args_utils import DataAccessor -from tfx.components.util import tfxio_utils -from tfx.dsl.io import fileio -from tfx.examples.bigquery_ml import taxi_utils_bqml -from tfx.types import standard_artifacts -from tfx.utils import io_utils -from tfx.utils import path_utils - -from tfx_bsl.tfxio import tf_example_record -from tensorflow_metadata.proto.v0 import schema_pb2 - - -class TaxiUtilsTest(tf.test.TestCase): - - def setUp(self): - super().setUp() - self._testdata_path = os.path.join( - os.path.dirname(os.path.dirname(os.path.dirname(__file__))), - 'components/testdata') - - def testUtils(self): - key = 'fare' - xfm_key = taxi_utils_bqml._transformed_name(key) - self.assertEqual(xfm_key, 'fare_xf') - - def testPreprocessingFn(self): - schema_file = os.path.join(self._testdata_path, 'schema_gen/schema.pbtxt') - schema = io_utils.parse_pbtxt_file(schema_file, schema_pb2.Schema()) - feature_spec = taxi_utils_bqml._get_raw_feature_spec(schema) - working_dir = self.get_temp_dir() - transform_graph_path = os.path.join(working_dir, 'transform_graph') - transformed_examples_path = os.path.join( - working_dir, 'transformed_examples') - - # Run very simplified version of executor logic. - # TODO(kestert): Replace with tft_unit.assertAnalyzeAndTransformResults. - # Generate legacy `DatasetMetadata` object. Future version of Transform - # will accept the `Schema` proto directly. - legacy_metadata = dataset_metadata.DatasetMetadata( - schema_utils.schema_from_feature_spec(feature_spec)) - tfxio = tf_example_record.TFExampleRecord( - file_pattern=os.path.join(self._testdata_path, - 'csv_example_gen/Split-train/*'), - telemetry_descriptors=['Tests'], - schema=legacy_metadata.schema) - with beam.Pipeline() as p: - with tft_beam.Context(temp_dir=os.path.join(working_dir, 'tmp')): - examples = p | 'ReadTrainData' >> tfxio.BeamSource() - (transformed_examples, transformed_metadata), transform_fn = ( - (examples, tfxio.TensorAdapterConfig()) - | 'AnalyzeAndTransform' >> tft_beam.AnalyzeAndTransformDataset( - taxi_utils_bqml.preprocessing_fn)) - - # WriteTransformFn writes transform_fn and metadata to subdirectories - # tensorflow_transform.SAVED_MODEL_DIR and - # tensorflow_transform.TRANSFORMED_METADATA_DIR respectively. - # pylint: disable=expression-not-assigned - (transform_fn - | 'WriteTransformFn' >> tft_beam.WriteTransformFn( - transform_graph_path)) - - encoder = tft.coders.ExampleProtoCoder(transformed_metadata.schema) - (transformed_examples - | 'EncodeTrainData' >> beam.Map(encoder.encode) - | 'WriteTrainData' >> beam.io.WriteToTFRecord( - os.path.join(transformed_examples_path, - 'Split-train/transformed_examples.gz'), - coder=beam.coders.BytesCoder())) - # pylint: enable=expression-not-assigned - - # Verify the output matches golden output. - # NOTE: we don't verify that transformed examples match golden output. - expected_transformed_schema = io_utils.parse_pbtxt_file( - os.path.join( - self._testdata_path, - 'transform/transform_graph/transformed_metadata/schema.pbtxt'), - schema_pb2.Schema()) - transformed_schema = io_utils.parse_pbtxt_file( - os.path.join(transform_graph_path, - 'transformed_metadata/schema.pbtxt'), - schema_pb2.Schema()) - # Clear annotations so we only have to test main schema. - for feature in transformed_schema.feature: - feature.ClearField('annotation') - transformed_schema.ClearField('annotation') - self.assertEqual(transformed_schema, expected_transformed_schema) - - def testTrainerFn(self): - temp_dir = os.path.join( - os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), - self._testMethodName) - - schema_file = os.path.join(self._testdata_path, 'schema_gen/schema.pbtxt') - trainer_fn_args = trainer_executor.TrainerFnArgs( - train_files=os.path.join( - self._testdata_path, - 'transform/transformed_examples/Split-train/*.gz'), - transform_output=os.path.join(self._testdata_path, - 'transform/transform_graph/'), - serving_model_dir=os.path.join(temp_dir, 'serving_model_dir'), - eval_files=os.path.join( - self._testdata_path, - 'transform/transformed_examples/Split-eval/*.gz'), - schema_file=schema_file, - train_steps=1, - eval_steps=1, - base_model=os.path.join(self._testdata_path, - 'trainer/previous/Format-Serving'), - data_accessor=DataAccessor( - tf_dataset_factory=tfxio_utils.get_tf_dataset_factory_from_artifact( - [standard_artifacts.Examples()], []), - record_batch_factory=None, - data_view_decode_fn=None)) - schema = io_utils.parse_pbtxt_file(schema_file, schema_pb2.Schema()) - training_spec = taxi_utils_bqml.trainer_fn(trainer_fn_args, schema) - - estimator = training_spec['estimator'] - train_spec = training_spec['train_spec'] - eval_spec = training_spec['eval_spec'] - eval_input_receiver_fn = training_spec['eval_input_receiver_fn'] - - self.assertIsInstance(estimator, tf_estimator.Estimator) - self.assertIsInstance(train_spec, tf_estimator.TrainSpec) - self.assertIsInstance(eval_spec, tf_estimator.EvalSpec) - self.assertIsInstance(eval_input_receiver_fn, types.FunctionType) - - # Train for one step, then eval for one step. - eval_result, exports = tf_estimator.train_and_evaluate( - estimator, train_spec, eval_spec) - print(eval_result, exports) - self.assertGreater(eval_result['loss'], 0.0) - self.assertEqual(len(exports), 1) - self.assertGreaterEqual(len(fileio.listdir(exports[0])), 1) - - # Export the eval saved model. - eval_savedmodel_path = tfma.export.export_eval_savedmodel( - estimator=estimator, - export_dir_base=path_utils.eval_model_dir(temp_dir), - eval_input_receiver_fn=eval_input_receiver_fn) - self.assertGreaterEqual(len(fileio.listdir(eval_savedmodel_path)), 1) - - # Test exported serving graph. - with tf.compat.v1.Session() as sess: - metagraph_def = tf.compat.v1.saved_model.loader.load( - sess, [tf.saved_model.SERVING], exports[0]) - self.assertIsInstance(metagraph_def, tf.compat.v1.MetaGraphDef) diff --git a/tfx/examples/chicago_taxi_pipeline/README.md b/tfx/examples/chicago_taxi_pipeline/README.md index f930fc954d..8173c60ce9 100644 --- a/tfx/examples/chicago_taxi_pipeline/README.md +++ b/tfx/examples/chicago_taxi_pipeline/README.md @@ -16,7 +16,7 @@ performance, and serve it. This example uses the following * [Transform](https://github.com/tensorflow/tfx/blob/master/docs/guide/transform.md) performs feature engineering on the dataset. * [Trainer](https://github.com/tensorflow/tfx/blob/master/docs/guide/trainer.md) - trains the model using TensorFlow [Estimators](https://www.tensorflow.org/guide/estimators) + trains the model using native Keras. or [Keras](https://www.tensorflow.org/guide/keras). * [Evaluator](https://github.com/tensorflow/tfx/blob/master/docs/guide/evaluator.md) performs deep analysis of the training results. diff --git a/tfx/examples/chicago_taxi_pipeline/taxi_pipeline_local.py b/tfx/examples/chicago_taxi_pipeline/taxi_pipeline_local.py deleted file mode 100644 index 8f8628bd51..0000000000 --- a/tfx/examples/chicago_taxi_pipeline/taxi_pipeline_local.py +++ /dev/null @@ -1,196 +0,0 @@ -# Copyright 2019 Google LLC. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Chicago taxi example using TFX.""" - -import os -from typing import List - -import absl -import tensorflow_model_analysis as tfma -from tfx.components import CsvExampleGen -from tfx.components import Evaluator -from tfx.components import ExampleValidator -from tfx.components import Pusher -from tfx.components import SchemaGen -from tfx.components import StatisticsGen -from tfx.components import Trainer -from tfx.components import Transform -from tfx.components.trainer.executor import Executor -from tfx.dsl.components.base import executor_spec -from tfx.dsl.components.common import resolver -from tfx.dsl.experimental import latest_artifacts_resolver -from tfx.dsl.experimental import latest_blessed_model_resolver -from tfx.orchestration import metadata -from tfx.orchestration import pipeline -from tfx.orchestration.local.local_dag_runner import LocalDagRunner -from tfx.proto import pusher_pb2 -from tfx.proto import trainer_pb2 -from tfx.types import Channel -from tfx.types.standard_artifacts import Model -from tfx.types.standard_artifacts import ModelBlessing - -_pipeline_name = 'chicago_taxi_beam' - -# This example assumes that the taxi data is stored in ~/taxi/data and the -# taxi utility function is in ~/taxi. Feel free to customize this as needed. -_taxi_root = os.path.join(os.environ['HOME'], 'taxi') -_data_root = os.path.join(_taxi_root, 'data', 'simple') -# Python module file to inject customized logic into the TFX components. The -# Transform and Trainer both require user-defined functions to run successfully. -_module_file = os.path.join(_taxi_root, 'taxi_utils.py') -# Path which can be listened to by the model server. Pusher will output the -# trained model here. -_serving_model_dir = os.path.join(_taxi_root, 'serving_model', _pipeline_name) - -# Directory and data locations. This example assumes all of the chicago taxi -# example code and metadata library is relative to $HOME, but you can store -# these files anywhere on your local filesystem. -_tfx_root = os.path.join(os.environ['HOME'], 'tfx') -_pipeline_root = os.path.join(_tfx_root, 'pipelines', _pipeline_name) -# Sqlite ML-metadata db path. -_metadata_path = os.path.join(_tfx_root, 'metadata', _pipeline_name, - 'metadata.db') - -# Pipeline arguments for Beam powered Components. -_beam_pipeline_args = [ - '--direct_running_mode=multi_processing', - # 0 means auto-detect based on on the number of CPUs available - # during execution time. - '--direct_num_workers=0', -] - - -# TODO(b/137289334): rename this as simple after DAG visualization is done. -def _create_pipeline(pipeline_name: str, pipeline_root: str, data_root: str, - module_file: str, serving_model_dir: str, - metadata_path: str, - beam_pipeline_args: List[str]) -> pipeline.Pipeline: - """Implements the chicago taxi pipeline with TFX.""" - - # Brings data into the pipeline or otherwise joins/converts training data. - example_gen = CsvExampleGen(input_base=data_root) - - # Computes statistics over data for visualization and example validation. - statistics_gen = StatisticsGen(examples=example_gen.outputs['examples']) - - # Generates schema based on statistics files. - schema_gen = SchemaGen( - statistics=statistics_gen.outputs['statistics'], - infer_feature_shape=False) - - # Performs anomaly detection based on statistics and data schema. - example_validator = ExampleValidator( - statistics=statistics_gen.outputs['statistics'], - schema=schema_gen.outputs['schema']) - - # Performs transformations and feature engineering in training and serving. - transform = Transform( - examples=example_gen.outputs['examples'], - schema=schema_gen.outputs['schema'], - module_file=module_file) - - # Get the latest model so that we can warm start from the model. - latest_model_resolver = resolver.Resolver( - strategy_class=latest_artifacts_resolver.LatestArtifactsResolver, - latest_model=Channel(type=Model)).with_id('latest_model_resolver') - - # Uses user-provided Python function that implements a model. - trainer = Trainer( - module_file=module_file, - custom_executor_spec=executor_spec.ExecutorClassSpec(Executor), - transformed_examples=transform.outputs['transformed_examples'], - schema=schema_gen.outputs['schema'], - base_model=latest_model_resolver.outputs['latest_model'], - transform_graph=transform.outputs['transform_graph'], - train_args=trainer_pb2.TrainArgs(num_steps=10000), - eval_args=trainer_pb2.EvalArgs(num_steps=5000)) - - # Get the latest blessed model for model validation. - model_resolver = resolver.Resolver( - strategy_class=latest_blessed_model_resolver.LatestBlessedModelResolver, - model=Channel(type=Model), - model_blessing=Channel( - type=ModelBlessing)).with_id('latest_blessed_model_resolver') - - # Uses TFMA to compute a evaluation statistics over features of a model and - # perform quality validation of a candidate model (compared to a baseline). - eval_config = tfma.EvalConfig( - model_specs=[tfma.ModelSpec(signature_name='eval')], - slicing_specs=[ - tfma.SlicingSpec(), - tfma.SlicingSpec(feature_keys=['trip_start_hour']) - ], - metrics_specs=[ - tfma.MetricsSpec( - thresholds={ - 'accuracy': - tfma.MetricThreshold( - value_threshold=tfma.GenericValueThreshold( - lower_bound={'value': 0.6}), - change_threshold=tfma.GenericChangeThreshold( - direction=tfma.MetricDirection.HIGHER_IS_BETTER, - absolute={'value': -1e-10})) - }) - ]) - evaluator = Evaluator( - examples=example_gen.outputs['examples'], - model=trainer.outputs['model'], - baseline_model=model_resolver.outputs['model'], - # Change threshold will be ignored if there is no baseline (first run). - eval_config=eval_config) - - # Checks whether the model passed the validation steps and pushes the model - # to a file destination if check passed. - pusher = Pusher( - model=trainer.outputs['model'], - model_blessing=evaluator.outputs['blessing'], - push_destination=pusher_pb2.PushDestination( - filesystem=pusher_pb2.PushDestination.Filesystem( - base_directory=serving_model_dir))) - - return pipeline.Pipeline( - pipeline_name=pipeline_name, - pipeline_root=pipeline_root, - components=[ - example_gen, - statistics_gen, - schema_gen, - example_validator, - transform, - latest_model_resolver, - trainer, - model_resolver, - evaluator, - pusher, - ], - enable_cache=True, - metadata_connection_config=metadata.sqlite_metadata_connection_config( - metadata_path), - beam_pipeline_args=beam_pipeline_args) - - -# To run this pipeline from the python CLI: -# $python taxi_pipeline_beam.py -if __name__ == '__main__': - absl.logging.set_verbosity(absl.logging.INFO) - - LocalDagRunner().run( - _create_pipeline( - pipeline_name=_pipeline_name, - pipeline_root=_pipeline_root, - data_root=_data_root, - module_file=_module_file, - serving_model_dir=_serving_model_dir, - metadata_path=_metadata_path, - beam_pipeline_args=_beam_pipeline_args)) diff --git a/tfx/examples/chicago_taxi_pipeline/taxi_pipeline_local_e2e_test.py b/tfx/examples/chicago_taxi_pipeline/taxi_pipeline_local_e2e_test.py deleted file mode 100644 index 4e5953fd15..0000000000 --- a/tfx/examples/chicago_taxi_pipeline/taxi_pipeline_local_e2e_test.py +++ /dev/null @@ -1,100 +0,0 @@ -# Copyright 2019 Google LLC. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""E2E Tests for tfx.examples.chicago_taxi_pipeline.taxi_pipeline_local.""" - -import os - -from absl.testing import parameterized -import tensorflow as tf -from tfx.dsl.io import fileio -from tfx.examples.chicago_taxi_pipeline import taxi_pipeline_local -from tfx.orchestration import metadata -from tfx.orchestration.local.local_dag_runner import LocalDagRunner - -import pytest - - -@pytest.mark.e2e -class TaxiPipelineLocalEndToEndTest(tf.test.TestCase, parameterized.TestCase): - - def setUp(self): - super().setUp() - self._test_dir = os.path.join( - os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), - self._testMethodName) - - self._pipeline_name = 'beam_test' - self._data_root = os.path.join(os.path.dirname(__file__), 'data', 'simple') - self._module_file = os.path.join(os.path.dirname(__file__), 'taxi_utils.py') - self._serving_model_dir = os.path.join(self._test_dir, 'serving_model') - self._pipeline_root = os.path.join(self._test_dir, 'tfx', 'pipelines', - self._pipeline_name) - self._metadata_path = os.path.join(self._test_dir, 'tfx', 'metadata', - self._pipeline_name, 'metadata.db') - - def assertExecutedOnce(self, component: str) -> None: - """Check the component is executed exactly once.""" - component_path = os.path.join(self._pipeline_root, component) - self.assertTrue(fileio.exists(component_path)) - outputs = fileio.listdir(component_path) - - self.assertIn('.system', outputs) - outputs.remove('.system') - system_paths = [ - os.path.join('.system', path) - for path in fileio.listdir(os.path.join(component_path, '.system')) - ] - self.assertNotEmpty(system_paths) - self.assertIn('.system/executor_execution', system_paths) - outputs.extend(system_paths) - self.assertNotEmpty(outputs) - for output in outputs: - execution = fileio.listdir(os.path.join(component_path, output)) - if output == '.system/stateful_working_dir': - self.assertEmpty(execution) - else: - self.assertLen(execution, 1) - - def assertPipelineExecution(self) -> None: - self.assertExecutedOnce('CsvExampleGen') - self.assertExecutedOnce('Evaluator') - self.assertExecutedOnce('ExampleValidator') - self.assertExecutedOnce('Pusher') - self.assertExecutedOnce('SchemaGen') - self.assertExecutedOnce('StatisticsGen') - self.assertExecutedOnce('Trainer') - self.assertExecutedOnce('Transform') - - def testTaxiPipelineBeam(self): - LocalDagRunner().run( - taxi_pipeline_local._create_pipeline( - pipeline_name=self._pipeline_name, - data_root=self._data_root, - module_file=self._module_file, - serving_model_dir=self._serving_model_dir, - pipeline_root=self._pipeline_root, - metadata_path=self._metadata_path, - beam_pipeline_args=[])) - - self.assertTrue(fileio.exists(self._serving_model_dir)) - self.assertTrue(fileio.exists(self._metadata_path)) - metadata_config = metadata.sqlite_metadata_connection_config( - self._metadata_path) - with metadata.Metadata(metadata_config) as m: - artifact_count = len(m.store.get_artifacts()) - execution_count = len(m.store.get_executions()) - self.assertGreaterEqual(artifact_count, execution_count) - self.assertEqual(10, execution_count) - - self.assertPipelineExecution() diff --git a/tfx/examples/chicago_taxi_pipeline/taxi_pipeline_simple.py b/tfx/examples/chicago_taxi_pipeline/taxi_pipeline_simple.py index 0e2fc26249..5e5faf18ef 100644 --- a/tfx/examples/chicago_taxi_pipeline/taxi_pipeline_simple.py +++ b/tfx/examples/chicago_taxi_pipeline/taxi_pipeline_simple.py @@ -26,8 +26,6 @@ from tfx.components import StatisticsGen from tfx.components import Trainer from tfx.components import Transform -from tfx.components.trainer.executor import Executor -from tfx.dsl.components.base import executor_spec from tfx.dsl.components.common import resolver from tfx.dsl.experimental import latest_blessed_model_resolver from tfx.orchestration import data_types @@ -116,7 +114,6 @@ def _create_pipeline(pipeline_name: str, pipeline_root: str, data_root: str, # Uses user-provided Python function that implements a model. trainer = Trainer( module_file=module_file, - custom_executor_spec=executor_spec.ExecutorClassSpec(Executor), transformed_examples=transform.outputs['transformed_examples'], schema=schema_gen.outputs['schema'], transform_graph=transform.outputs['transform_graph'], diff --git a/tfx/examples/chicago_taxi_pipeline/taxi_utils.py b/tfx/examples/chicago_taxi_pipeline/taxi_utils.py index 4a6ade3b4b..42ee24ce23 100644 --- a/tfx/examples/chicago_taxi_pipeline/taxi_utils.py +++ b/tfx/examples/chicago_taxi_pipeline/taxi_utils.py @@ -13,27 +13,30 @@ # limitations under the License. """Python source file include taxi pipeline functions and necesasry utils. -For a TFX pipeline to successfully run, a preprocessing_fn and a -trainer_fn function needs to be provided. This file contains both. +The utilities in this file are used to build a model with native Keras. +This module file will be used in Transform and generic Trainer. """ -from typing import List +from typing import Optional +from absl import logging import tensorflow as tf -from tensorflow import estimator as tf_estimator -import tensorflow_model_analysis as tfma import tensorflow_transform as tft from tensorflow_transform.tf_metadata import schema_utils -from tfx.components.trainer.fn_args_utils import DataAccessor +from tfx.components.trainer import fn_args_utils from tfx_bsl.tfxio import dataset_options # Categorical features are assumed to each have a maximum value in the dataset. -_MAX_CATEGORICAL_FEATURE_VALUES = [24, 31, 12] +_MAX_CATEGORICAL_FEATURE_VALUES = [24, 31, 13] _CATEGORICAL_FEATURE_KEYS = [ - 'trip_start_hour', 'trip_start_day', 'trip_start_month', - 'pickup_census_tract', 'dropoff_census_tract', 'pickup_community_area', - 'dropoff_community_area' + 'trip_start_hour', + 'trip_start_day', + 'trip_start_month', + 'pickup_census_tract', + 'dropoff_census_tract', + 'pickup_community_area', + 'dropoff_community_area', ] _DENSE_FLOAT_FEATURE_KEYS = ['trip_miles', 'fare', 'trip_seconds'] @@ -42,8 +45,10 @@ _FEATURE_BUCKET_COUNT = 10 _BUCKET_FEATURE_KEYS = [ - 'pickup_latitude', 'pickup_longitude', 'dropoff_latitude', - 'dropoff_longitude' + 'pickup_latitude', + 'pickup_longitude', + 'dropoff_latitude', + 'dropoff_longitude', ] # Number of vocabulary terms used for encoding VOCAB_FEATURES by tf.transform @@ -81,23 +86,192 @@ def _fill_in_missing(x): Fills in missing values of `x` with '' or 0, and converts to a dense tensor. Args: - x: A `SparseTensor` of rank 2. Its dense shape should have size at most 1 + x: A `SparseTensor` of rank 2. Its dense shape should have size at most 1 in the second dimension. Returns: - A rank 1 tensor where missing values of `x` have been filled in. + A rank 1 tensor where missing values of `x` have been filled in. """ if not isinstance(x, tf.sparse.SparseTensor): return x default_value = '' if x.dtype == tf.string else 0 - return tf.squeeze( - tf.sparse.to_dense( - tf.SparseTensor(x.indices, x.values, [x.dense_shape[0], 1]), - default_value), - axis=1) + dense_tensor = tf.sparse.to_dense( + tf.SparseTensor(x.indices, x.values, [x.dense_shape[0], 1]), + default_value, + ) + return dense_tensor + + +def _get_tf_examples_serving_signature(model, tf_transform_output): + """Returns a serving signature that accepts `tensorflow.Example`.""" + model.tft_layer_inference = tf_transform_output.transform_features_layer() + + @tf.function( + input_signature=[ + tf.TensorSpec(shape=[None], dtype=tf.string, name='examples') + ] + ) + def serve_tf_examples_fn(serialized_tf_example): + raw_feature_spec = tf_transform_output.raw_feature_spec() + raw_feature_spec.pop(_LABEL_KEY) + raw_features = tf.io.parse_example(serialized_tf_example, raw_feature_spec) + transformed_features = model.tft_layer_inference(raw_features) + logging.info('serve_transformed_features = %s', transformed_features) + + outputs = model(transformed_features) + return {'outputs': outputs} + + return serve_tf_examples_fn + + +def _get_transform_features_signature(model, tf_transform_output): + """Returns a serving signature that accepts `tensorflow.Example`.""" + model.tft_layer_eval = tf_transform_output.transform_features_layer() + + @tf.function( + input_signature=[ + tf.TensorSpec(shape=[None], dtype=tf.string, name='examples') + ] + ) + def transform_features_fn(serialized_tf_example): + raw_feature_spec = tf_transform_output.raw_feature_spec() + raw_features = tf.io.parse_example(serialized_tf_example, raw_feature_spec) + transformed_features = model.tft_layer_eval(raw_features) + logging.info('eval_transformed_features = %s', transformed_features) + return transformed_features + + return transform_features_fn + + +def _input_fn( + file_pattern: list[str], + data_accessor: fn_args_utils.DataAccessor, + tf_transform_output: tft.TFTransformOutput, + batch_size: int = 200, +) -> tf.data.Dataset: + """Generates features and label for tuning/training. + + Args: + file_pattern: List of paths or patterns of input tfrecord files. + data_accessor: fn_args_utils.DataAccessor for converting input to + RecordBatch. + tf_transform_output: A TFTransformOutput. + batch_size: representing the number of consecutive elements of returned + dataset to combine in a single batch + + Returns: + A dataset that contains (features, indices) tuple where features is a + dictionary of Tensors, and indices is a single Tensor of label indices. + """ + return data_accessor.tf_dataset_factory( + file_pattern, + dataset_options.TensorFlowDatasetOptions( + batch_size=batch_size, label_key=_transformed_name(_LABEL_KEY) + ), + tf_transform_output.transformed_metadata.schema, + ).repeat() + + +def _build_keras_model( + hidden_units: Optional[list[int]] = None, +) -> tf.keras.Model: + """Creates a DNN Keras model for classifying taxi data. + + Args: + hidden_units: [int], the layer sizes of the DNN (input layer first). + + Returns: + A Wide and Deep keras Model. + """ + # Following values are hard coded for simplicity in this example, + # However prefarably they should be passsed in as hparams. + # Keras needs the feature definitions at compile time. + deep_input = { + colname: tf.keras.layers.Input(name=colname, shape=(1,), dtype=tf.float32) + for colname in _transformed_names(_DENSE_FLOAT_FEATURE_KEYS) + } + wide_vocab_input = { + colname: tf.keras.layers.Input(name=colname, shape=(1,), dtype='int32') + for colname in _transformed_names(_VOCAB_FEATURE_KEYS) + } + wide_bucket_input = { + colname: tf.keras.layers.Input(name=colname, shape=(1,), dtype='int32') + for colname in _transformed_names(_BUCKET_FEATURE_KEYS) + } + wide_categorical_input = { + colname: tf.keras.layers.Input(name=colname, shape=(1,), dtype='int32') + for colname in _transformed_names(_CATEGORICAL_FEATURE_KEYS) + } + input_layers = { + **deep_input, + **wide_vocab_input, + **wide_bucket_input, + **wide_categorical_input, + } + # TODO(b/161952382): Replace with Keras premade models and + # Keras preprocessing layers. + deep = tf.keras.layers.concatenate( + [tf.keras.layers.Normalization()(layer) for layer in deep_input.values()] + ) + for numnodes in (hidden_units or [100, 70, 50, 25]): + deep = tf.keras.layers.Dense(numnodes)(deep) + + wide_layers = [] + for key in _transformed_names(_VOCAB_FEATURE_KEYS): + wide_layers.append( + tf.keras.layers.CategoryEncoding(num_tokens=_VOCAB_SIZE + _OOV_SIZE)( + input_layers[key] + ) + ) + for key in _transformed_names(_BUCKET_FEATURE_KEYS): + wide_layers.append( + tf.keras.layers.CategoryEncoding(num_tokens=_FEATURE_BUCKET_COUNT)( + input_layers[key] + ) + ) + for key, num_tokens in zip( + _transformed_names(_CATEGORICAL_FEATURE_KEYS), + _MAX_CATEGORICAL_FEATURE_VALUES, + ): + wide_layers.append( + tf.keras.layers.CategoryEncoding(num_tokens=num_tokens)( + input_layers[key] + ) + ) + wide = tf.keras.layers.concatenate(wide_layers) + + output = tf.keras.layers.Dense(1, activation='sigmoid')( + tf.keras.layers.concatenate([deep, wide]) + ) + output = tf.squeeze(output, -1) + + model = tf.keras.Model(input_layers, output) + model.compile( + loss='binary_crossentropy', + optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), + metrics=[tf.keras.metrics.BinaryAccuracy()], + ) + model.summary(print_fn=logging.info) + return model + + +def stats_options_updater_fn(unused_stats_type, stats_options): + """Callback function for setting pre and post-transform stats options. + + Args: + unused_stats_type: a stats_options_util.StatsType object. + stats_options: a tfdv.StatsOptions object. + + Returns: + An updated tfdv.StatsOptions object. + """ + return stats_options + + +# TFX Transform will call this function. def preprocessing_fn(inputs): """tf.transform's callback function for preprocessing inputs. @@ -111,18 +285,21 @@ def preprocessing_fn(inputs): for key in _DENSE_FLOAT_FEATURE_KEYS: # If sparse make it dense, setting nan's to 0 or '', and apply zscore. outputs[_transformed_name(key)] = tft.scale_to_z_score( - _fill_in_missing(inputs[key])) + _fill_in_missing(inputs[key]) + ) for key in _VOCAB_FEATURE_KEYS: # Build a vocabulary for this feature. outputs[_transformed_name(key)] = tft.compute_and_apply_vocabulary( _fill_in_missing(inputs[key]), top_k=_VOCAB_SIZE, - num_oov_buckets=_OOV_SIZE) + num_oov_buckets=_OOV_SIZE, + ) for key in _BUCKET_FEATURE_KEYS: outputs[_transformed_name(key)] = tft.bucketize( - _fill_in_missing(inputs[key]), _FEATURE_BUCKET_COUNT) + _fill_in_missing(inputs[key]), _FEATURE_BUCKET_COUNT + ) for key in _CATEGORICAL_FEATURE_KEYS: outputs[_transformed_name(key)] = _fill_in_missing(inputs[key]) @@ -130,229 +307,68 @@ def preprocessing_fn(inputs): # Was this passenger a big tipper? taxi_fare = _fill_in_missing(inputs[_FARE_KEY]) tips = _fill_in_missing(inputs[_LABEL_KEY]) - outputs[_transformed_name(_LABEL_KEY)] = tf.compat.v1.where( + outputs[_transformed_name(_LABEL_KEY)] = tf.where( tf.math.is_nan(taxi_fare), tf.cast(tf.zeros_like(taxi_fare), tf.int64), # Test if the tip was > 20% of the fare. tf.cast( - tf.greater(tips, tf.multiply(taxi_fare, tf.constant(0.2))), tf.int64)) + tf.greater(tips, tf.multiply(taxi_fare, tf.constant(0.2))), tf.int64 + ), + ) return outputs -def _build_estimator(config, hidden_units=None, warm_start_from=None): - """Build an estimator for predicting the tipping behavior of taxi riders. - - Args: - config: tf.estimator.RunConfig defining the runtime environment for the - estimator (including model_dir). - hidden_units: [int], the layer sizes of the DNN (input layer first) - warm_start_from: Optional directory to warm start from. - - Returns: - A dict of the following: - - estimator: The estimator that will be used for training and eval. - - train_spec: Spec for training. - - eval_spec: Spec for eval. - - eval_input_receiver_fn: Input function for eval. - """ - real_valued_columns = [ - tf.feature_column.numeric_column(key, shape=()) - for key in _transformed_names(_DENSE_FLOAT_FEATURE_KEYS) - ] - categorical_columns = [ - tf.feature_column.categorical_column_with_identity( - key, num_buckets=_VOCAB_SIZE + _OOV_SIZE, default_value=0) - for key in _transformed_names(_VOCAB_FEATURE_KEYS) - ] - categorical_columns += [ - tf.feature_column.categorical_column_with_identity( - key, num_buckets=_FEATURE_BUCKET_COUNT, default_value=0) - for key in _transformed_names(_BUCKET_FEATURE_KEYS) - ] - categorical_columns += [ - tf.feature_column.categorical_column_with_identity( # pylint: disable=g-complex-comprehension - key, - num_buckets=num_buckets, - default_value=0) for key, num_buckets in zip( - _transformed_names(_CATEGORICAL_FEATURE_KEYS), - _MAX_CATEGORICAL_FEATURE_VALUES) - ] - return tf_estimator.DNNLinearCombinedClassifier( - config=config, - linear_feature_columns=categorical_columns, - dnn_feature_columns=real_valued_columns, - dnn_hidden_units=hidden_units or [100, 70, 50, 25], - warm_start_from=warm_start_from) - - -def _example_serving_receiver_fn(tf_transform_output, schema): - """Build the serving in inputs. +# TFX Trainer will call this function. +def run_fn(fn_args: fn_args_utils.FnArgs): + """Train the model based on given args. Args: - tf_transform_output: A TFTransformOutput. - schema: the schema of the input data. - - Returns: - Tensorflow graph which parses examples, applying tf-transform to them. - """ - raw_feature_spec = _get_raw_feature_spec(schema) - raw_feature_spec.pop(_LABEL_KEY) - - raw_input_fn = tf_estimator.export.build_parsing_serving_input_receiver_fn( - raw_feature_spec, default_batch_size=None) - serving_input_receiver = raw_input_fn() - - transformed_features = tf_transform_output.transform_raw_features( - serving_input_receiver.features) - - return tf_estimator.export.ServingInputReceiver( - transformed_features, serving_input_receiver.receiver_tensors) - - -def _eval_input_receiver_fn(tf_transform_output, schema): - """Build everything needed for the tf-model-analysis to run the model. - - Args: - tf_transform_output: A TFTransformOutput. - schema: the schema of the input data. - - Returns: - EvalInputReceiver function, which contains: - - Tensorflow graph which parses raw untransformed features, applies the - tf-transform preprocessing operators. - - Set of raw, untransformed features. - - Label against which predictions will be compared. - """ - # Notice that the inputs are raw features, not transformed features here. - raw_feature_spec = _get_raw_feature_spec(schema) - - serialized_tf_example = tf.compat.v1.placeholder( - dtype=tf.string, shape=[None], name='input_example_tensor') - - # Add a parse_example operator to the tensorflow graph, which will parse - # raw, untransformed, tf examples. - features = tf.io.parse_example( - serialized=serialized_tf_example, features=raw_feature_spec) - - # Now that we have our raw examples, process them through the tf-transform - # function computed during the preprocessing step. - transformed_features = tf_transform_output.transform_raw_features( - features) - - # The key name MUST be 'examples'. - receiver_tensors = {'examples': serialized_tf_example} - - # NOTE: Model is driven by transformed features (since training works on the - # materialized output of TFT, but slicing will happen on raw features. - features.update(transformed_features) - - return tfma.export.EvalInputReceiver( - features=features, - receiver_tensors=receiver_tensors, - labels=transformed_features[_transformed_name(_LABEL_KEY)]) - - -def _input_fn(file_pattern: List[str], - data_accessor: DataAccessor, - tf_transform_output: tft.TFTransformOutput, - batch_size: int = 200) -> tf.data.Dataset: - """Generates features and label for tuning/training. - - Args: - file_pattern: List of paths or patterns of input tfrecord files. - data_accessor: DataAccessor for converting input to RecordBatch. - tf_transform_output: A TFTransformOutput. - batch_size: representing the number of consecutive elements of returned - dataset to combine in a single batch - - Returns: - A dataset that contains (features, indices) tuple where features is a - dictionary of Tensors, and indices is a single Tensor of label indices. - """ - return data_accessor.tf_dataset_factory( - file_pattern, - dataset_options.TensorFlowDatasetOptions( - batch_size=batch_size, label_key=_transformed_name(_LABEL_KEY)), - tf_transform_output.transformed_metadata.schema) - - -# TFX will call this function -def trainer_fn(trainer_fn_args, schema): - """Build the estimator using the high level API. - - Args: - trainer_fn_args: Holds args used to train the model as name/value pairs. - schema: Holds the schema of the training examples. - - Returns: - A dict of the following: - - estimator: The estimator that will be used for training and eval. - - train_spec: Spec for training. - - eval_spec: Spec for eval. - - eval_input_receiver_fn: Input function for eval. + fn_args: Holds args used to train the model as name/value pairs. """ # Number of nodes in the first layer of the DNN first_dnn_layer_size = 100 num_dnn_layers = 4 dnn_decay_factor = 0.7 - train_batch_size = 40 - eval_batch_size = 40 - - tf_transform_output = tft.TFTransformOutput(trainer_fn_args.transform_output) - - train_input_fn = lambda: _input_fn( # pylint: disable=g-long-lambda - trainer_fn_args.train_files, - trainer_fn_args.data_accessor, - tf_transform_output, - batch_size=train_batch_size) - - eval_input_fn = lambda: _input_fn( # pylint: disable=g-long-lambda - trainer_fn_args.eval_files, - trainer_fn_args.data_accessor, - tf_transform_output, - batch_size=eval_batch_size) - - train_spec = tf_estimator.TrainSpec( # pylint: disable=g-long-lambda - train_input_fn, - max_steps=trainer_fn_args.train_steps) - - serving_receiver_fn = lambda: _example_serving_receiver_fn( # pylint: disable=g-long-lambda - tf_transform_output, schema) - - exporter = tf_estimator.FinalExporter('chicago-taxi', serving_receiver_fn) - eval_spec = tf_estimator.EvalSpec( - eval_input_fn, - steps=trainer_fn_args.eval_steps, - exporters=[exporter], - name='chicago-taxi-eval') - - # Keep multiple checkpoint files for distributed training, note that - # keep_max_checkpoint should be greater or equal to the number of replicas to - # avoid race condition. - run_config = tf_estimator.RunConfig( - save_checkpoints_steps=999, keep_checkpoint_max=5) - - run_config = run_config.replace(model_dir=trainer_fn_args.serving_model_dir) - warm_start_from = trainer_fn_args.base_model - - estimator = _build_estimator( - # Construct layers sizes with exponetial decay - hidden_units=[ - max(2, int(first_dnn_layer_size * dnn_decay_factor**i)) - for i in range(num_dnn_layers) - ], - config=run_config, - warm_start_from=warm_start_from) - - # Create an input receiver for TFMA processing - receiver_fn = lambda: _eval_input_receiver_fn( # pylint: disable=g-long-lambda - tf_transform_output, schema) - - return { - 'estimator': estimator, - 'train_spec': train_spec, - 'eval_spec': eval_spec, - 'eval_input_receiver_fn': receiver_fn + tf_transform_output = tft.TFTransformOutput(fn_args.transform_graph_path) + + train_dataset = _input_fn( + fn_args.train_files, fn_args.data_accessor, tf_transform_output, 40 + ) + eval_dataset = _input_fn( + fn_args.eval_files, fn_args.data_accessor, tf_transform_output, 40 + ) + + mirrored_strategy = tf.distribute.MirroredStrategy() + with mirrored_strategy.scope(): + model = _build_keras_model( + # Construct layers sizes with exponetial decay + hidden_units=[ + max(2, int(first_dnn_layer_size * dnn_decay_factor**i)) + for i in range(num_dnn_layers) + ] + ) + + # Write logs to path + tensorboard_callback = tf.keras.callbacks.TensorBoard( + log_dir=fn_args.model_run_dir, update_freq='epoch' + ) + + model.fit( + train_dataset, + steps_per_epoch=fn_args.train_steps, + validation_data=eval_dataset, + validation_steps=fn_args.eval_steps, + callbacks=[tensorboard_callback], + ) + + signatures = { + 'serving_default': _get_tf_examples_serving_signature( + model, tf_transform_output + ), + 'transform_features': _get_transform_features_signature( + model, tf_transform_output + ), } + model.save(fn_args.serving_model_dir, save_format='tf', signatures=signatures) diff --git a/tfx/examples/chicago_taxi_pipeline/taxi_utils_test.py b/tfx/examples/chicago_taxi_pipeline/taxi_utils_test.py index 931328e13c..ac123fc27d 100644 --- a/tfx/examples/chicago_taxi_pipeline/taxi_utils_test.py +++ b/tfx/examples/chicago_taxi_pipeline/taxi_utils_test.py @@ -14,24 +14,15 @@ """Tests for tfx.examples.chicago_taxi_pipeline.taxi_utils.""" import os -import types import apache_beam as beam import tensorflow as tf -from tensorflow import estimator as tf_estimator -import tensorflow_model_analysis as tfma import tensorflow_transform as tft from tensorflow_transform import beam as tft_beam from tensorflow_transform.tf_metadata import dataset_metadata from tensorflow_transform.tf_metadata import schema_utils -from tfx.components.trainer import executor as trainer_executor -from tfx.components.trainer.fn_args_utils import DataAccessor -from tfx.components.util import tfxio_utils -from tfx.dsl.io import fileio from tfx.examples.chicago_taxi_pipeline import taxi_utils -from tfx.types import standard_artifacts from tfx.utils import io_utils -from tfx.utils import path_utils from tfx_bsl.tfxio import tf_example_record from tensorflow_metadata.proto.v0 import schema_pb2 @@ -110,66 +101,3 @@ def testPreprocessingFn(self): for feature in transformed_schema.feature: feature.ClearField('annotation') self.assertEqual(transformed_schema, expected_transformed_schema) - - def testTrainerFn(self): - temp_dir = os.path.join( - os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), - self._testMethodName) - - schema_file = os.path.join(self._testdata_path, 'schema_gen/schema.pbtxt') - data_accessor = DataAccessor( - tf_dataset_factory=tfxio_utils.get_tf_dataset_factory_from_artifact( - [standard_artifacts.Examples()], []), - record_batch_factory=None, - data_view_decode_fn=None) - trainer_fn_args = trainer_executor.TrainerFnArgs( - train_files=os.path.join( - self._testdata_path, - 'transform/transformed_examples/Split-train/*.gz'), - transform_output=os.path.join(self._testdata_path, - 'transform/transform_graph'), - serving_model_dir=os.path.join(temp_dir, 'serving_model_dir'), - eval_files=os.path.join( - self._testdata_path, - 'transform/transformed_examples/Split-eval/*.gz'), - schema_file=schema_file, - train_steps=1, - eval_steps=1, - base_model=None, - data_accessor=data_accessor) - schema = io_utils.parse_pbtxt_file(schema_file, schema_pb2.Schema()) - training_spec = taxi_utils.trainer_fn(trainer_fn_args, schema) - - estimator = training_spec['estimator'] - train_spec = training_spec['train_spec'] - eval_spec = training_spec['eval_spec'] - eval_input_receiver_fn = training_spec['eval_input_receiver_fn'] - - self.assertIsInstance(estimator, - tf_estimator.DNNLinearCombinedClassifier) - self.assertIsInstance(train_spec, tf_estimator.TrainSpec) - self.assertIsInstance(eval_spec, tf_estimator.EvalSpec) - self.assertIsInstance(eval_input_receiver_fn, types.FunctionType) - - # Test keep_max_checkpoint in RunConfig - self.assertGreater(estimator._config.keep_checkpoint_max, 1) - - # Train for one step, then eval for one step. - eval_result, exports = tf_estimator.train_and_evaluate( - estimator, train_spec, eval_spec) - self.assertGreater(eval_result['loss'], 0.0) - self.assertEqual(len(exports), 1) - self.assertGreaterEqual(len(fileio.listdir(exports[0])), 1) - - # Export the eval saved model. - eval_savedmodel_path = tfma.export.export_eval_savedmodel( - estimator=estimator, - export_dir_base=path_utils.eval_model_dir(temp_dir), - eval_input_receiver_fn=eval_input_receiver_fn) - self.assertGreaterEqual(len(fileio.listdir(eval_savedmodel_path)), 1) - - # Test exported serving graph. - with tf.compat.v1.Session() as sess: - metagraph_def = tf.compat.v1.saved_model.loader.load( - sess, [tf.saved_model.SERVING], exports[0]) - self.assertIsInstance(metagraph_def, tf.compat.v1.MetaGraphDef) diff --git a/tfx/examples/cifar10/README.md b/tfx/examples/cifar10/README.md deleted file mode 100644 index 7a524c7a53..0000000000 --- a/tfx/examples/cifar10/README.md +++ /dev/null @@ -1,66 +0,0 @@ -# CIFAR-10 Transfer Learning and MLKit integration Example - -This example illustrates how to use Transfer Learning for image classification -with TFX, and use trained model to do object detection with -[MLKit](https://developers.google.com/ml-kit) - -## Instruction - -Create a Python 3 virtual environment for this example and activate the -`virtualenv`: - -``` -virtualenv -p python3.7 cifar10 -source ./cifar10/bin/activate -``` - -Then, clone the tfx repo and copy cifar10/ folder to home directory: - -``` -git clone https://github.com/tensorflow/tfx ~/tfx-source && pushd ~/tfx-source -cp -r ~/tfx-source/tfx/examples/cifar10 ~/ -``` - -Next, install the dependencies required by the CIFAR-10 example (appropriate -version of TF2 will be installed automatically). - -``` -pip install -e cifar10/ -# The following is needed until tensorflow-model-analysis 0.23.0 is released -pip uinstall tensorflow-model-analysis -pip install git+https://github.com/tensorflow/model-analysis.git#egg=tensorflow_model_analysis -``` - -### Dataset - -There is a subset of CIFAR10 (128 images) available in the data folder. To -prepare the whole dataset, first create a script and run the following Python -code: `import tensorflow_datasets as tfds ds = tfds.load('cifar10', -data_dir='./cifar10/data/',split=['train', 'test'])` Then, create sub-folders -for different dataset splits and move different splits to corresponding folders. -`cd cifar10/data mkdir train_whole mkdir test_whole mv -cifar10/3.0.2/cifar10-train.tfrecord-00000-of-00001 train_whole mv -cifar10/3.0.2/cifar10-test.tfrecord-00000-of-00001 test_whole` You'll find the -final dataset under `train_whole` and `test_whole` folders. Finally, clean up -the data folder. `rm -r cifar10` - -### Train the model - -Execute the pipeline python file : `python -~/cifar10/cifar_pipeline_native_keras.py` The trained model is located at -`~/cifar10/serving_model_lite/tflite` - -This model is ready to be used for object detection with MLKit. Follow MLKit's -[documentation](https://developers.google.com/ml-kit/vision/object-detection/custom-models/android) -to set up an App and use it. - -## Acknowledge Data Source - -``` -@TECHREPORT{Krizhevsky09learningmultiple, - author = {Alex Krizhevsky}, - title = {Learning multiple layers of features from tiny images}, - institution = {}, - year = {2009} -} -``` diff --git a/tfx/examples/cifar10/__init__.py b/tfx/examples/cifar10/__init__.py deleted file mode 100644 index b179ecb83a..0000000000 --- a/tfx/examples/cifar10/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2020 Google LLC. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/tfx/examples/cifar10/cifar10_pipeline_native_keras.py b/tfx/examples/cifar10/cifar10_pipeline_native_keras.py deleted file mode 100644 index da6b4b618f..0000000000 --- a/tfx/examples/cifar10/cifar10_pipeline_native_keras.py +++ /dev/null @@ -1,217 +0,0 @@ -# Copyright 2019 Google LLC. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""CIFAR10 image classification example using TFX. - -This example demonstrates how to do data augmentation, transfer learning, -and inserting TFLite metadata with TFX. -The trained model can be pluged into MLKit for object detection. -""" - -import os -from typing import List - -import absl -import tensorflow_model_analysis as tfma -from tfx.components import Evaluator -from tfx.components import ExampleValidator -from tfx.components import ImportExampleGen -from tfx.components import Pusher -from tfx.components import SchemaGen -from tfx.components import StatisticsGen -from tfx.components import Trainer -from tfx.components import Transform -from tfx.dsl.components.common import resolver -from tfx.dsl.experimental import latest_blessed_model_resolver -from tfx.orchestration import metadata -from tfx.orchestration import pipeline -from tfx.orchestration.beam.beam_dag_runner import BeamDagRunner -from tfx.proto import example_gen_pb2 -from tfx.proto import pusher_pb2 -from tfx.proto import trainer_pb2 -from tfx.types import Channel -from tfx.types.standard_artifacts import Model -from tfx.types.standard_artifacts import ModelBlessing - -_pipeline_name = 'cifar10_native_keras' - -# This example assumes that CIFAR10 train set data is stored in -# ~/cifar10/data/train, test set data is stored in ~/cifar10/data/test, and -# the utility function is in ~/cifar10. Feel free to customize as needed. -_cifar10_root = os.path.join(os.environ['HOME'], 'cifar10') -_data_root = os.path.join(_cifar10_root, 'data') -# Python module files to inject customized logic into the TFX components. The -# Transform and Trainer both require user-defined functions to run successfully. -_module_file = os.path.join(_cifar10_root, 'cifar10_utils_native_keras.py') -# Path which can be listened to by the model server. Pusher will output the -# trained model here. -_serving_model_dir_lite = os.path.join(_cifar10_root, 'serving_model_lite', - _pipeline_name) - -# Directory and data locations. This example assumes all of the images, -# example code, and metadata library is relative to $HOME, but you can store -# these files anywhere on your local filesystem. -_tfx_root = os.path.join(os.environ['HOME'], 'tfx') -_pipeline_root = os.path.join(_tfx_root, 'pipelines', _pipeline_name) -# Sqlite ML-metadata db path. -_metadata_path = os.path.join(_tfx_root, 'metadata', _pipeline_name, - 'metadata.db') -# Path to labels file for mapping model outputs. -_labels_path = os.path.join(_data_root, 'labels.txt') - - -# Pipeline arguments for Beam powered Components. -_beam_pipeline_args = [ - '--direct_running_mode=multi_processing', - # 0 means auto-detect based on on the number of CPUs available - # during execution time. - '--direct_num_workers=0', -] - - -def _create_pipeline(pipeline_name: str, - pipeline_root: str, - data_root: str, - module_file: str, - serving_model_dir_lite: str, - metadata_path: str, - labels_path: str, - beam_pipeline_args: List[str], - accuracy_threshold: float = 0.55) -> pipeline.Pipeline: - """Implements the CIFAR10 image classification pipeline using TFX.""" - # This is needed for datasets with pre-defined splits - # Change the pattern argument to train_whole/* and test_whole/* to train - # on the whole CIFAR-10 dataset - input_config = example_gen_pb2.Input(splits=[ - example_gen_pb2.Input.Split(name='train', pattern='train/*'), - example_gen_pb2.Input.Split(name='eval', pattern='test/*') - ]) - - # Brings data into the pipeline. - example_gen = ImportExampleGen( - input_base=data_root, input_config=input_config) - - # Computes statistics over data for visualization and example validation. - statistics_gen = StatisticsGen(examples=example_gen.outputs['examples']) - - # Generates schema based on statistics files. - schema_gen = SchemaGen( - statistics=statistics_gen.outputs['statistics'], infer_feature_shape=True) - - # Performs anomaly detection based on statistics and data schema. - example_validator = ExampleValidator( - statistics=statistics_gen.outputs['statistics'], - schema=schema_gen.outputs['schema']) - - # Performs transformations and feature engineering in training and serving. - transform = Transform( - examples=example_gen.outputs['examples'], - schema=schema_gen.outputs['schema'], - module_file=module_file) - - # Uses user-provided Python function that trains a model. - # When traning on the whole dataset, use 18744 for train steps, 156 for eval - # steps. 18744 train steps correspond to 24 epochs on the whole train set, and - # 156 eval steps correspond to 1 epoch on the whole test set. The - # configuration below is for training on the dataset we provided in the data - # folder, which has 128 train and 128 test samples. The 160 train steps - # correspond to 40 epochs on this tiny train set, and 4 eval steps correspond - # to 1 epoch on this tiny test set. - trainer = Trainer( - module_file=module_file, - examples=transform.outputs['transformed_examples'], - transform_graph=transform.outputs['transform_graph'], - schema=schema_gen.outputs['schema'], - train_args=trainer_pb2.TrainArgs(num_steps=160), - eval_args=trainer_pb2.EvalArgs(num_steps=4), - custom_config={'labels_path': labels_path}) - - # Get the latest blessed model for model validation. - model_resolver = resolver.Resolver( - strategy_class=latest_blessed_model_resolver.LatestBlessedModelResolver, - model=Channel(type=Model), - model_blessing=Channel( - type=ModelBlessing)).with_id('latest_blessed_model_resolver') - - # Uses TFMA to compute evaluation statistics over features of a model and - # perform quality validation of a candidate model (compare to a baseline). - eval_config = tfma.EvalConfig( - model_specs=[tfma.ModelSpec(label_key='label_xf', model_type='tf_lite')], - slicing_specs=[tfma.SlicingSpec()], - metrics_specs=[ - tfma.MetricsSpec(metrics=[ - tfma.MetricConfig( - class_name='SparseCategoricalAccuracy', - threshold=tfma.MetricThreshold( - value_threshold=tfma.GenericValueThreshold( - lower_bound={'value': accuracy_threshold}), - # Change threshold will be ignored if there is no - # baseline model resolved from MLMD (first run). - change_threshold=tfma.GenericChangeThreshold( - direction=tfma.MetricDirection.HIGHER_IS_BETTER, - absolute={'value': -1e-3}))) - ]) - ]) - - # Uses TFMA to compute the evaluation statistics over features of a model. - # We evaluate using the materialized examples that are output by Transform - # because - # 1. the decoding_png function currently performed within Transform are not - # compatible with TFLite. - # 2. MLKit requires deserialized (float32) tensor image inputs - # Note that for deployment, the same logic that is performed within Transform - # must be reproduced client-side. - evaluator = Evaluator( - examples=transform.outputs['transformed_examples'], - model=trainer.outputs['model'], - baseline_model=model_resolver.outputs['model'], - eval_config=eval_config) - - # Checks whether the model passed the validation steps and pushes the model - # to a file destination if check passed. - pusher = Pusher( - model=trainer.outputs['model'], - model_blessing=evaluator.outputs['blessing'], - push_destination=pusher_pb2.PushDestination( - filesystem=pusher_pb2.PushDestination.Filesystem( - base_directory=serving_model_dir_lite))) - - components = [ - example_gen, statistics_gen, schema_gen, example_validator, transform, - trainer, model_resolver, evaluator, pusher - ] - - return pipeline.Pipeline( - pipeline_name=pipeline_name, - pipeline_root=pipeline_root, - components=components, - enable_cache=True, - metadata_connection_config=metadata.sqlite_metadata_connection_config( - metadata_path), - beam_pipeline_args=beam_pipeline_args) - - -# To run this pipeline from the python CLI: -# $python cifar_pipeline_native_keras.py -if __name__ == '__main__': - absl.logging.set_verbosity(absl.logging.INFO) - BeamDagRunner().run( - _create_pipeline( - pipeline_name=_pipeline_name, - pipeline_root=_pipeline_root, - data_root=_data_root, - module_file=_module_file, - serving_model_dir_lite=_serving_model_dir_lite, - metadata_path=_metadata_path, - labels_path=_labels_path, - beam_pipeline_args=_beam_pipeline_args)) diff --git a/tfx/examples/cifar10/cifar10_utils_native_keras.py b/tfx/examples/cifar10/cifar10_utils_native_keras.py deleted file mode 100644 index e0ca5478cf..0000000000 --- a/tfx/examples/cifar10/cifar10_utils_native_keras.py +++ /dev/null @@ -1,405 +0,0 @@ -# Copyright 2019 Google LLC. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Python source file includes CIFAR10 utils for Keras model. - -The utilities in this file are used to build a model with native Keras. -This module file will be used in Transform and generic Trainer. -""" - -import os -from typing import List -import absl -import flatbuffers -import tensorflow as tf -import tensorflow_transform as tft - -from tfx.components.trainer.fn_args_utils import DataAccessor -from tfx.components.trainer.fn_args_utils import FnArgs -from tfx.components.trainer.rewriting import converters -from tfx.components.trainer.rewriting import rewriter -from tfx.components.trainer.rewriting import rewriter_factory -from tfx.dsl.io import fileio -from tfx_bsl.tfxio import dataset_options - -from tflite_support import metadata_schema_py_generated as _metadata_fb -from tflite_support import metadata as _metadata - -# When training on the whole dataset use following constants instead. -# This setting should give ~91% accuracy on the whole test set -# _TRAIN_DATA_SIZE = 50000 -# _EVAL_DATA_SIZE = 10000 -# _TRAIN_BATCH_SIZE = 64 -# _EVAL_BATCH_SIZE = 64 -# _CLASSIFIER_LEARNING_RATE = 3e-4 -# _FINETUNE_LEARNING_RATE = 5e-5 -# _CLASSIFIER_EPOCHS = 12 - -_TRAIN_DATA_SIZE = 128 -_EVAL_DATA_SIZE = 128 -_TRAIN_BATCH_SIZE = 32 -_EVAL_BATCH_SIZE = 32 -_CLASSIFIER_LEARNING_RATE = 1e-3 -_FINETUNE_LEARNING_RATE = 7e-6 -_CLASSIFIER_EPOCHS = 30 - -_IMAGE_KEY = 'image' -_LABEL_KEY = 'label' - -_TFLITE_MODEL_NAME = 'tflite' - - -def _transformed_name(key): - return key + '_xf' - - -def _get_serve_image_fn(model): - """Returns a function that feeds the input tensor into the model.""" - - @tf.function - def serve_image_fn(image_tensor): - """Returns the output to be used in the serving signature. - - Args: - image_tensor: A tensor represeting input image. The image should have 3 - channels. - - Returns: - The model's predicton on input image tensor - """ - return model(image_tensor) - - return serve_image_fn - - -def _image_augmentation(image_features): - """Perform image augmentation on batches of images . - - Args: - image_features: a batch of image features - - Returns: - The augmented image features - """ - batch_size = tf.shape(image_features)[0] - image_features = tf.image.random_flip_left_right(image_features) - image_features = tf.image.resize_with_crop_or_pad(image_features, 250, 250) - image_features = tf.image.random_crop(image_features, - (batch_size, 224, 224, 3)) - return image_features - - -def _data_augmentation(feature_dict): - """Perform data augmentation on batches of data. - - Args: - feature_dict: a dict containing features of samples - - Returns: - The feature dict with augmented features - """ - image_features = feature_dict[_transformed_name(_IMAGE_KEY)] - image_features = _image_augmentation(image_features) - feature_dict[_transformed_name(_IMAGE_KEY)] = image_features - return feature_dict - - -def _input_fn(file_pattern: List[str], - data_accessor: DataAccessor, - tf_transform_output: tft.TFTransformOutput, - is_train: bool = False, - batch_size: int = 200) -> tf.data.Dataset: - """Generates features and label for tuning/training. - - Args: - file_pattern: List of paths or patterns of input tfrecord files. - data_accessor: DataAccessor for converting input to RecordBatch. - tf_transform_output: A TFTransformOutput. - is_train: Whether the input dataset is train split or not. - batch_size: representing the number of consecutive elements of returned - dataset to combine in a single batch - - Returns: - A dataset that contains (features, indices) tuple where features is a - dictionary of Tensors, and indices is a single Tensor of label indices. - """ - dataset = data_accessor.tf_dataset_factory( - file_pattern, - dataset_options.TensorFlowDatasetOptions( - batch_size=batch_size, label_key=_transformed_name(_LABEL_KEY)), - tf_transform_output.transformed_metadata.schema) - # Apply data augmentation. We have to do data augmentation here because - # we need to apply data agumentation on-the-fly during training. If we put - # it in Transform, it will only be applied once on the whole dataset, which - # will lose the point of data augmentation. - if is_train: - dataset = dataset.map(lambda x, y: (_data_augmentation(x), y)) - - return dataset - - -def _freeze_model_by_percentage(model: tf.keras.Model, percentage: float): - """Freeze part of the model based on specified percentage. - - Args: - model: The keras model need to be partially frozen - percentage: the percentage of layers to freeze - - Raises: - ValueError: Invalid values. - """ - if percentage < 0 or percentage > 1: - raise ValueError('Freeze percentage should between 0.0 and 1.0') - - if not model.trainable: - raise ValueError( - 'The model is not trainable, please set model.trainable to True') - - num_layers = len(model.layers) - num_layers_to_freeze = int(num_layers * percentage) - for idx, layer in enumerate(model.layers): - if idx < num_layers_to_freeze: - layer.trainable = False - else: - layer.trainable = True - - -def _build_keras_model() -> tf.keras.Model: - """Creates a Image classification model with MobileNet backbone. - - Returns: - The image classifcation Keras Model and the backbone MobileNet model - """ - # We create a MobileNet model with weights pre-trained on ImageNet. - # We remove the top classification layer of the MobileNet, which was - # used for classifying ImageNet objects. We will add our own classification - # layer for CIFAR10 later. We use average pooling at the last convolution - # layer to get a 1D vector for classifcation, which is consistent with the - # origin MobileNet setup - base_model = tf.keras.applications.MobileNet( - input_shape=(224, 224, 3), - include_top=False, - weights='imagenet', - pooling='avg') - base_model.input_spec = None - - # We add a Dropout layer at the top of MobileNet backbone we just created to - # prevent overfiting, and then a Dense layer to classifying CIFAR10 objects - model = tf.keras.Sequential([ - tf.keras.layers.InputLayer( - input_shape=(224, 224, 3), name=_transformed_name(_IMAGE_KEY)), - base_model, - tf.keras.layers.Dropout(0.1), - tf.keras.layers.Dense(10) - ]) - - # Freeze the whole MobileNet backbone to first train the top classifer only - _freeze_model_by_percentage(base_model, 1.0) - - model.compile( - loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), - optimizer=tf.keras.optimizers.RMSprop(lr=_CLASSIFIER_LEARNING_RATE), - metrics=['sparse_categorical_accuracy']) - model.summary(print_fn=absl.logging.info) - - return model, base_model - - -# TFX Transform will call this function. -def preprocessing_fn(inputs): - """tf.transform's callback function for preprocessing inputs. - - Args: - inputs: map from feature keys to raw not-yet-transformed features. - - Returns: - Map from string feature key to transformed feature operations. - """ - outputs = {} - - # tf.io.decode_png function cannot be applied on a batch of data. - # We have to use tf.map_fn - image_features = tf.map_fn( - lambda x: tf.io.decode_png(x[0], channels=3), - inputs[_IMAGE_KEY], - dtype=tf.uint8) - # image_features = tf.cast(image_features, tf.float32) - image_features = tf.image.resize(image_features, [224, 224]) - image_features = tf.keras.applications.mobilenet.preprocess_input( - image_features) - - outputs[_transformed_name(_IMAGE_KEY)] = image_features - # TODO(b/157064428): Support label transformation for Keras. - # Do not apply label transformation as it will result in wrong evaluation. - outputs[_transformed_name(_LABEL_KEY)] = inputs[_LABEL_KEY] - - return outputs - - -def _write_metadata(model_path: str, label_map_path: str, mean: List[float], - std: List[float]): - """Add normalization option and label map TFLite metadata to the model. - - Args: - model_path: The path of the TFLite model - label_map_path: The path of the label map file - mean: The mean value used to normalize input image tensor - std: The standard deviation used to normalize input image tensor - """ - - # Creates flatbuffer for model information. - model_meta = _metadata_fb.ModelMetadataT() - - # Creates flatbuffer for model input metadata. - # Here we add the input normalization info to input metadata. - input_meta = _metadata_fb.TensorMetadataT() - input_normalization = _metadata_fb.ProcessUnitT() - input_normalization.optionsType = ( - _metadata_fb.ProcessUnitOptions.NormalizationOptions) - input_normalization.options = _metadata_fb.NormalizationOptionsT() - input_normalization.options.mean = mean - input_normalization.options.std = std - input_meta.processUnits = [input_normalization] - - # Creates flatbuffer for model output metadata. - # Here we add label file to output metadata. - output_meta = _metadata_fb.TensorMetadataT() - label_file = _metadata_fb.AssociatedFileT() - label_file.name = os.path.basename(label_map_path) - label_file.type = _metadata_fb.AssociatedFileType.TENSOR_AXIS_LABELS - output_meta.associatedFiles = [label_file] - - # Creates subgraph to contain input and output information, - # and add subgraph to the model information. - subgraph = _metadata_fb.SubGraphMetadataT() - subgraph.inputTensorMetadata = [input_meta] - subgraph.outputTensorMetadata = [output_meta] - model_meta.subgraphMetadata = [subgraph] - - # Serialize the model metadata buffer we created above using flatbuffer - # builder. - b = flatbuffers.Builder(0) - b.Finish( - model_meta.Pack(b), _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) - metadata_buf = b.Output() - - # Populates metadata and label file to the model file. - populator = _metadata.MetadataPopulator.with_model_file(model_path) - populator.load_metadata_buffer(metadata_buf) - populator.load_associated_files([label_map_path]) - populator.populate() - - -# TFX Trainer will call this function. -def run_fn(fn_args: FnArgs): - """Train the model based on given args. - - Args: - fn_args: Holds args used to train the model as name/value pairs. - - Raises: - ValueError: if invalid inputs. - """ - tf_transform_output = tft.TFTransformOutput(fn_args.transform_output) - - train_dataset = _input_fn( - fn_args.train_files, - fn_args.data_accessor, - tf_transform_output, - is_train=True, - batch_size=_TRAIN_BATCH_SIZE) - eval_dataset = _input_fn( - fn_args.eval_files, - fn_args.data_accessor, - tf_transform_output, - is_train=False, - batch_size=_EVAL_BATCH_SIZE) - - model, base_model = _build_keras_model() - - absl.logging.info('Tensorboard logging to {}'.format(fn_args.model_run_dir)) - # Write logs to path - tensorboard_callback = tf.keras.callbacks.TensorBoard( - log_dir=fn_args.model_run_dir, update_freq='epoch') - - # Our training regime has two phases: we first freeze the backbone and train - # the newly added classifier only, then unfreeze part of the backbone and - # fine-tune with classifier jointly. - steps_per_epoch = int(_TRAIN_DATA_SIZE / _TRAIN_BATCH_SIZE) - total_epochs = int(fn_args.train_steps / steps_per_epoch) - if _CLASSIFIER_EPOCHS > total_epochs: - raise ValueError('Classifier epochs is greater than the total epochs') - - absl.logging.info('Start training the top classifier') - model.fit( - train_dataset, - epochs=_CLASSIFIER_EPOCHS, - steps_per_epoch=steps_per_epoch, - validation_data=eval_dataset, - validation_steps=fn_args.eval_steps, - callbacks=[tensorboard_callback]) - - absl.logging.info('Start fine-tuning the model') - # Unfreeze the top MobileNet layers and do joint fine-tuning - _freeze_model_by_percentage(base_model, 0.9) - - # We need to recompile the model because layer properties have changed - model.compile( - loss='sparse_categorical_crossentropy', - optimizer=tf.keras.optimizers.RMSprop(lr=_FINETUNE_LEARNING_RATE), - metrics=['sparse_categorical_accuracy']) - model.summary(print_fn=absl.logging.info) - - model.fit( - train_dataset, - initial_epoch=_CLASSIFIER_EPOCHS, - epochs=total_epochs, - steps_per_epoch=steps_per_epoch, - validation_data=eval_dataset, - validation_steps=fn_args.eval_steps, - callbacks=[tensorboard_callback]) - - # Prepare the TFLite model used for serving in MLKit - signatures = { - 'serving_default': - _get_serve_image_fn(model).get_concrete_function( - tf.TensorSpec( - shape=[None, 224, 224, 3], - dtype=tf.float32, - name=_transformed_name(_IMAGE_KEY))) - } - - temp_saving_model_dir = os.path.join(fn_args.serving_model_dir, 'temp') - model.save(temp_saving_model_dir, save_format='tf', signatures=signatures) - - tfrw = rewriter_factory.create_rewriter( - rewriter_factory.TFLITE_REWRITER, - name='tflite_rewriter') - converters.rewrite_saved_model(temp_saving_model_dir, - fn_args.serving_model_dir, tfrw, - rewriter.ModelType.TFLITE_MODEL) - - # Add necessary TFLite metadata to the model in order to use it within MLKit - # TODO(dzats@): Handle label map file path more properly, currently - # hard-coded. - tflite_model_path = os.path.join(fn_args.serving_model_dir, - _TFLITE_MODEL_NAME) - # TODO(dzats@): Extend the TFLite rewriter to be able to add TFLite metadata - #@ to the model. - _write_metadata( - model_path=tflite_model_path, - label_map_path=fn_args.custom_config['labels_path'], - mean=[127.5], - std=[127.5]) - - fileio.rmtree(temp_saving_model_dir) diff --git a/tfx/examples/cifar10/data/labels.txt b/tfx/examples/cifar10/data/labels.txt deleted file mode 100644 index fa30c22b95..0000000000 --- a/tfx/examples/cifar10/data/labels.txt +++ /dev/null @@ -1,10 +0,0 @@ -airplane -automobile -bird -cat -deer -dog -frog -horse -ship -truck diff --git a/tfx/examples/cifar10/data/test/cifar10_test.tfrecord b/tfx/examples/cifar10/data/test/cifar10_test.tfrecord deleted file mode 100644 index 3fe6a73d85..0000000000 Binary files a/tfx/examples/cifar10/data/test/cifar10_test.tfrecord and /dev/null differ diff --git a/tfx/examples/cifar10/data/train/cifar10_train.tfrecord b/tfx/examples/cifar10/data/train/cifar10_train.tfrecord deleted file mode 100644 index 68399e97fc..0000000000 Binary files a/tfx/examples/cifar10/data/train/cifar10_train.tfrecord and /dev/null differ diff --git a/tfx/examples/custom_components/slack/example/taxi_pipeline_slack_kubeflow.py b/tfx/examples/custom_components/slack/example/taxi_pipeline_slack_kubeflow.py index 8f0175a67f..bcdd21ee11 100644 --- a/tfx/examples/custom_components/slack/example/taxi_pipeline_slack_kubeflow.py +++ b/tfx/examples/custom_components/slack/example/taxi_pipeline_slack_kubeflow.py @@ -53,7 +53,7 @@ # Python module file to inject customized logic into the TFX components. The # Transform and Trainer both require user-defined functions to run successfully. -_taxi_trainer_func = 'example.taxi_utils_slack.trainer_fn' +_taxi_module_file = os.path.join(_taxi_root, 'taxi_utils_slack.py') _taxi_transformer_func = 'example.taxi_utils_slack.preprocessing_fn' # Path which can be listened to by the model server. Pusher will output the # trained model here. @@ -104,7 +104,7 @@ def _create_pipeline(): # Uses user-provided Python function that implements a model. trainer = Trainer( - trainer_fn=_taxi_trainer_func, + module_file=_taxi_module_file, examples=transform.outputs['transformed_examples'], schema=schema_gen.outputs['schema'], transform_graph=transform.outputs['transform_graph'], diff --git a/tfx/examples/custom_components/slack/example/taxi_utils_slack.py b/tfx/examples/custom_components/slack/example/taxi_utils_slack.py index 253b25001c..4fdc7550e6 100644 --- a/tfx/examples/custom_components/slack/example/taxi_utils_slack.py +++ b/tfx/examples/custom_components/slack/example/taxi_utils_slack.py @@ -13,29 +13,29 @@ # limitations under the License. """Python source file include taxi pipeline functions and necesasry utils. -For a TFX pipeline to successfully run, a preprocessing_fn and a -_build_estimator function needs to be provided. This file contains both. - -This file is equivalent to examples/chicago_taxi/trainer/model.py and -examples/chicago_taxi/preprocess.py. +The utilities in this file are used to build a model with native Keras. +This module file will be used in Transform and generic Trainer. """ -from typing import List +from typing import Optional + +from absl import logging import tensorflow as tf -from tensorflow import estimator as tf_estimator -import tensorflow_model_analysis as tfma import tensorflow_transform as tft -from tensorflow_transform.tf_metadata import schema_utils -from tfx.components.trainer.fn_args_utils import DataAccessor +from tfx.components.trainer import fn_args_utils from tfx_bsl.tfxio import dataset_options # Categorical features are assumed to each have a maximum value in the dataset. -_MAX_CATEGORICAL_FEATURE_VALUES = [24, 31, 12] +_MAX_CATEGORICAL_FEATURE_VALUES = [24, 31, 13] _CATEGORICAL_FEATURE_KEYS = [ - 'trip_start_hour', 'trip_start_day', 'trip_start_month', - 'pickup_census_tract', 'dropoff_census_tract', 'pickup_community_area', - 'dropoff_community_area' + 'trip_start_hour', + 'trip_start_day', + 'trip_start_month', + 'pickup_census_tract', + 'dropoff_census_tract', + 'pickup_community_area', + 'dropoff_community_area', ] _DENSE_FLOAT_FEATURE_KEYS = ['trip_miles', 'fare', 'trip_seconds'] @@ -44,8 +44,10 @@ _FEATURE_BUCKET_COUNT = 10 _BUCKET_FEATURE_KEYS = [ - 'pickup_latitude', 'pickup_longitude', 'dropoff_latitude', - 'dropoff_longitude' + 'pickup_latitude', + 'pickup_longitude', + 'dropoff_latitude', + 'dropoff_longitude', ] # Number of vocabulary terms used for encoding VOCAB_FEATURES by tf.transform @@ -72,33 +74,198 @@ def _transformed_names(keys): return [_transformed_name(key) for key in keys] -# Tf.Transform considers these features as "raw" -def _get_raw_feature_spec(schema): - return schema_utils.schema_as_feature_spec(schema).feature_spec - - def _fill_in_missing(x): """Replace missing values in a SparseTensor. Fills in missing values of `x` with '' or 0, and converts to a dense tensor. Args: - x: A `SparseTensor` of rank 2. Its dense shape should have size at most 1 + x: A `SparseTensor` of rank 2. Its dense shape should have size at most 1 in the second dimension. Returns: - A rank 1 tensor where missing values of `x` have been filled in. + A rank 1 tensor where missing values of `x` have been filled in. """ if not isinstance(x, tf.sparse.SparseTensor): return x default_value = '' if x.dtype == tf.string else 0 - return tf.squeeze( - tf.compat.v1.sparse_to_dense(x.indices, [x.dense_shape[0], 1], x.values, - default_value), - axis=1) + dense_tensor = tf.sparse.to_dense( + tf.SparseTensor(x.indices, x.values, [x.dense_shape[0], 1]), + default_value, + ) + return dense_tensor + + +def _get_tf_examples_serving_signature(model, tf_transform_output): + """Returns a serving signature that accepts `tensorflow.Example`.""" + model.tft_layer_inference = tf_transform_output.transform_features_layer() + + @tf.function( + input_signature=[ + tf.TensorSpec(shape=[None], dtype=tf.string, name='examples') + ] + ) + def serve_tf_examples_fn(serialized_tf_example): + raw_feature_spec = tf_transform_output.raw_feature_spec() + raw_feature_spec.pop(_LABEL_KEY) + raw_features = tf.io.parse_example(serialized_tf_example, raw_feature_spec) + transformed_features = model.tft_layer_inference(raw_features) + logging.info('serve_transformed_features = %s', transformed_features) + + outputs = model(transformed_features) + return {'outputs': outputs} + + return serve_tf_examples_fn + + +def _get_transform_features_signature(model, tf_transform_output): + """Returns a serving signature that accepts `tensorflow.Example`.""" + model.tft_layer_eval = tf_transform_output.transform_features_layer() + + @tf.function( + input_signature=[ + tf.TensorSpec(shape=[None], dtype=tf.string, name='examples') + ] + ) + def transform_features_fn(serialized_tf_example): + raw_feature_spec = tf_transform_output.raw_feature_spec() + raw_features = tf.io.parse_example(serialized_tf_example, raw_feature_spec) + transformed_features = model.tft_layer_eval(raw_features) + logging.info('eval_transformed_features = %s', transformed_features) + return transformed_features + + return transform_features_fn + + +def _input_fn( + file_pattern: list[str], + data_accessor: fn_args_utils.DataAccessor, + tf_transform_output: tft.TFTransformOutput, + batch_size: int = 200, +) -> tf.data.Dataset: + """Generates features and label for tuning/training. + + Args: + file_pattern: List of paths or patterns of input tfrecord files. + data_accessor: fn_args_utils.DataAccessor for converting input to + RecordBatch. + tf_transform_output: A TFTransformOutput. + batch_size: representing the number of consecutive elements of returned + dataset to combine in a single batch + Returns: + A dataset that contains (features, indices) tuple where features is a + dictionary of Tensors, and indices is a single Tensor of label indices. + """ + return data_accessor.tf_dataset_factory( + file_pattern, + dataset_options.TensorFlowDatasetOptions( + batch_size=batch_size, label_key=_transformed_name(_LABEL_KEY) + ), + tf_transform_output.transformed_metadata.schema, + ).repeat() + + +def _build_keras_model( + hidden_units: Optional[list[int]] = None, +) -> tf.keras.Model: + """Creates a DNN Keras model for classifying taxi data. + + Args: + hidden_units: [int], the layer sizes of the DNN (input layer first). + Returns: + A Wide and Deep keras Model. + """ + # Following values are hard coded for simplicity in this example, + # However prefarably they should be passsed in as hparams. + + # Keras needs the feature definitions at compile time. + deep_input = { + colname: tf.keras.layers.Input(name=colname, shape=(1,), dtype=tf.float32) + for colname in _transformed_names(_DENSE_FLOAT_FEATURE_KEYS) + } + wide_vocab_input = { + colname: tf.keras.layers.Input(name=colname, shape=(1,), dtype='int32') + for colname in _transformed_names(_VOCAB_FEATURE_KEYS) + } + wide_bucket_input = { + colname: tf.keras.layers.Input(name=colname, shape=(1,), dtype='int32') + for colname in _transformed_names(_BUCKET_FEATURE_KEYS) + } + wide_categorical_input = { + colname: tf.keras.layers.Input(name=colname, shape=(1,), dtype='int32') + for colname in _transformed_names(_CATEGORICAL_FEATURE_KEYS) + } + input_layers = { + **deep_input, + **wide_vocab_input, + **wide_bucket_input, + **wide_categorical_input, + } + + # TODO(b/161952382): Replace with Keras premade models and + # Keras preprocessing layers. + deep = tf.keras.layers.concatenate( + [tf.keras.layers.Normalization()(layer) for layer in deep_input.values()] + ) + for numnodes in (hidden_units or [100, 70, 50, 25]): + deep = tf.keras.layers.Dense(numnodes)(deep) + + wide_layers = [] + for key in _transformed_names(_VOCAB_FEATURE_KEYS): + wide_layers.append( + tf.keras.layers.CategoryEncoding(num_tokens=_VOCAB_SIZE + _OOV_SIZE)( + input_layers[key] + ) + ) + for key in _transformed_names(_BUCKET_FEATURE_KEYS): + wide_layers.append( + tf.keras.layers.CategoryEncoding(num_tokens=_FEATURE_BUCKET_COUNT)( + input_layers[key] + ) + ) + for key, num_tokens in zip( + _transformed_names(_CATEGORICAL_FEATURE_KEYS), + _MAX_CATEGORICAL_FEATURE_VALUES, + ): + wide_layers.append( + tf.keras.layers.CategoryEncoding(num_tokens=num_tokens)( + input_layers[key] + ) + ) + wide = tf.keras.layers.concatenate(wide_layers) + + output = tf.keras.layers.Dense(1, activation='sigmoid')( + tf.keras.layers.concatenate([deep, wide]) + ) + output = tf.squeeze(output, -1) + + model = tf.keras.Model(input_layers, output) + model.compile( + loss='binary_crossentropy', + optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), + metrics=[tf.keras.metrics.BinaryAccuracy()], + ) + model.summary(print_fn=logging.info) + return model + + +def stats_options_updater_fn(unused_stats_type, stats_options): + """Callback function for setting pre and post-transform stats options. + + Args: + unused_stats_type: a stats_options_util.StatsType object. + stats_options: a tfdv.StatsOptions object. + + Returns: + An updated tfdv.StatsOptions object. + """ + return stats_options + + +# TFX Transform will call this function. def preprocessing_fn(inputs): """tf.transform's callback function for preprocessing inputs. @@ -112,18 +279,21 @@ def preprocessing_fn(inputs): for key in _DENSE_FLOAT_FEATURE_KEYS: # If sparse make it dense, setting nan's to 0 or '', and apply zscore. outputs[_transformed_name(key)] = tft.scale_to_z_score( - _fill_in_missing(inputs[key])) + _fill_in_missing(inputs[key]) + ) for key in _VOCAB_FEATURE_KEYS: # Build a vocabulary for this feature. outputs[_transformed_name(key)] = tft.compute_and_apply_vocabulary( _fill_in_missing(inputs[key]), top_k=_VOCAB_SIZE, - num_oov_buckets=_OOV_SIZE) + num_oov_buckets=_OOV_SIZE, + ) for key in _BUCKET_FEATURE_KEYS: outputs[_transformed_name(key)] = tft.bucketize( - _fill_in_missing(inputs[key]), _FEATURE_BUCKET_COUNT) + _fill_in_missing(inputs[key]), _FEATURE_BUCKET_COUNT + ) for key in _CATEGORICAL_FEATURE_KEYS: outputs[_transformed_name(key)] = _fill_in_missing(inputs[key]) @@ -131,223 +301,68 @@ def preprocessing_fn(inputs): # Was this passenger a big tipper? taxi_fare = _fill_in_missing(inputs[_FARE_KEY]) tips = _fill_in_missing(inputs[_LABEL_KEY]) - outputs[_transformed_name(_LABEL_KEY)] = tf.compat.v1.where( + outputs[_transformed_name(_LABEL_KEY)] = tf.where( tf.math.is_nan(taxi_fare), tf.cast(tf.zeros_like(taxi_fare), tf.int64), # Test if the tip was > 20% of the fare. tf.cast( - tf.greater(tips, tf.multiply(taxi_fare, tf.constant(0.2))), tf.int64)) + tf.greater(tips, tf.multiply(taxi_fare, tf.constant(0.2))), tf.int64 + ), + ) return outputs -def _build_estimator(config, hidden_units=None, warm_start_from=None): - """Build an estimator for predicting the tipping behavior of taxi riders. - - Args: - config: tf.contrib.learn.RunConfig defining the runtime environment for the - estimator (including model_dir). - hidden_units: [int], the layer sizes of the DNN (input layer first) - warm_start_from: Optional directory to warm start from. - - Returns: - A dict of the following: - - estimator: The estimator that will be used for training and eval. - - train_spec: Spec for training. - - eval_spec: Spec for eval. - - eval_input_receiver_fn: Input function for eval. - """ - real_valued_columns = [ - tf.feature_column.numeric_column(key, shape=()) - for key in _transformed_names(_DENSE_FLOAT_FEATURE_KEYS) - ] - categorical_columns = [ - tf.feature_column.categorical_column_with_identity( - key, num_buckets=_VOCAB_SIZE + _OOV_SIZE, default_value=0) - for key in _transformed_names(_VOCAB_FEATURE_KEYS) - ] - categorical_columns += [ - tf.feature_column.categorical_column_with_identity( - key, num_buckets=_FEATURE_BUCKET_COUNT, default_value=0) - for key in _transformed_names(_BUCKET_FEATURE_KEYS) - ] - categorical_columns += [ - tf.feature_column.categorical_column_with_identity( # pylint: disable=g-complex-comprehension - key, - num_buckets=num_buckets, - default_value=0) for key, num_buckets in zip( - _transformed_names(_CATEGORICAL_FEATURE_KEYS), - _MAX_CATEGORICAL_FEATURE_VALUES) - ] - return tf_estimator.DNNLinearCombinedClassifier( - config=config, - linear_feature_columns=categorical_columns, - dnn_feature_columns=real_valued_columns, - dnn_hidden_units=hidden_units or [100, 70, 50, 25], - warm_start_from=warm_start_from) - - -def _example_serving_receiver_fn(transform_output, schema): - """Build the serving in inputs. +# TFX Trainer will call this function. +def run_fn(fn_args: fn_args_utils.FnArgs): + """Train the model based on given args. Args: - transform_output: a `tft.TFTransformOutput` object. - schema: the schema of the input data. - - Returns: - Tensorflow graph which parses examples, applying tf-transform to them. - """ - raw_feature_spec = _get_raw_feature_spec(schema) - raw_feature_spec.pop(_LABEL_KEY) - - raw_input_fn = tf_estimator.export.build_parsing_serving_input_receiver_fn( - raw_feature_spec, default_batch_size=None) - serving_input_receiver = raw_input_fn() - - _, transformed_features = transform_output.transform_raw_features( - serving_input_receiver.features, drop_unused_features=True) - - return tf_estimator.export.ServingInputReceiver( - transformed_features, serving_input_receiver.receiver_tensors) - - -def _eval_input_receiver_fn(transform_output, schema): - """Build everything needed for the tf-model-analysis to run the model. - - Args: - transform_output: a `tft.TFTransformOutput` object. - schema: the schema of the input data. - - Returns: - EvalInputReceiver function, which contains: - - Tensorflow graph which parses raw untransformed features, applies the - tf-transform preprocessing operators. - - Set of raw, untransformed features. - - Label against which predictions will be compared. - """ - # Notice that the inputs are raw features, not transformed features here. - raw_feature_spec = _get_raw_feature_spec(schema) - - serialized_tf_example = tf.compat.v1.placeholder( - dtype=tf.string, shape=[None], name='input_example_tensor') - - # Add a parse_example operator to the tensorflow graph, which will parse - # raw, untransformed, tf examples. - features = tf.io.parse_example( - serialized=serialized_tf_example, features=raw_feature_spec) - - # Now that we have our raw examples, process them through the tf-transform - # function computed during the preprocessing step. - _, transformed_features = transform_output.transform_raw_features( - features, drop_unused_features=True) - - # The key name MUST be 'examples'. - receiver_tensors = {'examples': serialized_tf_example} - - # NOTE: Model is driven by transformed features (since training works on the - # materialized output of TFT, but slicing will happen on raw features. - features.update(transformed_features) - - return tfma.export.EvalInputReceiver( - features=features, - receiver_tensors=receiver_tensors, - labels=transformed_features[_transformed_name(_LABEL_KEY)]) - - -def _input_fn(file_pattern: List[str], - data_accessor: DataAccessor, - tf_transform_output: tft.TFTransformOutput, - batch_size: int = 200) -> tf.data.Dataset: - """Generates features and label for tuning/training. - - Args: - file_pattern: List of paths or patterns of input tfrecord files. - data_accessor: DataAccessor for converting input to RecordBatch. - tf_transform_output: A TFTransformOutput. - batch_size: representing the number of consecutive elements of returned - dataset to combine in a single batch - - Returns: - A dataset that contains (features, indices) tuple where features is a - dictionary of Tensors, and indices is a single Tensor of label indices. - """ - return data_accessor.tf_dataset_factory( - file_pattern, - dataset_options.TensorFlowDatasetOptions( - batch_size=batch_size, label_key=_transformed_name(_LABEL_KEY)), - tf_transform_output.transformed_metadata.schema) - - -# TFX will call this function -def trainer_fn(trainer_fn_args, schema): - """Build the estimator using the high level API. - - Args: - trainer_fn_args: Holds args used to train the model as name/value pairs. - schema: Holds the schema of the training examples. - - Returns: - A dict of the following: - - estimator: The estimator that will be used for training and eval. - - train_spec: Spec for training. - - eval_spec: Spec for eval. - - eval_input_receiver_fn: Input function for eval. + fn_args: Holds args used to train the model as name/value pairs. """ # Number of nodes in the first layer of the DNN first_dnn_layer_size = 100 num_dnn_layers = 4 dnn_decay_factor = 0.7 - train_batch_size = 40 - eval_batch_size = 40 - - tf_transform_output = tft.TFTransformOutput(trainer_fn_args.transform_output) - - train_input_fn = lambda: _input_fn( # pylint: disable=g-long-lambda - trainer_fn_args.train_files, - trainer_fn_args.data_accessor, - tf_transform_output, - batch_size=train_batch_size) - - eval_input_fn = lambda: _input_fn( # pylint: disable=g-long-lambda - trainer_fn_args.eval_files, - trainer_fn_args.data_accessor, - tf_transform_output, - batch_size=eval_batch_size) - - train_spec = tf_estimator.TrainSpec( - train_input_fn, max_steps=trainer_fn_args.train_steps) - - serving_receiver_fn = ( - lambda: _example_serving_receiver_fn(tf_transform_output, schema)) - - exporter = tf_estimator.FinalExporter('chicago-taxi', serving_receiver_fn) - eval_spec = tf_estimator.EvalSpec( - eval_input_fn, - steps=trainer_fn_args.eval_steps, - exporters=[exporter], - name='chicago-taxi-eval') - - run_config = tf_estimator.RunConfig( - save_checkpoints_steps=999, keep_checkpoint_max=1) - - run_config = run_config.replace(model_dir=trainer_fn_args.serving_model_dir) - - estimator = _build_estimator( - # Construct layers sizes with exponetial decay - hidden_units=[ - max(2, int(first_dnn_layer_size * dnn_decay_factor**i)) - for i in range(num_dnn_layers) - ], - config=run_config, - warm_start_from=trainer_fn_args.base_model) - - # Create an input receiver for TFMA processing - receiver_fn = lambda: _eval_input_receiver_fn(tf_transform_output, schema) - - return { - 'estimator': estimator, - 'train_spec': train_spec, - 'eval_spec': eval_spec, - 'eval_input_receiver_fn': receiver_fn + tf_transform_output = tft.TFTransformOutput(fn_args.transform_graph_path) + + train_dataset = _input_fn( + fn_args.train_files, fn_args.data_accessor, tf_transform_output, 40 + ) + eval_dataset = _input_fn( + fn_args.eval_files, fn_args.data_accessor, tf_transform_output, 40 + ) + + mirrored_strategy = tf.distribute.MirroredStrategy() + with mirrored_strategy.scope(): + model = _build_keras_model( + # Construct layers sizes with exponetial decay + hidden_units=[ + max(2, int(first_dnn_layer_size * dnn_decay_factor**i)) + for i in range(num_dnn_layers) + ] + ) + + # Write logs to path + tensorboard_callback = tf.keras.callbacks.TensorBoard( + log_dir=fn_args.model_run_dir, update_freq='epoch' + ) + + model.fit( + train_dataset, + steps_per_epoch=fn_args.train_steps, + validation_data=eval_dataset, + validation_steps=fn_args.eval_steps, + callbacks=[tensorboard_callback], + ) + + signatures = { + 'serving_default': _get_tf_examples_serving_signature( + model, tf_transform_output + ), + 'transform_features': _get_transform_features_signature( + model, tf_transform_output + ), } + model.save(fn_args.serving_model_dir, save_format='tf', signatures=signatures) diff --git a/tfx/examples/mnist/mnist_pipeline_native_keras.py b/tfx/examples/mnist/mnist_pipeline_native_keras.py index 78ba19f82e..d584cab3b6 100644 --- a/tfx/examples/mnist/mnist_pipeline_native_keras.py +++ b/tfx/examples/mnist/mnist_pipeline_native_keras.py @@ -41,14 +41,10 @@ # Python module files to inject customized logic into the TFX components. The # Transform and Trainer both require user-defined functions to run successfully. _module_file = os.path.join(_mnist_root, 'mnist_utils_native_keras.py') -_module_file_lite = os.path.join( - _mnist_root, 'mnist_utils_native_keras_lite.py') # Path which can be listened to by the model server. Pusher will output the # trained model here. _serving_model_dir = os.path.join(_mnist_root, 'serving_model', _pipeline_name) -_serving_model_dir_lite = os.path.join( - _mnist_root, 'serving_model_lite', _pipeline_name) # Directory and data locations. This example assumes all of the images, # example code, and metadata library is relative to $HOME, but you can store @@ -69,8 +65,8 @@ def _create_pipeline(pipeline_name: str, pipeline_root: str, data_root: str, - module_file: str, module_file_lite: str, - serving_model_dir: str, serving_model_dir_lite: str, + module_file: str, + serving_model_dir: str, metadata_path: str, beam_pipeline_args: List[str], accuracy_threshold: float = 0.8) -> pipeline.Pipeline: @@ -108,9 +104,6 @@ def _create_trainer(module_file, component_id): # Uses user-provided Python function that trains a Keras model. trainer = _create_trainer(module_file, 'Trainer.mnist') - # Trains the same model as the one above, but converts it into a TFLite one. - trainer_lite = _create_trainer(module_file_lite, 'Trainer.mnist_lite') - # TODO(b/150949276): Add resolver back once it supports two trainers. # Uses TFMA to compute evaluation statistics over features of a model and @@ -128,24 +121,12 @@ def _create_trainer(module_file, component_id): ]) ]) - eval_config_lite = tfma.EvalConfig() - eval_config_lite.CopyFrom(eval_config) - # Informs the evaluator that the model is a TFLite model. - eval_config_lite.model_specs[0].model_type = 'tf_lite' - # Uses TFMA to compute the evaluation statistics over features of a model. evaluator = Evaluator( examples=example_gen.outputs['examples'], model=trainer.outputs['model'], eval_config=eval_config).with_id('Evaluator.mnist') - # Uses TFMA to compute the evaluation statistics over features of a TFLite - # model. - evaluator_lite = Evaluator( - examples=example_gen.outputs['examples'], - model=trainer_lite.outputs['model'], - eval_config=eval_config_lite).with_id('Evaluator.mnist_lite') - # Checks whether the model passed the validation steps and pushes the model # to a file destination if check passed. pusher = Pusher( @@ -155,16 +136,6 @@ def _create_trainer(module_file, component_id): filesystem=pusher_pb2.PushDestination.Filesystem( base_directory=serving_model_dir))).with_id('Pusher.mnist') - # Checks whether the TFLite model passed the validation steps and pushes the - # model to a file destination if check passed. - pusher_lite = Pusher( - model=trainer_lite.outputs['model'], - model_blessing=evaluator_lite.outputs['blessing'], - push_destination=pusher_pb2.PushDestination( - filesystem=pusher_pb2.PushDestination.Filesystem( - base_directory=serving_model_dir_lite))).with_id( - 'Pusher.mnist_lite') - return pipeline.Pipeline( pipeline_name=pipeline_name, pipeline_root=pipeline_root, @@ -175,11 +146,8 @@ def _create_trainer(module_file, component_id): example_validator, transform, trainer, - trainer_lite, evaluator, - evaluator_lite, pusher, - pusher_lite, ], enable_cache=True, metadata_connection_config=metadata.sqlite_metadata_connection_config( @@ -197,8 +165,6 @@ def _create_trainer(module_file, component_id): pipeline_root=_pipeline_root, data_root=_data_root, module_file=_module_file, - module_file_lite=_module_file_lite, serving_model_dir=_serving_model_dir, - serving_model_dir_lite=_serving_model_dir_lite, metadata_path=_metadata_path, beam_pipeline_args=_beam_pipeline_args)) diff --git a/tfx/examples/mnist/mnist_pipeline_native_keras_e2e_test.py b/tfx/examples/mnist/mnist_pipeline_native_keras_e2e_test.py index 4f97725896..3edb7fd957 100644 --- a/tfx/examples/mnist/mnist_pipeline_native_keras_e2e_test.py +++ b/tfx/examples/mnist/mnist_pipeline_native_keras_e2e_test.py @@ -38,11 +38,7 @@ def setUp(self): self._data_root = os.path.join(os.path.dirname(__file__), 'data') self._module_file = os.path.join( os.path.dirname(__file__), 'mnist_utils_native_keras.py') - self._module_file_lite = os.path.join( - os.path.dirname(__file__), 'mnist_utils_native_keras_lite.py') self._serving_model_dir = os.path.join(self._test_dir, 'serving_model') - self._serving_model_dir_lite = os.path.join( - self._test_dir, 'serving_model_lite') self._pipeline_root = os.path.join(self._test_dir, 'tfx', 'pipelines', self._pipeline_name) self._metadata_path = os.path.join(self._test_dir, 'tfx', 'metadata', @@ -73,14 +69,11 @@ def assertExecutedOnce(self, component: str) -> None: def assertPipelineExecution(self) -> None: self.assertExecutedOnce('ImportExampleGen') self.assertExecutedOnce('Evaluator.mnist') - self.assertExecutedOnce('Evaluator.mnist_lite') self.assertExecutedOnce('ExampleValidator') self.assertExecutedOnce('Pusher.mnist') - self.assertExecutedOnce('Pusher.mnist_lite') self.assertExecutedOnce('SchemaGen') self.assertExecutedOnce('StatisticsGen') self.assertExecutedOnce('Trainer.mnist') - self.assertExecutedOnce('Trainer.mnist_lite') self.assertExecutedOnce('Transform') def testMNISTPipelineNativeKeras(self): @@ -91,20 +84,17 @@ def testMNISTPipelineNativeKeras(self): pipeline_name=self._pipeline_name, data_root=self._data_root, module_file=self._module_file, - module_file_lite=self._module_file_lite, serving_model_dir=self._serving_model_dir, - serving_model_dir_lite=self._serving_model_dir_lite, pipeline_root=self._pipeline_root, metadata_path=self._metadata_path, beam_pipeline_args=[], accuracy_threshold=0.5)) # Use a low value to make test stable. self.assertTrue(fileio.exists(self._serving_model_dir)) - self.assertTrue(fileio.exists(self._serving_model_dir_lite)) self.assertTrue(fileio.exists(self._metadata_path)) metadata_config = metadata.sqlite_metadata_connection_config( self._metadata_path) - expected_execution_count = 11 + expected_execution_count = 8 with metadata.Metadata(metadata_config) as m: artifact_count = len(m.store.get_artifacts()) execution_count = len(m.store.get_executions()) @@ -119,9 +109,7 @@ def testMNISTPipelineNativeKeras(self): pipeline_name=self._pipeline_name, data_root=self._data_root, module_file=self._module_file, - module_file_lite=self._module_file_lite, serving_model_dir=self._serving_model_dir, - serving_model_dir_lite=self._serving_model_dir_lite, pipeline_root=self._pipeline_root, metadata_path=self._metadata_path, beam_pipeline_args=[], diff --git a/tfx/examples/mnist/mnist_utils_native_keras_base.py b/tfx/examples/mnist/mnist_utils_native_keras_base.py index ce44c9e0d0..d580a1b10f 100644 --- a/tfx/examples/mnist/mnist_utils_native_keras_base.py +++ b/tfx/examples/mnist/mnist_utils_native_keras_base.py @@ -13,8 +13,7 @@ # limitations under the License. """Base Python source file for MNIST utils. -This file is used by both mnist_utils_native_keras and -mnist_util_native_keras_lite to build Keras and TFLite models, respectively. +This file is used by both mnist_utils_native_keras to build Keras models. """ from typing import List diff --git a/tfx/examples/mnist/mnist_utils_native_keras_lite.py b/tfx/examples/mnist/mnist_utils_native_keras_lite.py deleted file mode 100644 index 9734cf4226..0000000000 --- a/tfx/examples/mnist/mnist_utils_native_keras_lite.py +++ /dev/null @@ -1,107 +0,0 @@ -# Copyright 2020 Google LLC. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Python source file includes MNIST utils for TFLite model. - -The utilities in this file are used to build a TFLite model. -This module file will be used in Transform and generic Trainer. -""" - -import os - -import tensorflow as tf -import tensorflow_transform as tft - -from tfx import v1 as tfx -from tfx.components.trainer.rewriting import converters -from tfx.components.trainer.rewriting import rewriter -from tfx.components.trainer.rewriting import rewriter_factory -from tfx.examples.mnist import mnist_utils_native_keras_base as base - - -def _get_serve_tf_examples_fn(model, tf_transform_output): - """Returns a function that feeds the input tensor into the model.""" - - model.tft_layer = tf_transform_output.transform_features_layer() - - @tf.function - def serve_tf_examples_fn(image_tensor): - """Returns the output to be used in the serving signature.""" - transformed_features = model.tft_layer({base.IMAGE_KEY: image_tensor}) - return model(transformed_features) - - return serve_tf_examples_fn - - -# TFX Transform will call this function. -def preprocessing_fn(inputs): - """tf.transform's callback function for preprocessing inputs. - - Args: - inputs: map from feature keys to raw not-yet-transformed features. - - Returns: - Map from string feature key to transformed feature operations. - """ - return base.preprocessing_fn(inputs) - - -# TFX Trainer will call this function. -def run_fn(fn_args: tfx.components.FnArgs): - """Train the model based on given args. - - Args: - fn_args: Holds args used to train the model as name/value pairs. - """ - tf_transform_output = tft.TFTransformOutput(fn_args.transform_output) - - train_dataset = base.input_fn(fn_args.train_files, fn_args.data_accessor, - tf_transform_output, 40) - eval_dataset = base.input_fn(fn_args.eval_files, fn_args.data_accessor, - tf_transform_output, 40) - - mirrored_strategy = tf.distribute.MirroredStrategy() - with mirrored_strategy.scope(): - model = base.build_keras_model() - - # Write logs to path - tensorboard_callback = tf.keras.callbacks.TensorBoard( - log_dir=fn_args.model_run_dir, update_freq='epoch') - - model.fit( - train_dataset, - steps_per_epoch=fn_args.train_steps, - validation_data=eval_dataset, - validation_steps=fn_args.eval_steps, - callbacks=[tensorboard_callback]) - - signatures = { - 'serving_default': - _get_serve_tf_examples_fn( - model, tf_transform_output).get_concrete_function( - tf.TensorSpec( - shape=[None, 784], - dtype=tf.float32, - name='image_floats')) - } - temp_saving_model_dir = os.path.join(fn_args.serving_model_dir, 'temp') - model.save(temp_saving_model_dir, save_format='tf', signatures=signatures) - - tfrw = rewriter_factory.create_rewriter( - rewriter_factory.TFLITE_REWRITER, name='tflite_rewriter') - converters.rewrite_saved_model(temp_saving_model_dir, - fn_args.serving_model_dir, - tfrw, - rewriter.ModelType.TFLITE_MODEL) - - tfx.dsl.io.fileio.rmtree(temp_saving_model_dir) diff --git a/tfx/examples/tfjs_next_page_prediction/README.md b/tfx/examples/tfjs_next_page_prediction/README.md index 08f9d8a2b2..ed94ebf2be 100644 --- a/tfx/examples/tfjs_next_page_prediction/README.md +++ b/tfx/examples/tfjs_next_page_prediction/README.md @@ -5,10 +5,6 @@ This example demonstrates: * How Apache Beam can be used to convert Google Analytics events into data used for training (see `bigquery_beam_data_generation.py`). - * How to construct a TFX pipeline that trains a TFJS - model for predicting the next page the user will - visit (see `tfjs_next_page_prediction_pipeline.py` - which shows how to setup such a pipeline). diff --git a/tfx/examples/tfjs_next_page_prediction/tfjs_next_page_prediction_e2e_test.py b/tfx/examples/tfjs_next_page_prediction/tfjs_next_page_prediction_e2e_test.py deleted file mode 100644 index d55dc19015..0000000000 --- a/tfx/examples/tfjs_next_page_prediction/tfjs_next_page_prediction_e2e_test.py +++ /dev/null @@ -1,112 +0,0 @@ -# Copyright 2021 Google LLC. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""E2E Tests for tfx.examples.tfjs_next_page_prediction.tfjs_next_page_prediction_pipeline.""" - -import os -import unittest - -import tensorflow as tf - -from tfx.dsl.io import fileio -from tfx.examples.tfjs_next_page_prediction import tfjs_next_page_prediction_pipeline -from tfx.orchestration import metadata -from tfx.orchestration.local.local_dag_runner import LocalDagRunner - -try: - import tensorflowjs # pylint: disable=g-import-not-at-top -except ImportError: - tensorflowjs = None - -import pytest - - -@pytest.mark.xfail(run=False, reason="PR 6889 This class contains tests that fail and needs to be fixed. " -"If all tests pass, please remove this mark.") -@pytest.mark.e2e -@unittest.skipIf(tensorflowjs is None, - 'Cannot import required modules. This can happen when' - ' tensorflowjs is not available.') -class TFJSNextPagePredictionPipelineEndToEndTest(tf.test.TestCase): - - def setUp(self): - super().setUp() - self._test_dir = os.path.join( - os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), - self._testMethodName) - - self._pipeline_name = 'page_prediction_test' - self._data_root = os.path.join(os.path.dirname(__file__), 'data') - self._module_file = os.path.join( - os.path.dirname(__file__), 'tfjs_next_page_prediction_util.py') - self._serving_model_dir = os.path.join(self._test_dir, 'serving_model') - self._pipeline_root = os.path.join(self._test_dir, 'tfx', 'pipelines', - self._pipeline_name) - self._metadata_path = os.path.join(self._test_dir, 'tfx', 'metadata', - self._pipeline_name, 'metadata.db') - - def assertExecutedOnce(self, component: str) -> None: - """Check the component is executed exactly once.""" - component_path = os.path.join(self._pipeline_root, component) - self.assertTrue(fileio.exists(component_path)) - outputs = fileio.listdir(component_path) - - self.assertIn('.system', outputs) - outputs.remove('.system') - system_paths = [ - os.path.join('.system', path) - for path in fileio.listdir(os.path.join(component_path, '.system')) - ] - self.assertNotEmpty(system_paths) - self.assertIn('.system/executor_execution', system_paths) - outputs.extend(system_paths) - for output in outputs: - execution = fileio.listdir(os.path.join(component_path, output)) - self.assertLen(execution, 1) - - def assertPipelineExecution(self) -> None: - self.assertExecutedOnce('ImportExampleGen') - self.assertExecutedOnce('Evaluator') - self.assertExecutedOnce('ExampleValidator') - self.assertExecutedOnce('Pusher') - self.assertExecutedOnce('SchemaGen') - self.assertExecutedOnce('StatisticsGen') - self.assertExecutedOnce('Trainer') - self.assertExecutedOnce('Transform') - - def testTFJSPagePredictionPipeline(self): - if not tf.executing_eagerly(): - self.skipTest('The test requires TF2.') - pipeline = tfjs_next_page_prediction_pipeline._create_pipeline( - pipeline_name=self._pipeline_name, - data_root=self._data_root, - module_file=self._module_file, - serving_model_dir=self._serving_model_dir, - pipeline_root=self._pipeline_root, - metadata_path=self._metadata_path, - beam_pipeline_args=[]) - - LocalDagRunner().run(pipeline) - - self.assertTrue(fileio.exists(self._serving_model_dir)) - self.assertTrue(fileio.exists(self._metadata_path)) - expected_execution_count = 9 # 8 components + 1 resolver - metadata_config = metadata.sqlite_metadata_connection_config( - self._metadata_path) - with metadata.Metadata(metadata_config) as m: - artifact_count = len(m.store.get_artifacts()) - execution_count = len(m.store.get_executions()) - self.assertGreaterEqual(artifact_count, execution_count) - self.assertEqual(expected_execution_count, execution_count) - - self.assertPipelineExecution() diff --git a/tfx/examples/tfjs_next_page_prediction/tfjs_next_page_prediction_pipeline.py b/tfx/examples/tfjs_next_page_prediction/tfjs_next_page_prediction_pipeline.py deleted file mode 100644 index dab2a97c41..0000000000 --- a/tfx/examples/tfjs_next_page_prediction/tfjs_next_page_prediction_pipeline.py +++ /dev/null @@ -1,197 +0,0 @@ -# Copyright 2021 Google LLC. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TFX/TFJS Page Prediction Pipeline.""" - -import os -from typing import List - -import absl -import tensorflow_model_analysis as tfma -from tfx.v1 import dsl -from tfx.v1 import orchestration -from tfx.v1 import proto -from tfx.v1 import types -from tfx.v1.components import Evaluator -from tfx.v1.components import ExampleValidator -from tfx.v1.components import ImportExampleGen -from tfx.v1.components import Pusher -from tfx.v1.components import SchemaGen -from tfx.v1.components import StatisticsGen -from tfx.v1.components import Trainer -from tfx.v1.components import Transform - - -_pipeline_name = 'tfx_tfjs_page_prediction' - -# This example assumes that train set data is stored in -# ~/tfx_tfjs_page_prediction/data/. Feel free to customize and use -# google cloud storage paths if needed. -_page_prediction_root = os.path.join(os.environ['HOME'], - 'tfx_tfjs_page_prediction') -_data_root = os.path.join(_page_prediction_root, 'data') - -# Python module file to inject customized logic into the TFX components. The -# Transform and Trainer both require user-defined functions to run successfully. -_module_file = os.path.join(_page_prediction_root, - 'tfjs_next_page_prediction_util.py') -# Path which can be listened to by the model server. Pusher will output the -# trained model here. -_serving_model_dir = os.path.join(_page_prediction_root, 'serving_model', - _pipeline_name) - -# Directory and data locations. This example assumes all of the -# example code and metadata library is relative to $HOME, but you can store -# these files anywhere on your local filesystem. -_tfx_root = os.path.join(os.environ['HOME'], 'tfx') -_pipeline_root = os.path.join(_tfx_root, 'pipelines', _pipeline_name) -# Sqlite ML-metadata db path. -_metadata_path = os.path.join( - os.getenv('HOME'), 'metadata', _pipeline_name, 'metadata.db') - -# Pipeline arguments for Beam powered Components. -_beam_pipeline_args = [ - '--direct_running_mode=multi_processing', - # 0 means auto-detect based on on the number of CPUs available - # during execution time. - '--direct_num_workers=0', -] - - -def _create_pipeline(pipeline_name: str, pipeline_root: str, data_root: str, - module_file: str, serving_model_dir: str, - metadata_path: str, - beam_pipeline_args: List[str]) -> dsl.Pipeline: - """Implements the page prediction pipline with TFX.""" - input_config = proto.Input( - splits=[proto.Input.Split(name='input', pattern='*.tfrecord.gz')]) - output_config = proto.Output( - split_config=proto.SplitConfig(splits=[ - proto.SplitConfig.Split(name='train', hash_buckets=9), - proto.SplitConfig.Split(name='eval', hash_buckets=1) - ])) - - # Brings data in to the pipline - example_gen = ImportExampleGen( - input_base=data_root, - input_config=input_config, - output_config=output_config) - - # Computes statistics over data for visualization and example validation. - statistics_gen = StatisticsGen( - examples=example_gen.outputs['examples']) - - # Generates schema based on statistics files. - schema_gen = SchemaGen( - statistics=statistics_gen.outputs['statistics'], infer_feature_shape=True) - - # Performs anomaly detection based on statistics and data schema. - example_validator = ExampleValidator( - statistics=statistics_gen.outputs['statistics'], - schema=schema_gen.outputs['schema']) - - # Performs transformations and feature engineering in training and serving. - transform = Transform( - examples=example_gen.outputs['examples'], - schema=schema_gen.outputs['schema'], - module_file=module_file) - - # Uses user-provided Python function that trains a model. - trainer = Trainer( - module_file=module_file, - examples=transform.outputs['transformed_examples'], - transform_graph=transform.outputs['transform_graph'], - schema=schema_gen.outputs['schema'], - train_args=proto.TrainArgs(num_steps=100000), - eval_args=proto.EvalArgs(num_steps=200)) - - # Get the latest blessed model for model validation. - model_resolver = dsl.Resolver( - strategy_class=dsl.experimental.LatestBlessedModelStrategy, - model=dsl.Channel(type=types.standard_artifacts.Model), - model_blessing=dsl.Channel( - type=types.standard_artifacts.ModelBlessing)).with_id( - 'latest_blessed_model_resolver') - - # Uses TFMA to compute evaluation statistics over features of a model and - # perform quality validation of a candidate model (compared to a baseline). - eval_config = tfma.EvalConfig( - # Directly evaluates the tfjs model. - model_specs=[tfma.ModelSpec(label_key='label', model_type='tf_js')], - slicing_specs=[tfma.SlicingSpec()], - metrics_specs=[ - tfma.MetricsSpec(metrics=[ - tfma.MetricConfig( - class_name='SparseCategoricalAccuracy', - threshold=tfma.MetricThreshold( - value_threshold=tfma.GenericValueThreshold( - # Increase this threshold when training on complete - # dataset. - lower_bound={'value': 0.01}), - # Change threshold will be ignored if there is no - # baseline model resolved from MLMD (first run). - change_threshold=tfma.GenericChangeThreshold( - direction=tfma.MetricDirection.HIGHER_IS_BETTER, - absolute={'value': -1e-2}))) - ]) - ]) - - evaluator = Evaluator( - examples=transform.outputs['transformed_examples'], - model=trainer.outputs['model'], - baseline_model=model_resolver.outputs['model'], - eval_config=eval_config) - - # Checks whether the model passed the validation steps and pushes the model - # to a file destination if check passed. - pusher = Pusher( - model=trainer.outputs['model'], - model_blessing=evaluator.outputs['blessing'], - push_destination=proto.PushDestination( - filesystem=proto.PushDestination.Filesystem( - base_directory=serving_model_dir))) - - components = [ - example_gen, - statistics_gen, - schema_gen, - example_validator, - transform, - trainer, - model_resolver, - evaluator, - pusher, - ] - return dsl.Pipeline( - pipeline_name=pipeline_name, - pipeline_root=pipeline_root, - components=components, - metadata_connection_config=orchestration.metadata - .sqlite_metadata_connection_config(metadata_path), - enable_cache=True, - beam_pipeline_args=beam_pipeline_args) - - -# To run this pipeline from the python CLI: -# $python imdb_pipeline_native_keras.py -if __name__ == '__main__': - absl.logging.set_verbosity(absl.logging.INFO) - orchestration.LocalDagRunner().run( - _create_pipeline( - pipeline_name=_pipeline_name, - pipeline_root=_pipeline_root, - data_root=_data_root, - module_file=_module_file, - serving_model_dir=_serving_model_dir, - metadata_path=_metadata_path, - beam_pipeline_args=_beam_pipeline_args)) diff --git a/tfx/examples/tfjs_next_page_prediction/tfjs_next_page_prediction_util.py b/tfx/examples/tfjs_next_page_prediction/tfjs_next_page_prediction_util.py deleted file mode 100644 index 7b8bbe919e..0000000000 --- a/tfx/examples/tfjs_next_page_prediction/tfjs_next_page_prediction_util.py +++ /dev/null @@ -1,208 +0,0 @@ -# Copyright 2021 Google LLC. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Python source file includes pipeline functions and necessary utils.""" - -import os -from typing import List - -import absl -import tensorflow as tf -from tensorflow import keras -import tensorflow_transform as tft - -from tfx.components.trainer.rewriting import converters -from tfx.components.trainer.rewriting import rewriter -from tfx.components.trainer.rewriting import rewriter_factory -from tfx.dsl.io import fileio - -from tfx import v1 as tfx # pylint: disable=g-bad-import-order - -from tfx_bsl.public import tfxio - -_CUR_PAGE_FEATURE_KEY = 'cur_page' -_SESSION_INDEX_FEATURE_KEY = 'session_index' -_LABEL_KEY = 'label' -_VOCAB_FILENAME = 'vocab' - -_TOP_K = 100 -_EMBEDDING_DIM = 10 -_UNITS = 50 - -_TRAIN_BATCH_SIZE = 32 -_EVAL_BATCH_SIZE = 16 - - -# TFX Transform will call this function. -def preprocessing_fn(inputs): - """Callback function for preprocessing inputs. - - Args: - inputs: map from feature keys to raw not-yet-transformed features. - - Returns: - Map from string feature key to transformed feature operations. - """ - outputs = inputs.copy() - - # Compute a vocabulary based on the TOP-K current pages and labels seen in - # the dataset. - vocab = tft.vocabulary( - tf.concat([inputs[_CUR_PAGE_FEATURE_KEY], inputs[_LABEL_KEY]], axis=0), - top_k=_TOP_K, - vocab_filename=_VOCAB_FILENAME) - - # Apply the vocabulary to both the current page feature and the label, - # converting the strings into integers. - for k in [_CUR_PAGE_FEATURE_KEY, _LABEL_KEY]: - # Out-of-vocab strings will be assigned the _TOP_K value. - outputs[k] = tft.apply_vocabulary(inputs[k], vocab, default_value=_TOP_K) - return outputs - - -def _input_fn(file_pattern: List[str], - data_accessor: tfx.components.DataAccessor, - tf_transform_output: tft.TFTransformOutput, - batch_size: int = 200) -> tf.data.Dataset: - """Generates features and label for tuning/training. - - Args: - file_pattern: List of paths or patterns of input tfrecord files. - data_accessor: DataAccessor for converting input to RecordBatch. - tf_transform_output: A TFTransformOutput. - batch_size: representing the number of consecutive elements of returned - dataset to combine in a single batch. - - Returns: - A dataset that contains (features, indices) tuple where features is a - dictionary of Tensors, and indices is a single Tensor of label indices. - """ - dataset = data_accessor.tf_dataset_factory( - file_pattern, - tfxio.TensorFlowDatasetOptions( - batch_size=batch_size, label_key=_LABEL_KEY), - tf_transform_output.transformed_metadata.schema) - - return dataset.repeat() - - -def _build_keras_model() -> keras.Model: - """Creates a Keras model for predicting the next page. - - Returns: - A Keras Model. - """ - # This model has two inputs: (i) current page and (ii) session index. - cur_page_input = keras.Input(shape=(), name=_CUR_PAGE_FEATURE_KEY) - session_index_input = keras.Input(shape=(1,), name=_SESSION_INDEX_FEATURE_KEY) - inputs = [cur_page_input, session_index_input] - - # Create an embedding for the current page. - cur_page_emb = keras.layers.Embedding( - _TOP_K + 1, _EMBEDDING_DIM, input_length=1)( - cur_page_input) - x = keras.layers.Concatenate()([cur_page_emb, session_index_input]) - x = keras.layers.Dense(_UNITS, activation='relu')(x) - outputs = keras.layers.Dense(_TOP_K + 1)(x) - model = keras.Model(inputs=inputs, outputs=outputs) - model.compile( - loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), - optimizer=keras.optimizers.Adam(0.0001), - metrics=[ - 'sparse_categorical_accuracy', 'sparse_top_k_categorical_accuracy' - ]) - - model.summary(print_fn=absl.logging.info) - return model - - -# The inference function assumes that the mapping from string to integer for -# the current page has been done outside of the model. We store the vocabulary -# file with the output tfjs model to simplify this process. -def _get_inference_fn(model, tf_transform_output): - """Defines the function used for inference.""" - model.tft_layer = tf_transform_output.transform_features_layer() - - @tf.function - def inference_fn(cur_page, session_index): - """Returns the output to be used in the serving signature.""" - return model({ - _CUR_PAGE_FEATURE_KEY: cur_page, - _SESSION_INDEX_FEATURE_KEY: session_index - }) - - return inference_fn - - -# TFX Trainer will call this function. -def run_fn(fn_args: tfx.components.FnArgs): - """Train the model based on given args. - - Args: - fn_args: Holds args used to train the model as name/value pairs. - """ - tf_transform_output = tft.TFTransformOutput(fn_args.transform_output) - - train_dataset = _input_fn( - fn_args.train_files, - fn_args.data_accessor, - tf_transform_output, - batch_size=_TRAIN_BATCH_SIZE) - - eval_dataset = _input_fn( - fn_args.eval_files, - fn_args.data_accessor, - tf_transform_output, - batch_size=_EVAL_BATCH_SIZE) - - mirrored_strategy = tf.distribute.MirroredStrategy() - with mirrored_strategy.scope(): - model = _build_keras_model() - - model.fit( - train_dataset, - steps_per_epoch=fn_args.train_steps, - validation_data=eval_dataset, - validation_steps=fn_args.eval_steps, - verbose=2) - - signatures = { - 'serving_default': - _get_inference_fn(model, tf_transform_output).get_concrete_function( - tf.TensorSpec( - shape=[None], dtype=tf.int64, name=_CUR_PAGE_FEATURE_KEY), - tf.TensorSpec( - shape=[None], dtype=tf.int64, - name=_SESSION_INDEX_FEATURE_KEY)), - } - - # Create the saved_model in a temporary directory. - temp_saving_model_dir = os.path.join(fn_args.serving_model_dir, 'temp') - model.save(temp_saving_model_dir, save_format='tf', signatures=signatures) - - # Convert the saved_model to a tfjs model and store it in the final directory. - tfrw = rewriter_factory.create_rewriter( - rewriter_factory.TFJS_REWRITER, name='tfjs_rewriter') - converters.rewrite_saved_model(temp_saving_model_dir, - fn_args.serving_model_dir, tfrw, - rewriter.ModelType.TFJS_MODEL) - - # Copy the vocabulary computed by transform to the final directory. - # The vocabulary is not included in the original savedmodel because vocab - # lookups are currently not supported in TFJS and are expected to be done - # independently by client code. - fileio.copy( - tf_transform_output.vocabulary_file_by_name(_VOCAB_FILENAME), - os.path.join(fn_args.serving_model_dir, _VOCAB_FILENAME)) - - fileio.rmtree(temp_saving_model_dir) diff --git a/tfx/experimental/pipeline_testing/examples/chicago_taxi_pipeline/taxi_pipeline_regression_e2e_test.py b/tfx/experimental/pipeline_testing/examples/chicago_taxi_pipeline/taxi_pipeline_regression_e2e_test.py deleted file mode 100644 index a2be633d54..0000000000 --- a/tfx/experimental/pipeline_testing/examples/chicago_taxi_pipeline/taxi_pipeline_regression_e2e_test.py +++ /dev/null @@ -1,203 +0,0 @@ -# Copyright 2020 Google LLC. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""E2E Tests for taxi pipeline beam with stub executors.""" - -import os - -from absl import logging -import tensorflow as tf -from tfx.dsl.compiler import compiler -from tfx.dsl.io import fileio -from tfx.examples.chicago_taxi_pipeline import taxi_pipeline_local -from tfx.experimental.pipeline_testing import executor_verifier_utils -from tfx.experimental.pipeline_testing import pipeline_mock -from tfx.experimental.pipeline_testing import pipeline_recorder_utils -from tfx.orchestration import metadata -from tfx.orchestration.beam.beam_dag_runner import BeamDagRunner - -from ml_metadata.proto import metadata_store_pb2 - -import pytest - - -@pytest.mark.xfail(run=False, reason="PR 6889 This class contains tests that fail and needs to be fixed. " -"If all tests pass, please remove this mark.") -@pytest.mark.e2e -class TaxiPipelineRegressionEndToEndTest(tf.test.TestCase): - - def setUp(self): - super().setUp() - self._test_dir = os.path.join( - os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), - self._testMethodName) - self._pipeline_name = 'beam_stub_test' - # This example assumes that the taxi data and taxi utility function are - # stored in tfx/examples/chicago_taxi_pipeline. Feel free to customize this - # as needed. - taxi_root = os.path.dirname(taxi_pipeline_local.__file__) - self._data_root = os.path.join(taxi_root, 'data', 'simple') - self._module_file = os.path.join(taxi_root, 'taxi_utils.py') - self._serving_model_dir = os.path.join(self._test_dir, 'serving_model') - self._pipeline_root = os.path.join(self._test_dir, 'tfx', 'pipelines', - self._pipeline_name) - # Metadata path for recording successful pipeline run. - self._recorded_mlmd_path = os.path.join(self._test_dir, 'tfx', 'record', - 'metadata.db') - # Metadata path for stub pipeline runs. - self._metadata_path = os.path.join(self._test_dir, 'tfx', 'metadata', - self._pipeline_name, 'metadata.db') - self._recorded_output_dir = os.path.join(self._test_dir, 'testdata') - - # Runs the pipeline and record to self._recorded_output_dir - record_taxi_pipeline = taxi_pipeline_local._create_pipeline( # pylint:disable=protected-access - pipeline_name=self._pipeline_name, - data_root=self._data_root, - module_file=self._module_file, - serving_model_dir=self._serving_model_dir, - pipeline_root=self._pipeline_root, - metadata_path=self._recorded_mlmd_path, - beam_pipeline_args=[]) - - BeamDagRunner().run(record_taxi_pipeline) - - pipeline_recorder_utils.record_pipeline( - output_dir=self._recorded_output_dir, - metadata_db_uri=self._recorded_mlmd_path, - pipeline_name=self._pipeline_name) - - self.taxi_pipeline = taxi_pipeline_local._create_pipeline( # pylint:disable=protected-access - pipeline_name=self._pipeline_name, - data_root=self._data_root, - module_file=self._module_file, - serving_model_dir=self._serving_model_dir, - pipeline_root=self._pipeline_root, - metadata_path=self._metadata_path, - beam_pipeline_args=[]) - - def assertDirectoryEqual(self, dir1: str, dir2: str): - self.assertTrue(executor_verifier_utils.compare_dirs(dir1, dir2)) - - def _verify_file_path(self, output_uri: str, artifact_uri: str): - self.assertTrue( - executor_verifier_utils.verify_file_dir(output_uri, artifact_uri)) - - def _veryify_root_dir(self, output_uri: str, unused_artifact_uri: str): - self.assertTrue(fileio.exists(output_uri)) - - def _verify_evaluation(self, output_uri: str, expected_uri: str): - self.assertTrue( - executor_verifier_utils.compare_eval_results(output_uri, expected_uri, - 1.0, ['accuracy'])) - - def _verify_schema(self, output_uri: str, expected_uri: str): - self.assertTrue( - executor_verifier_utils.compare_file_sizes(output_uri, expected_uri, - .5)) - - def _verify_examples(self, output_uri: str, expected_uri: str): - self.assertTrue( - executor_verifier_utils.compare_file_sizes(output_uri, expected_uri, - .5)) - - def _verify_model(self, output_uri: str, expected_uri: str): - self.assertTrue( - executor_verifier_utils.compare_model_file_sizes( - output_uri, expected_uri, .5)) - - def _verify_anomalies(self, output_uri: str, expected_uri: str): - self.assertTrue( - executor_verifier_utils.compare_anomalies(output_uri, expected_uri)) - - def testStubbedTaxiPipelineBeam(self): - pipeline_ir = compiler.Compiler().compile(self.taxi_pipeline) - - logging.info('Replacing with test_data_dir:%s', self._recorded_output_dir) - pipeline_mock.replace_executor_with_stub(pipeline_ir, - self._recorded_output_dir, []) - - BeamDagRunner().run_with_ir(pipeline_ir) - - self.assertTrue(fileio.exists(self._metadata_path)) - - metadata_config = metadata.sqlite_metadata_connection_config( - self._metadata_path) - - # Verify that recorded files are successfully copied to the output uris. - with metadata.Metadata(metadata_config) as m: - artifacts = m.store.get_artifacts() - artifact_count = len(artifacts) - executions = m.store.get_executions() - execution_count = len(executions) - # Artifact count is greater by 7 due to extra artifacts produced by - # Evaluator(blessing and evaluation), Trainer(model and model_run) and - # Transform(example, graph, cache, pre_transform_statistics, - # pre_transform_schema, post_transform_statistics, post_transform_schema, - # post_transform_anomalies) minus Resolver which doesn't generate - # new artifact. - self.assertEqual(artifact_count, execution_count + 7) - self.assertLen(self.taxi_pipeline.components, execution_count) - - for execution in executions: - component_id = pipeline_recorder_utils.get_component_id_from_execution( - m, execution) - if component_id.startswith('Resolver'): - continue - eid = [execution.id] - events = m.store.get_events_by_execution_ids(eid) - output_events = [ - x for x in events if x.type == metadata_store_pb2.Event.OUTPUT - ] - for event in output_events: - steps = event.path.steps - self.assertTrue(steps[0].HasField('key')) - name = steps[0].key - artifacts = m.store.get_artifacts_by_id([event.artifact_id]) - for idx, artifact in enumerate(artifacts): - self.assertDirectoryEqual( - artifact.uri, - os.path.join(self._recorded_output_dir, component_id, name, - str(idx))) - - # Calls verifier for pipeline output artifacts, excluding the resolver node. - BeamDagRunner().run(self.taxi_pipeline) - pipeline_outputs = executor_verifier_utils.get_pipeline_outputs( - self.taxi_pipeline.metadata_connection_config, self._pipeline_name) - - verifier_map = { - 'model': self._verify_model, - 'model_run': self._verify_model, - 'examples': self._verify_examples, - 'schema': self._verify_schema, - 'anomalies': self._verify_anomalies, - 'evaluation': self._verify_evaluation, - # A subdirectory of updated_analyzer_cache has changing name. - 'updated_analyzer_cache': self._veryify_root_dir, - } - - # List of components to verify. Resolver is ignored because it - # doesn't have an executor. - verify_component_ids = [ - component.id - for component in self.taxi_pipeline.components - if not component.id.startswith('Resolver') - ] - - for component_id in verify_component_ids: - logging.info('Verifying %s', component_id) - for key, artifact_dict in pipeline_outputs[component_id].items(): - for idx, artifact in artifact_dict.items(): - recorded_uri = os.path.join(self._recorded_output_dir, component_id, - key, str(idx)) - verifier_map.get(key, self._verify_file_path)(artifact.uri, - recorded_uri) diff --git a/tfx/experimental/templates/taxi/models/estimator_model/__init__.py b/tfx/experimental/templates/taxi/models/estimator_model/__init__.py deleted file mode 100644 index b179ecb83a..0000000000 --- a/tfx/experimental/templates/taxi/models/estimator_model/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2020 Google LLC. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/tfx/experimental/templates/taxi/models/estimator_model/constants.py b/tfx/experimental/templates/taxi/models/estimator_model/constants.py deleted file mode 100644 index e3b675f189..0000000000 --- a/tfx/experimental/templates/taxi/models/estimator_model/constants.py +++ /dev/null @@ -1,22 +0,0 @@ -# Copyright 2020 Google LLC. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Constants of the taxi model. - -These values can be tweaked to affect model training performance. -""" - -HIDDEN_UNITS = [16, 8] - -TRAIN_BATCH_SIZE = 40 -EVAL_BATCH_SIZE = 40 diff --git a/tfx/experimental/templates/taxi/models/estimator_model/model.py b/tfx/experimental/templates/taxi/models/estimator_model/model.py deleted file mode 100644 index 391dde63c0..0000000000 --- a/tfx/experimental/templates/taxi/models/estimator_model/model.py +++ /dev/null @@ -1,277 +0,0 @@ -# Copyright 2020 Google LLC. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""TFX template taxi model. - -A tf.estimator.DNNLinearCombinedClassifier which uses features -defined in features.py and network parameters defined in constants.py. -""" - -from absl import logging -import tensorflow as tf -from tensorflow import estimator as tf_estimator -import tensorflow_model_analysis as tfma -import tensorflow_transform as tft -from tensorflow_transform.tf_metadata import schema_utils - -from tfx import v1 as tfx -from tfx.experimental.templates.taxi.models import features -from tfx.experimental.templates.taxi.models.estimator_model import constants -from tfx_bsl.public import tfxio - -from tensorflow_metadata.proto.v0 import schema_pb2 - - -def _gzip_reader_fn(filenames): - """Small utility returning a record reader that can read gzip'ed files.""" - return tf.data.TFRecordDataset(filenames, compression_type='GZIP') - - -# Tf.Transform considers these features as "raw" -def _get_raw_feature_spec(schema): - return schema_utils.schema_as_feature_spec(schema).feature_spec - - -def _build_estimator(config, hidden_units=None, warm_start_from=None): - """Build an estimator for predicting the tipping behavior of taxi riders. - - Args: - config: tf.estimator.RunConfig defining the runtime environment for the - estimator (including model_dir). - hidden_units: [int], the layer sizes of the DNN (input layer first) - warm_start_from: Optional directory to warm start from. - - Returns: - A dict of the following: - - estimator: The estimator that will be used for training and eval. - - train_spec: Spec for training. - - eval_spec: Spec for eval. - - eval_input_receiver_fn: Input function for eval. - """ - real_valued_columns = [ - tf.feature_column.numeric_column(key, shape=()) - for key in features.transformed_names(features.DENSE_FLOAT_FEATURE_KEYS) - ] - - categorical_columns = [] - for key in features.transformed_names(features.VOCAB_FEATURE_KEYS): - categorical_columns.append( - tf.feature_column.categorical_column_with_identity( - key, - num_buckets=features.VOCAB_SIZE + features.OOV_SIZE, - default_value=0)) - - for key, num_buckets in zip( - features.transformed_names(features.BUCKET_FEATURE_KEYS), - features.BUCKET_FEATURE_BUCKET_COUNT): - categorical_columns.append( - tf.feature_column.categorical_column_with_identity( - key, num_buckets=num_buckets, default_value=0)) - - for key, num_buckets in zip( - features.transformed_names(features.CATEGORICAL_FEATURE_KEYS), - features.CATEGORICAL_FEATURE_MAX_VALUES): - categorical_columns.append( - tf.feature_column.categorical_column_with_identity( - key, num_buckets=num_buckets, default_value=0)) - - return tf_estimator.DNNLinearCombinedClassifier( - config=config, - linear_feature_columns=categorical_columns, - dnn_feature_columns=real_valued_columns, - dnn_hidden_units=hidden_units or [100, 70, 50, 25], - warm_start_from=warm_start_from) - - -def _example_serving_receiver_fn(tf_transform_output, schema): - """Build the serving in inputs. - - Args: - tf_transform_output: A TFTransformOutput. - schema: the schema of the input data. - - Returns: - Tensorflow graph which parses examples, applying tf-transform to them. - """ - raw_feature_spec = _get_raw_feature_spec(schema) - raw_feature_spec.pop(features.LABEL_KEY) - - raw_input_fn = tf_estimator.export.build_parsing_serving_input_receiver_fn( - raw_feature_spec, default_batch_size=None) - serving_input_receiver = raw_input_fn() - - transformed_features = tf_transform_output.transform_raw_features( - serving_input_receiver.features) - - return tf_estimator.export.ServingInputReceiver( - transformed_features, serving_input_receiver.receiver_tensors) - - -def _eval_input_receiver_fn(tf_transform_output, schema): - """Build everything needed for the tf-model-analysis to run the model. - - Args: - tf_transform_output: A TFTransformOutput. - schema: the schema of the input data. - - Returns: - EvalInputReceiver function, which contains: - - Tensorflow graph which parses raw untransformed features, applies the - tf-transform preprocessing operators. - - Set of raw, untransformed features. - - Label against which predictions will be compared. - """ - # Notice that the inputs are raw features, not transformed features here. - raw_feature_spec = _get_raw_feature_spec(schema) - - serialized_tf_example = tf.compat.v1.placeholder( - dtype=tf.string, shape=[None], name='input_example_tensor') - - # Add a parse_example operator to the tensorflow graph, which will parse - # raw, untransformed, tf examples. - raw_features = tf.io.parse_example( - serialized=serialized_tf_example, features=raw_feature_spec) - - # Now that we have our raw examples, process them through the tf-transform - # function computed during the preprocessing step. - transformed_features = tf_transform_output.transform_raw_features( - raw_features) - - # The key name MUST be 'examples'. - receiver_tensors = {'examples': serialized_tf_example} - - # NOTE: Model is driven by transformed features (since training works on the - # materialized output of TFT, but slicing will happen on raw features. - raw_features.update(transformed_features) - - return tfma.export.EvalInputReceiver( - features=raw_features, - receiver_tensors=receiver_tensors, - labels=transformed_features[features.transformed_name( - features.LABEL_KEY)]) - - -def _input_fn(file_pattern, data_accessor, tf_transform_output, batch_size=200): - """Generates features and label for tuning/training. - - Args: - file_pattern: List of paths or patterns of input tfrecord files. - data_accessor: DataAccessor for converting input to RecordBatch. - tf_transform_output: A TFTransformOutput. - batch_size: representing the number of consecutive elements of returned - dataset to combine in a single batch - - Returns: - A dataset that contains (features, indices) tuple where features is a - dictionary of Tensors, and indices is a single Tensor of label indices. - """ - return data_accessor.tf_dataset_factory( - file_pattern, - tfxio.TensorFlowDatasetOptions( - batch_size=batch_size, - label_key=features.transformed_name(features.LABEL_KEY)), - tf_transform_output.transformed_metadata.schema) - - -def _create_train_and_eval_spec(trainer_fn_args, schema): - """Build the estimator using the high level API. - - Args: - trainer_fn_args: Holds args used to train the model as name/value pairs. - schema: Holds the schema of the training examples. - - Returns: - A dict of the following: - - estimator: The estimator that will be used for training and eval. - - train_spec: Spec for training. - - eval_spec: Spec for eval. - - eval_input_receiver_fn: Input function for eval. - """ - - tf_transform_output = tft.TFTransformOutput(trainer_fn_args.transform_output) - - train_input_fn = lambda: _input_fn( # pylint: disable=g-long-lambda - trainer_fn_args.train_files, - trainer_fn_args.data_accessor, - tf_transform_output, - batch_size=constants.TRAIN_BATCH_SIZE) - - eval_input_fn = lambda: _input_fn( # pylint: disable=g-long-lambda - trainer_fn_args.eval_files, - trainer_fn_args.data_accessor, - tf_transform_output, - batch_size=constants.EVAL_BATCH_SIZE) - - train_spec = tf_estimator.TrainSpec( # pylint: disable=g-long-lambda - train_input_fn, - max_steps=trainer_fn_args.train_steps) - - serving_receiver_fn = lambda: _example_serving_receiver_fn( # pylint: disable=g-long-lambda - tf_transform_output, schema) - - exporter = tf_estimator.FinalExporter('chicago-taxi', serving_receiver_fn) - eval_spec = tf_estimator.EvalSpec( - eval_input_fn, - steps=trainer_fn_args.eval_steps, - exporters=[exporter], - name='chicago-taxi-eval') - - run_config = tf_estimator.RunConfig( - save_checkpoints_steps=999, keep_checkpoint_max=1) - - run_config = run_config.replace(model_dir=trainer_fn_args.serving_model_dir) - - estimator = _build_estimator( - hidden_units=constants.HIDDEN_UNITS, config=run_config) - - # Create an input receiver for TFMA processing - receiver_fn = lambda: _eval_input_receiver_fn( # pylint: disable=g-long-lambda - tf_transform_output, schema) - - return { - 'estimator': estimator, - 'train_spec': train_spec, - 'eval_spec': eval_spec, - 'eval_input_receiver_fn': receiver_fn - } - - -# TFX will call this function -def run_fn(fn_args): - """Train the model based on given args. - - Args: - fn_args: Holds args used to train the model as name/value pairs. - """ - schema = tfx.utils.parse_pbtxt_file(fn_args.schema_file, schema_pb2.Schema()) - - train_and_eval_spec = _create_train_and_eval_spec(fn_args, schema) - - # Train the model - logging.info('Training model.') - tf_estimator.train_and_evaluate(train_and_eval_spec['estimator'], - train_and_eval_spec['train_spec'], - train_and_eval_spec['eval_spec']) - logging.info('Training complete. Model written to %s', - fn_args.serving_model_dir) - - # Export an eval savedmodel for TFMA - # NOTE: When trained in distributed training cluster, eval_savedmodel must be - # exported only by the chief worker. - logging.info('Exporting eval_savedmodel for TFMA.') - tfma.export.export_eval_savedmodel( - estimator=train_and_eval_spec['estimator'], - export_dir_base=fn_args.eval_model_dir, - eval_input_receiver_fn=train_and_eval_spec['eval_input_receiver_fn']) - - logging.info('Exported eval_savedmodel to %s.', fn_args.eval_model_dir) diff --git a/tfx/experimental/templates/taxi/models/estimator_model/model_test.py b/tfx/experimental/templates/taxi/models/estimator_model/model_test.py deleted file mode 100644 index e5856b84a4..0000000000 --- a/tfx/experimental/templates/taxi/models/estimator_model/model_test.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright 2020 Google LLC. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import tensorflow as tf -from tensorflow import estimator as tf_estimator -from tfx.components.trainer import executor as trainer_executor -from tfx.experimental.templates.taxi.models.estimator_model import model - -from tensorflow_metadata.proto.v0 import schema_pb2 - - -class ModelTest(tf.test.TestCase): - - def testTrainerFn(self): - trainer_fn_args = trainer_executor.TrainerFnArgs( - train_files='/path/to/train.file', - transform_output='/path/to/transform_output', - serving_model_dir='/path/to/model_dir', - eval_files='/path/to/eval.file', - schema_file='/path/to/schema_file', - train_steps=1000, - eval_steps=100, - ) - schema = schema_pb2.Schema() - result = model._create_train_and_eval_spec(trainer_fn_args, schema) # pylint: disable=protected-access - self.assertIsInstance(result['estimator'], tf_estimator.Estimator) - self.assertIsInstance(result['train_spec'], tf_estimator.TrainSpec) - self.assertIsInstance(result['eval_spec'], tf_estimator.EvalSpec) - self.assertTrue(callable(result['eval_input_receiver_fn'])) diff --git a/tfx/extensions/google_cloud_ai_platform/trainer/component.py b/tfx/extensions/google_cloud_ai_platform/trainer/component.py index 49eab5512e..6c2821df60 100644 --- a/tfx/extensions/google_cloud_ai_platform/trainer/component.py +++ b/tfx/extensions/google_cloud_ai_platform/trainer/component.py @@ -37,8 +37,6 @@ def __init__(self, module_file: Optional[Union[str, data_types.RuntimeParameter]] = None, run_fn: Optional[Union[str, data_types.RuntimeParameter]] = None, - trainer_fn: Optional[Union[str, - data_types.RuntimeParameter]] = None, train_args: Optional[Union[trainer_pb2.TrainArgs, data_types.RuntimeParameter]] = None, eval_args: Optional[Union[trainer_pb2.EvalArgs, @@ -70,30 +68,9 @@ def __init__(self, ```python def run_fn(trainer.fn_args_utils.FnArgs): ... ``` - and the trained model must be - saved to FnArgs.serving_model_dir when this function is executed. For - Estimator based Executor, The module_file must implement a function - named `trainer_fn` at its top level. The function must have the - following signature. - ```python - def trainer_fn( - trainer.fn_args_utils.FnArgs, - tensorflow_metadata.proto.v0.schema_pb2 - ) -> Dict: ... - ``` - where the returned Dict has the following key-values. - - - `estimator`: an instance of tf.estimator.Estimator - - `train_spec`: an instance of tf.estimator.TrainSpec - - `eval_spec`: an instance of tf.estimator.EvalSpec - - `eval_input_receiver_fn`: an instance of tfma EvalInputReceiver. run_fn: A python path to UDF model definition function for generic trainer. See 'module_file' for details. Exactly one of 'module_file' or 'run_fn' must be supplied if Trainer uses GenericExecutor (default). - trainer_fn: A python path to UDF model definition function for estimator - based trainer. See 'module_file' for the required signature of the UDF. - Exactly one of 'module_file' or 'trainer_fn' must be supplied if Trainer - uses Estimator based Executor train_args: A proto.TrainArgs instance, containing args used for training Currently only splits and num_steps are available. Default behavior (when splits is empty) is train on `train` split. @@ -114,5 +91,4 @@ def trainer_fn( eval_args=eval_args, module_file=module_file, run_fn=run_fn, - trainer_fn=trainer_fn, custom_config=custom_config) diff --git a/tfx/extensions/google_cloud_ai_platform/trainer/executor.py b/tfx/extensions/google_cloud_ai_platform/trainer/executor.py index 230b599ced..1d152c3ae0 100644 --- a/tfx/extensions/google_cloud_ai_platform/trainer/executor.py +++ b/tfx/extensions/google_cloud_ai_platform/trainer/executor.py @@ -130,4 +130,4 @@ class Executor(GenericExecutor): """Start a trainer job on Google Cloud AI Platform using a default Trainer.""" def _GetExecutorClass(self): - return tfx_trainer_executor.Executor + return tfx_trainer_executor.GenericExecutor diff --git a/tfx/extensions/google_cloud_ai_platform/trainer/executor_test.py b/tfx/extensions/google_cloud_ai_platform/trainer/executor_test.py index 68658cb62e..f5f9d19f9a 100644 --- a/tfx/extensions/google_cloud_ai_platform/trainer/executor_test.py +++ b/tfx/extensions/google_cloud_ai_platform/trainer/executor_test.py @@ -49,7 +49,7 @@ def setUp(self): }, } self._executor_class_path = name_utils.get_full_name( - tfx_trainer_executor.Executor) + tfx_trainer_executor.GenericExecutor) self._generic_executor_class_path = name_utils.get_full_name( tfx_trainer_executor.GenericExecutor) diff --git a/tfx/orchestration/kubeflow/test_utils.py b/tfx/orchestration/kubeflow/test_utils.py index 50f87104ce..71e81f24f3 100644 --- a/tfx/orchestration/kubeflow/test_utils.py +++ b/tfx/orchestration/kubeflow/test_utils.py @@ -239,7 +239,6 @@ def create_primitive_type_components(pipeline_name: str) -> List[BaseComponent]: def create_e2e_components( pipeline_root: str, csv_input_location: str, - transform_module: str, trainer_module: str, ) -> List[BaseComponent]: """Creates components for a simple Chicago Taxi TFX pipeline for testing. @@ -247,7 +246,6 @@ def create_e2e_components( Args: pipeline_root: The root of the pipeline output. csv_input_location: The location of the input data directory. - transform_module: The location of the transform module file. trainer_module: The location of the trainer module file. Returns: @@ -262,7 +260,7 @@ def create_e2e_components( transform = Transform( examples=example_gen.outputs['examples'], schema=schema_gen.outputs['schema'], - module_file=transform_module) + module_file=trainer_module) latest_model_resolver = resolver.Resolver( strategy_class=latest_artifact_strategy.LatestArtifactStrategy, latest_model=Channel(type=Model)).with_id('latest_model_resolver') diff --git a/tfx/orchestration/kubeflow/v2/test_utils.py b/tfx/orchestration/kubeflow/v2/test_utils.py index 05b2f1076b..6491e73317 100644 --- a/tfx/orchestration/kubeflow/v2/test_utils.py +++ b/tfx/orchestration/kubeflow/v2/test_utils.py @@ -21,7 +21,6 @@ import tensorflow_model_analysis as tfma from tfx import v1 as tfx from tfx.components.example_gen import utils -from tfx.components.trainer.executor import Executor from tfx.dsl.component.experimental import executor_specs from tfx.dsl.component.experimental import placeholders from tfx.dsl.components.base import base_component @@ -220,7 +219,6 @@ def create_pipeline_components( model=tfx.dsl.Channel(type=tfx.types.standard_artifacts.Model)).with_id( 'Resolver.latest_model_resolver') trainer = tfx.components.Trainer( - custom_executor_spec=executor_spec.ExecutorClassSpec(Executor), examples=transform.outputs['transformed_examples'], schema=schema_gen.outputs['schema'], base_model=latest_model_resolver.outputs['model'], diff --git a/tfx/orchestration/kubeflow/v2/testdata/expected_full_taxi_pipeline_job.json b/tfx/orchestration/kubeflow/v2/testdata/expected_full_taxi_pipeline_job.json index ff631fc40c..92db9633ab 100644 --- a/tfx/orchestration/kubeflow/v2/testdata/expected_full_taxi_pipeline_job.json +++ b/tfx/orchestration/kubeflow/v2/testdata/expected_full_taxi_pipeline_job.json @@ -72,7 +72,7 @@ "container": { "args": [ "--executor_class_path", - "tfx.components.trainer.executor.Executor", + "tfx.components.trainer.executor.GenericExecutor", "--json_serialized_invocation_args", "{{$}}", "--json_serialized_inputs_spec_args", @@ -625,7 +625,7 @@ "force_tf_compat_v1": { "runtimeValue": { "constant": 0.0 - + } } } diff --git a/tfx/orchestration/kubeflow/v2/testdata/legacy/expected_full_taxi_pipeline_job.json b/tfx/orchestration/kubeflow/v2/testdata/legacy/expected_full_taxi_pipeline_job.json index 258d984690..da72f2eb64 100644 --- a/tfx/orchestration/kubeflow/v2/testdata/legacy/expected_full_taxi_pipeline_job.json +++ b/tfx/orchestration/kubeflow/v2/testdata/legacy/expected_full_taxi_pipeline_job.json @@ -66,7 +66,7 @@ "container": { "args": [ "--executor_class_path", - "tfx.components.trainer.executor.Executor", + "tfx.components.trainer.executor.GenericExecutor", "--json_serialized_invocation_args", "{{$}}" ], diff --git a/tfx/types/standard_component_specs.py b/tfx/types/standard_component_specs.py index a833e86e4c..140b1c4c21 100644 --- a/tfx/types/standard_component_specs.py +++ b/tfx/types/standard_component_specs.py @@ -101,7 +101,6 @@ PUSHED_MODEL_KEY = 'pushed_model' # Key for TrainerSpec RUN_FN_KEY = 'run_fn' -TRAINER_FN_KEY = 'trainer_fn' BASE_MODEL_KEY = 'base_model' HYPERPARAMETERS_KEY = 'hyperparameters' MODEL_RUN_KEY = 'model_run' @@ -397,7 +396,6 @@ class TrainerSpec(ComponentSpec): MODULE_FILE_KEY: ExecutionParameter(type=str, optional=True), MODULE_PATH_KEY: ExecutionParameter(type=str, optional=True), RUN_FN_KEY: ExecutionParameter(type=str, optional=True), - TRAINER_FN_KEY: ExecutionParameter(type=str, optional=True), CUSTOM_CONFIG_KEY: ExecutionParameter(type=str, optional=True), } INPUTS = {