Skip to content

Latest commit

 

History

History
385 lines (318 loc) · 16.5 KB

20200420-tfx-tuner-component.md

File metadata and controls

385 lines (318 loc) · 16.5 KB

TFX Tuner Component

Status Accepted
Author(s) Jiayi Zhao ([email protected]), Amy Wu ([email protected])
Sponsor Zhitao Li ([email protected]), Tom O'Malley ([email protected]), Matthieu Monsch ([email protected]), Makoto Uchida ([email protected]), Goutham Bhat ([email protected])
Updated 2020-04-20

Objective

Goal

  • A new Tuner component in TFX for automated hyper-parameter tuning, which is based on abstractions from KerasTuner library, in order to reuse abstractions and algorithms from latter.

Non Goal

  • Natively support multi-worker tuning by the system. As TFX doesn't have ability to manage multi-worker clusters, running multiple trials in parallel (parallel tuning) and running each trial in distributed env (distributed training) are not supported natively. Parallel tuning may instead be realized by a particular implementation of TFX Tuner (custom Executor), e.g., in Google Cloud environment.
  • Implementation of custom tuner for KerasTuner library is out of scope of this design discussion, e.g., a built-in EstimatorTuner support. However, user project can still implement a tuner that inherits from kerastuner.BaseTuner and provide it to the proposed TFX Tuner component.

Background and Motivation

A hyperparameter is a parameter whose value is used to control the learning process of a model or the model itself (e.g., layers and number of nodes). By contrast, the values of other parameters (typically node weights) are learned.

Hyperparameter optimization is a critical part of many machine learning pipelines. Thus we propose a new TFX component, with the given search space which specifies the hyperparameter configuration (name, type, range etc.). TFX will optimize the hyperparameters based on the tuning algorithm.

User Benefit

This document proposes a built-in TFX Tuner component, which works seamlessly with Trainer and other TFX components. As the Tuner component will utilize the KerasTuner library, all supported tuning methods will be available to TFX, including custom implementation of KerasTuner.

Design Proposal

TFX Tuner component will be built with the KerasTuner library. In the following sections, we will first briefly go over the KerasTuner library and several concepts in hyperparameter optimization. Then we will focus on our Tuner component interface and how we utilize the KerasTuner library. After that, we will discuss parallel tuning and our plan on Google Cloud integration.

KerasTuner Library

The following graph shows a typical workflow of hyperparameter tuning under the KerasTuner framework:

Given the user provided model which accepts a hyperparameter container, tuner can search optimization through trials created by the tuning algortihm. For each trial, values within search spaces will be assigned to hyperparameter containers, and the user model will be trained with these hyperparameter values and evaluated based on the objective provided to the tuner. The evaluation results will be reported back to tuner and the tuning algorithm will decide the hyperparameter values for the next trial. After reaching certain conditions, e.g., max trials, the tuner will stop iteration and return the optimal hyperparameters.

KerasTuner library provides above tuning functionality, here are some abstractions in KerasTuner:

  • HyperParameters: Hyperparameter container for both search space, and current values.
  • Oracle: Implementation of a hyperparameter tuning algorithm, e.g., random search, including state management of the algorithm’s progress.
  • Trial: Provided by the Oracle, contains information about Hyperparameter values for the current iteration.
  • BaseTuner: a base tuner interface for above tuning workflow, responsible for the iteration of trial execution:
    • Generates Trial using Oracle.
    • Trains user model with the HyperParameters in the current Trial.
    • Evaluates metrics and reports back to Oracle for next Trial.
  • Tuner: An implementation of BaseTuner, for Keras model tuning.

Note: Other than the Tuner, abstractions defined by HyperParameters, Oracle, Trial and BaseTuner are not restricted to Keras models, although the library is called KerasTuner.

For more details and code examples, please refer to here.

Component Interface

Tuner component takes raw or transformed examples as input, along with schema or transform_graph for the feature specification, and outputs the hyperparameter tuning results, below shows the specification of Tuner component:

class TunerSpec(ComponentSpec):
  """ComponentSpec for TFX Tuner Component."""

  PARAMETERS = {
    # Specify a python module file which contains a UDF `tuner_fn`.
    'module_file': ExecutionParameter(type=(str, Text), optional=True),
    # Specify the steps for the training stage of each trial’s execution.
    'train_args': ExecutionParameter(type=trainer_pb2.TrainArgs),
    'eval_args': ExecutionParameter(type=trainer_pb2.EvalArgs),
  }

  INPUTS = {
    'examples': ChannelParameter(type=standard_artifacts.Examples),
    'schema': ChannelParameter(
            type=standard_artifacts.Schema, optional=True),
    'transform_graph':
        ChannelParameter(
            type=standard_artifacts.TransformGraph, optional=True),
  }

  OUTPUTS = {
    'best_hyperparameters':
        ChannelParameter(type=standard_artifacts.HyperParameters),
  }

