Skip to content

Commit

Permalink
Expose native TF 2 path for TFX Transform. It is disabled by default.
Browse files Browse the repository at this point in the history
Implements RFC: tensorflow/community#308

PiperOrigin-RevId: 347120003
  • Loading branch information
tfx-copybara committed Dec 12, 2020
1 parent 73e0777 commit 6e607ba
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 17 deletions.
5 changes: 5 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
TFX, and the TFX source code has been moved to `/tfx/src`.
* TFX Transform switched to a (notably) faster implementation of
`tft.quantiles` analyzer.
* Added native TF 2 implementation of Transform. The default
behavior will continue to use Tensorflow's compat.v1 APIs. This can be
overriden by passing `force_tf_compat_v1=False` and enabling TF 2 behaviors.
The default behavior for TF 2 will be switched to the new native
implementation in a future release.

## Breaking changes
* Wheel package building for TFX has changed, and users need to follow the
Expand Down
6 changes: 6 additions & 0 deletions tfx/components/transform/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __init__(
instance_name: Optional[Text] = None,
materialize: bool = True,
disable_analyzer_cache: bool = False,
force_tf_compat_v1: bool = True,
custom_config: Optional[Dict[Text, Any]] = None):
"""Construct a Transform component.
Expand Down Expand Up @@ -136,6 +137,10 @@ def preprocessing_fn(inputs: Dict[Text, Any], custom_config:
disable_analyzer_cache: If False, Transform will use input cache if
provided and write cache output. If True, `analyzer_cache` must not be
provided.
force_tf_compat_v1: (Optional) If True, Transform will use Tensorflow in
compat.v1 mode irrespective of installed version of Tensorflow. Defaults
to `True`. Note: The default value will be switched to `False` in a
future release.
custom_config: A dict which contains additional parameters that will be
passed to preprocessing_fn.
Expand Down Expand Up @@ -179,6 +184,7 @@ def preprocessing_fn(inputs: Dict[Text, Any], custom_config:
schema=schema,
module_file=module_file,
preprocessing_fn=preprocessing_fn,
force_tf_compat_v1=int(force_tf_compat_v1),
splits_config=splits_config,
transform_graph=transform_graph,
transformed_examples=transformed_examples,
Expand Down
11 changes: 11 additions & 0 deletions tfx/components/transform/component_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,17 @@ def test_construct_with_cache_disabled_but_input_cache(self):
analyzer_cache=channel_utils.as_channel(
[standard_artifacts.TransformCache()]))

def test_construct_with_force_tf_compat_v1_false(self):
transform = component.Transform(
examples=self.examples,
schema=self.schema,
preprocessing_fn='my_preprocessing_fn',
force_tf_compat_v1=False,
)
self._verify_outputs(transform)
self.assertEqual(False,
bool(transform.spec.exec_properties['force_tf_compat_v1']))


if __name__ == '__main__':
tf.test.main()
15 changes: 14 additions & 1 deletion tfx/components/transform/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import tensorflow_data_validation as tfdv
import tensorflow_transform as tft
from tensorflow_transform import impl_helper
from tensorflow_transform import tf2_utils
import tensorflow_transform.beam as tft_beam
from tensorflow_transform.beam import analyzer_cache
from tensorflow_transform.beam import common as tft_beam_common
Expand Down Expand Up @@ -316,6 +317,8 @@ def Do(self, input_dict: Dict[Text, List[types.Artifact]],
analyze and transform splits can have overlap. Default behavior (when
splits_config is not set) is analyze the 'train' split and transform
all splits. If splits_config is set, analyze cannot be empty.
- force_tf_compat_v1: Whether to use TF in compat.v1 mode
irrespective of installed/enabled TF behaviors.
Returns:
None
Expand Down Expand Up @@ -390,6 +393,16 @@ def _GetCachePath(label, params_dict):
else:
return artifact_utils.get_single_uri(params_dict[label])

force_tf_compat_v1 = bool(exec_properties.get('force_tf_compat_v1', 1))
if force_tf_compat_v1 and not tf2_utils.use_tf_compat_v1(False):
absl.logging.warning(
'The default value of `force_tf_compat_v1` will change in a future '
'release from `True` to `False`. Since this pipeline has TF 2 '
'behaviors enabled, Transform will use native TF 2 at that point. You'
' can test this behavior now by passing `force_tf_compat_v1=False` '
'or disable it by explicitly setting `force_tf_compat_v1=True` in '
'the Transform component.')

label_inputs = {
labels.COMPUTE_STATISTICS_LABEL:
False,
Expand All @@ -414,7 +427,7 @@ def _GetCachePath(label, params_dict):
labels.CUSTOM_CONFIG:
exec_properties.get('custom_config', None),
labels.FORCE_TF_COMPAT_V1_LABEL:
True,
force_tf_compat_v1,
}
cache_input = _GetCachePath(ANALYZER_CACHE_KEY, input_dict)
if cache_input is not None:
Expand Down
21 changes: 5 additions & 16 deletions tfx/components/transform/executor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from tfx import types
from tfx.components.testdata.module_file import transform_module
from tfx.components.transform import executor
from tfx.components.transform import labels
from tfx.dsl.io import fileio
from tfx.proto import transform_pb2
from tfx.types import artifact_utils
Expand All @@ -54,17 +53,6 @@ class _TempPath(types.Artifact):
TYPE_NAME = 'TempPath'


class _ExecutorForTesting(executor.Executor):

def __init__(self, force_tf_compat_v1):
super(_ExecutorForTesting, self).__init__()
self._force_tf_compat_v1 = force_tf_compat_v1

def Transform(self, inputs, outputs, status_file):
inputs[labels.FORCE_TF_COMPAT_V1_LABEL] = self._force_tf_compat_v1
super(_ExecutorForTesting, self).Transform(inputs, outputs, status_file)


# TODO(b/122478841): Add more detailed tests.
class ExecutorTest(tft_unit.TransformTestCase):

Expand Down Expand Up @@ -165,10 +153,11 @@ def setUp(self):
transform_module.preprocessing_fn.__module__,
transform_module.preprocessing_fn.__name__)
self._exec_properties['splits_config'] = None
self._exec_properties['force_tf_compat_v1'] = int(
self._use_force_tf_compat_v1())

# Executor for test.
self._transform_executor = _ExecutorForTesting(
self._use_force_tf_compat_v1())
self._transform_executor = executor.Executor()

def _verify_transform_outputs(self,
materialize=True,
Expand Down Expand Up @@ -238,11 +227,11 @@ def _create_pipeline_wrapper(*_):
return result

with tft_unit.mock.patch.object(
_ExecutorForTesting,
executor.Executor,
'_CreatePipeline',
autospec=True,
side_effect=_create_pipeline_wrapper):
transform_executor = _ExecutorForTesting(self._use_force_tf_compat_v1())
transform_executor = executor.Executor()
transform_executor.Do(self._input_dict, self._output_dict,
self._exec_properties)
assert len(pipelines) == 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,13 @@
}
}
},
"force_tf_compat_v1":{
"runtimeValue":{
"constantValue":{
"intValue":"1"
}
}
},
"module_file":{
"runtimeValue":{
"constantValue":{
Expand Down
2 changes: 2 additions & 0 deletions tfx/types/standard_component_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,8 @@ class TransformSpec(ComponentSpec):
ExecutionParameter(type=(str, Text), optional=True),
'preprocessing_fn':
ExecutionParameter(type=(str, Text), optional=True),
'force_tf_compat_v1':
ExecutionParameter(type=int, optional=True),
'custom_config':
ExecutionParameter(type=(str, Text), optional=True),
'splits_config':
Expand Down

0 comments on commit 6e607ba

Please sign in to comment.