Skip to content

Commit

Permalink
Expose native TF 2 path for TFX Transform.
Browse files Browse the repository at this point in the history
Implements RFC: tensorflow/community#308

PiperOrigin-RevId: 345978930
  • Loading branch information
tfx-copybara committed Dec 6, 2020
1 parent 270bb38 commit 9e892f6
Show file tree
Hide file tree
Showing 8 changed files with 49 additions and 18 deletions.
5 changes: 5 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
support for pipeline operations.
* Added an experimental template to use with Kubeflow V2 runner.
* Added sanitization of user-specified pipeline name in Kubeflow V2 runner.
* Added native TF 2 implementation of Transform. The default
behavior will continue to use Tensorflow's compat.v1 APIs. This can be
overriden by passing `force_tf_compat_v1=False`. The default
behavior for TF 2 users will be switched to the new native implementation in
a future release.

## Breaking changes

Expand Down
5 changes: 5 additions & 0 deletions tfx/components/transform/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __init__(
instance_name: Optional[Text] = None,
materialize: bool = True,
disable_analyzer_cache: bool = False,
force_tf_compat_v1: bool = True,
custom_config: Optional[Dict[Text, Any]] = None):
"""Construct a Transform component.
Expand Down Expand Up @@ -136,6 +137,9 @@ def preprocessing_fn(inputs: Dict[Text, Any], custom_config:
disable_analyzer_cache: If False, Transform will use input cache if
provided and write cache output. If True, `analyzer_cache` must not be
provided.
force_tf_compat_v1: (Optional) If True, Transform will use Tensorflow in
compat.v1 mode irrespective of installed version of Tensorflow. Defaults
to `True`.
custom_config: A dict which contains additional parameters that will be
passed to preprocessing_fn.
Expand Down Expand Up @@ -179,6 +183,7 @@ def preprocessing_fn(inputs: Dict[Text, Any], custom_config:
schema=schema,
module_file=module_file,
preprocessing_fn=preprocessing_fn,
force_tf_compat_v1=int(force_tf_compat_v1),
splits_config=splits_config,
transform_graph=transform_graph,
transformed_examples=transformed_examples,
Expand Down
11 changes: 11 additions & 0 deletions tfx/components/transform/component_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,17 @@ def test_construct_with_cache_disabled_but_input_cache(self):
analyzer_cache=channel_utils.as_channel(
[standard_artifacts.TransformCache()]))

def test_construct_with_force_tf_compat_v1_false(self):
transform = component.Transform(
examples=self.examples,
schema=self.schema,
preprocessing_fn='my_preprocessing_fn',
force_tf_compat_v1=False,
)
self._verify_outputs(transform)
self.assertEqual(False,
bool(transform.spec.exec_properties['force_tf_compat_v1']))


if __name__ == '__main__':
tf.test.main()
14 changes: 13 additions & 1 deletion tfx/components/transform/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import tensorflow_data_validation as tfdv
import tensorflow_transform as tft
from tensorflow_transform import impl_helper
from tensorflow_transform import tf2_utils
import tensorflow_transform.beam as tft_beam
from tensorflow_transform.beam import analyzer_cache
from tensorflow_transform.beam import common as tft_beam_common
Expand Down Expand Up @@ -316,6 +317,8 @@ def Do(self, input_dict: Dict[Text, List[types.Artifact]],
analyze and transform splits can have overlap. Default behavior (when
splits_config is not set) is analyze the 'train' split and transform
all splits. If splits_config is set, analyze cannot be empty.
- force_tf_compat_v1: Whether to use TF in compat.v1 mode
irrespective of installed/enabled TF behaviors.
Returns:
None
Expand Down Expand Up @@ -390,6 +393,15 @@ def _GetCachePath(label, params_dict):
else:
return artifact_utils.get_single_uri(params_dict[label])

force_tf_compat_v1 = bool(exec_properties.get('force_tf_compat_v1'))
if force_tf_compat_v1 and not tf2_utils.use_tf_compat_v1(False):
absl.logging.warning(
'The default value of `force_tf_compat_v1` will change in a future '
'release. Since this pipeline has TF 2 behaviors enabled, Transform '
'will use native TF 2 at that point. You can test this behavior now '
'by passing `force_tf_compat_v1=False` or disable it by explicitly '
'setting `force_tf_compat_v1=True` in the Transform component.')

label_inputs = {
labels.COMPUTE_STATISTICS_LABEL:
False,
Expand All @@ -414,7 +426,7 @@ def _GetCachePath(label, params_dict):
labels.CUSTOM_CONFIG:
exec_properties.get('custom_config', None),
labels.FORCE_TF_COMPAT_V1_LABEL:
True,
force_tf_compat_v1,
}
cache_input = _GetCachePath(ANALYZER_CACHE_KEY, input_dict)
if cache_input is not None:
Expand Down
21 changes: 5 additions & 16 deletions tfx/components/transform/executor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from tfx import types
from tfx.components.testdata.module_file import transform_module
from tfx.components.transform import executor
from tfx.components.transform import labels
from tfx.dsl.io import fileio
from tfx.proto import transform_pb2
from tfx.types import artifact_utils
Expand All @@ -54,17 +53,6 @@ class _TempPath(types.Artifact):
TYPE_NAME = 'TempPath'


class _ExecutorForTesting(executor.Executor):

def __init__(self, force_tf_compat_v1):
super(_ExecutorForTesting, self).__init__()
self._force_tf_compat_v1 = force_tf_compat_v1

def Transform(self, inputs, outputs, status_file):
inputs[labels.FORCE_TF_COMPAT_V1_LABEL] = self._force_tf_compat_v1
super(_ExecutorForTesting, self).Transform(inputs, outputs, status_file)


# TODO(b/122478841): Add more detailed tests.
class ExecutorTest(tft_unit.TransformTestCase):

Expand Down Expand Up @@ -165,10 +153,11 @@ def setUp(self):
transform_module.preprocessing_fn.__module__,
transform_module.preprocessing_fn.__name__)
self._exec_properties['splits_config'] = None
self._exec_properties['force_tf_compat_v1'] = int(
self._use_force_tf_compat_v1())

# Executor for test.
self._transform_executor = _ExecutorForTesting(
self._use_force_tf_compat_v1())
self._transform_executor = executor.Executor()

def _verify_transform_outputs(self,
materialize=True,
Expand Down Expand Up @@ -238,11 +227,11 @@ def _create_pipeline_wrapper(*_):
return result

with tft_unit.mock.patch.object(
_ExecutorForTesting,
executor.Executor,
'_CreatePipeline',
autospec=True,
side_effect=_create_pipeline_wrapper):
transform_executor = _ExecutorForTesting(self._use_force_tf_compat_v1())
transform_executor = executor.Executor()
transform_executor.Do(self._input_dict, self._output_dict,
self._exec_properties)
assert len(pipelines) == 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,13 @@
}
}
},
"force_tf_compat_v1":{
"runtimeValue":{
"constantValue":{
"intValue":"1"
}
}
},
"module_file":{
"runtimeValue":{
"constantValue":{
Expand Down
2 changes: 2 additions & 0 deletions tfx/types/standard_component_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,8 @@ class TransformSpec(ComponentSpec):
ExecutionParameter(type=(str, Text), optional=True),
'preprocessing_fn':
ExecutionParameter(type=(str, Text), optional=True),
'force_tf_compat_v1':
ExecutionParameter(type=int),
'custom_config':
ExecutionParameter(type=(str, Text), optional=True),
'splits_config':
Expand Down
2 changes: 1 addition & 1 deletion tfx/utils/dependency_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def make_beam_dependency_flags(beam_pipeline_args: List[Text]) -> List[Text]:
setuptools.setup(
name='tfx_ephemeral',
version='{version}',
packages=setuptools.find_namespace_packages(),
packages=setuptools.find_packages(),
install_requires=[{install_requires}],
)
"""
Expand Down

0 comments on commit 9e892f6

Please sign in to comment.