Trainer has an optional hyperparameters input; tuning result can be fed into it so that Trainer can utilize best hyperparameters to construct the model. Below shows an example about how tuner and trainer are chained in the pipeline:

# TrainerSpec:
  INPUTS = {
    ...
    'hyperparameters':
        ChannelParameter(
            type=standard_artifacts.HyperParameters, optional=True),
  }

# Pipeline DSL Example:
  tuner = Tuner(
      examples=example_gen.outputs['examples'],
      schema=schema_gen.outputs['schema'],
      module_file=module_file,
      train_args=trainer_pb2.TrainArgs(num_steps=1000),
      eval_args=trainer_pb2.EvalArgs(num_steps=500))

  trainer = Trainer(
      module_file=module_file,
      examples=example_gen.outputs['examples'],
      schema=schema_gen.outputs['schema'],
      hyperparameters=tuner.outputs['best_hyperparameters'],
      train_args=trainer_pb2.TrainArgs(num_steps=10000),
      eval_args=trainer_pb2.EvalArgs(num_steps=5000))

For Trainer, users need to define model code and training logic (Generic Trainer) in the module_file. For Tuner, in addition to model code, users also need to define hyperparameters, search space and a tuning algorithm in the module_file. A tuner_fn with the following signature is required for Tuner:

from kerastuner.engine import base_tuner
import tensorflow as tf
from tfx.components.trainer.executor import TrainerFnArgs

# Current TrainerFnArgs will be renamed to FnArgs as a util class.
FnArgs = TrainerFnArgs
TunerFnResult = NamedTuple('TunerFnResult',
                          [('tuner', base_tuner.BaseTuner),
                           ('fit_kwargs', Dict[Text, Any])])

def tuner_fn(fn_args: FnArgs) -> TunerFnResult:
  """Build the tuner using the KerasTuner API.

  Args:
    fn_args: Holds args as name/value pairs.
      working_dir: working dir for tuning. Automatically set by Executor.
      train_files: List of file paths containing training tf.Example data.
      eval_files: List of file paths containing eval tf.Example data.
      train_steps: number of train steps.
      eval_steps: number of eval steps.
      schema: optional schema file of the input data.
      transform_graph: optional transform graph produced by TFT.

  Returns:
    A namedtuple contains the following:
      - tuner: A BaseTuner that will be used for tuning.
      - fit_kwargs: Args to pass to tuner’s run_trial function for fitting the
                    model , e.g., the training and validation dataset. Required
                    args depend on the above tuner’s implementation.
  """

The TunerFnResult returned by the above tuner_fn contains an instance that implements the BaseTuner interface, that’s the contract required by Tuner for tuning. The model code, hyperparameters, search space and tuning algorithm are hidden under the BaseTuner abstraction so the Tuner itself is generic and agnostic to the model framework and tuning logic. Below shows an example module file with Keras model:

import kerastuner
import tensorflow as tf
...

def _input_fn(file_pattern: Text, ...) -> tf.data.Dataset:
  ...

# Model code for Trainer and Tuner.
def _build_keras_model(hp: kerastuner.HyperParameters) -> tf.keras.Model:
  ...
  for _ in range(hp.get('num_layers')):
    ...
  ...
  model = tf.keras.Model(...)
  model.compile(
     optimizer=tf.keras.optimizers.Adam(hp.get('learning_rate')),
     loss='sparse_categorical_crossentropy',
     metrics=[tf.keras.metrics.Accuracy()])
  return model

# This will be called by TFX Tuner.
def tuner_fn(fn_args: FnArgs) -> TunerFnResult:
  hp = kerastuner.HyperParameters()
  # Defines search space.
  hp.Choice('learning_rate', [1e-1, 1e-3])
  hp.Int('num_layers', 1, 5)

  # RandomSearch is a subclass of Keras model Tuner.
  tuner = kerastuner.RandomSearch(
      _build_keras_model,
      max_trials=5,
      hyperparameters=hp,
      allow_new_entries=False,
      objective='val_accuracy',
      directory=fn_args.working_dir,
      project_name='test')

  train_dataset=_input_fn(fn_args.train_files, ...)
  eval_dataset=_input_fn(fn_args.eval_files, ...)

  return TunerFnResult(
      tuner=tuner,
      fit_kwargs={'x': train_dataset,
                  'validation_data': eval_dataset,
                  'steps_per_epoch': fn_args.train_steps,
                  'validation_steps': fn_args.eval_steps})

# This will be called by TFX Generic Trainer.
def run_fn(fn_args: FnArgs) -> None:
  hp = kerastuner.HyperParameters.from_config(fn_args.hyperparameters)
  model = _build_keras_model(hp)
  model.fit(...)
  model.save(...)

In Tuner’s executor, tuner_fn will be called with information resolved from component inputs, then we call the search function of the returned tuner with fit_kwargs to launch trials for tuning, and finally emit the best trial’s hyperparameters:

# Executor of Tuner Component:
class Executor(base_executor.BaseExecutor):

 def Do(self,
        input_dict: Dict[Text, List[types.Artifact]],
        output_dict: Dict[Text, List[types.Artifact]],
        exec_properties: Dict[Text, Any]) -> None:
  ...
  tuner_spec = tuner_fn(self._create_fn_args(input_dict, exec_properties))
  tuner_spec.tuner.search(**tuner_spec.fit_kwargs)
  # Output file contains json format string of hyperparameters.get_config().
  self._emit_best_hyperparameters(
    output_dict, tuner_spec.tuner.get_best_hyperparameters()[0])

Parallel Tuning

In parallel tuning, multiple trials are executed in parallel. In this section, we will discuss how distribution works for KerasTuner library and the status of TFX.

In the search function of tuner, trials will be run in sequence instead of in parallel. To support parallel tuning, we need to launch multiple tuners (the tuner here refers to the one in KerasTuner library, not TFX Tuner component), and have an optimization service for managing the state of the tuning algorithm, with which oracle of each tuner communicates, and retrieves the trials for each tuner.

The above graph shows a parallel tuning of three tuners. Each tuner runs as a different worker, and it retrieves trials from its own oracle, which talks to optimization service. Trials of different tuners can run in parallel but trials within the same tuner will still execute in sequence. When launching tuners, the same identifier will be assigned to each oracle, thus the optimization service knows they are in the same tuning job group and will assign hyperparameter values for their trials based on the algorithm.

The number of parallel tuners can be passed to component by the TuneArgs as shown below:

# Args specific to tuning.
message TuneArgs {
  # Number of trials to run in parallel.
  # Each trial will be trained and evaluated by separate worker jobs.
  int32 num_parallel_trials = 1;
}

class TunerSpec(ComponentSpec):

  PARAMETERS = {
    ...
    'tune_args': ExecutionParameter(type=tuner_pb2.TuneArgs),
  }

The KerasTuner library allows users to configure tf.distribute.Strategy if they are using kerastuner.Tuner class (or subclasses of it). In above parallel tuning, each trial (each model training) is executed in a single worker, as such only single machine strategy is allowed. To support multi-worker distributed training, we need to be able to execute the trial (training) on different workers.

At the time of writing, the KerasTuner library can be used for parallel execution of trail with a single machine tf.distribute.Strategy, e.g., MirroredStrategy, multi-worker strategy (distributed training for trial) support is on the roadmap of KerasTuner. Note that cluster management of multiple worker machines is not part of the KerasTuner library.

At the time of writing, TFX does not have the ability to manage the multi-worker cluster and the centralized optimization service. As a result parallel tuning or distributed training is not supported natively in TFX (local or on-prem). In the next section, we will discuss the integration for Google Cloud. Similar parallel tuning support can be built for other execution environments.

Google Cloud Integration

In this section, we discuss the Tuner component with Google Cloud AI Platform (CAIP), specifically, an implementation of KerasTuner Oracle that talks to the AI Platform Optimizer as the centralized optimization service, and a custom Tuner executor implementation that makes use of the Cloud Optimizer-based Oracle (symbol names are subject to change).

As mentioned above in the parallel tuning section, KerasTuner uses a centralized optimization service that manages states of a tuning study and trials. In addition to that, we will create a CloudOracle as a client to the AI Platform Optimizer service, and a CloudTuner which inherits from Keras Tuner. In the module file, users create the tuner_fn with CloudTuner, and then users configure the TFX Tuner component to use the a custom Tuner executor (CloudExecutor), which launches multiple CloudTuners on a Google Cloud AI Platform Training job with possibly multiple worker machines running various trials concurrently. Below shows the workflow for in process tuning and Cloud tuning.

Future work

  • Native support for multi-worker parallel tuning.
  • Custom Tuner (inherits from BaseTuner) examples, e.g., for Estimator support or Keras custom training loop support.