From ffa453f712fcf07a2a918726fd3cf622e39f9a50 Mon Sep 17 00:00:00 2001 From: Arno Eigenwillig Date: Tue, 6 Feb 2024 05:51:53 -0800 Subject: [PATCH] Let TF-GNN choose between keras or tf_keras consistently with TF 2.15: both provide Keras 2.15, but it matters which one is used, because they have separate class hierarchies and global registries. Along the way, refactor the nested case distinctions of tf_internal.py into a clear list of supported older TF/Keras versions. PiperOrigin-RevId: 604619976 --- tensorflow_gnn/graph/tf_internal.py | 115 +++++++++++++++++----------- 1 file changed, 69 insertions(+), 46 deletions(-) diff --git a/tensorflow_gnn/graph/tf_internal.py b/tensorflow_gnn/graph/tf_internal.py index b35427de..213cca5b 100644 --- a/tensorflow_gnn/graph/tf_internal.py +++ b/tensorflow_gnn/graph/tf_internal.py @@ -17,6 +17,12 @@ TODO(b/188399175): Use the public ExtensionType API instead. """ +import os + +## +## Part 1: TensorFlow symbols +## + # The following imports work in all supported versions of TF. # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top,g-bad-import-order from tensorflow.python.framework import composite_tensor @@ -32,39 +38,6 @@ except ImportError: type_spec_registry = None # Not available before TF 2.12. -# NOTE: See ../__init__.py for an up-front check of supported Keras versions. - -try: - try: - # Get Keras v2 from the separate tf_keras package. - # In OSS, it exists for TF2.14+. It may become required for TF2.16+. - from tf_keras.src.engine import keras_tensor # pytype: disable=import-error - from tf_keras.src.layers import core as core_layers # pytype: disable=import-error - import tf_keras.src.backend as keras_backend # pytype: disable=import-error - except ImportError: - # Get Keras v2 from the keras package. - # In OSS, this is possible for TF2.15 and older. - import keras # pytype: disable=import-error - if not keras.__version__.startswith('2.'): - raise ImportError( - 'tensorflow_gnn requires tf_keras to be installed or keras version <' - f' 3. Instead got keras version {keras.__version__}.' - ) from None # A Keras version mismatch is different to lacking tf_keras. - import keras # pytype: disable=import-error - if hasattr(keras, 'src'): # As of TF/Keras 2.13. - from keras.src.engine import keras_tensor # pytype: disable=import-error - from keras.src.layers import core as core_layers # pytype: disable=import-error - import keras.src.backend as keras_backend # pytype: disable=import-error - else: - from keras.engine import keras_tensor # pytype: disable=import-error - from keras.layers import core as core_layers # pytype: disable=import-error - import keras.backend as keras_backend # pytype: disable=import-error -except ImportError: - # Internal - keras_tensor = tf._keras_internal.engine.keras_tensor # pylint: disable=protected-access - core_layers = tf._keras_internal.layers.core # pylint: disable=protected-access - keras_backend = tf._keras_internal.backend # pylint: disable=protected-access - CompositeTensor = composite_tensor.CompositeTensor BatchableTypeSpec = type_spec.BatchableTypeSpec type_spec_register = ( @@ -79,22 +52,71 @@ type_spec_registry.lookup if type_spec_registry is not None else type_spec.lookup) -try: - # These types are semi-public as of TF/Keras 2.13. - # Whenever possible, get them the official way. +OpDispatcher = tf.__internal__.dispatch.OpDispatcher + + +## +## Part 2: Keras symbols, compatible with `tf.keras.*` +## + +# pytype: disable=import-error + +if tf.__version__.startswith("2.12."): + # tf.keras is keras 2.12, which does not yet have the `src` subdirectory. + from keras import backend as keras_backend + from keras.engine import keras_tensor as kt + from keras.layers import core as core_layers + # In 2.12, these symbols are not exposed yet under tf.keras.__internal__. + KerasTensor = kt.KerasTensor + RaggedKerasTensor = kt.RaggedKerasTensor + +elif tf.__version__.startswith("2.13.") or tf.__version__.startswith("2.14."): KerasTensor = tf.keras.__internal__.KerasTensor RaggedKerasTensor = tf.keras.__internal__.RaggedKerasTensor -except AttributeError: - KerasTensor = keras_tensor.KerasTensor - RaggedKerasTensor = keras_tensor.RaggedKerasTensor -# These KerasTensor helpers are still private in TF/Keras 2.13. -register_keras_tensor_specialization = ( - keras_tensor.register_keras_tensor_specialization) -delegate_property = core_layers._delegate_property # pylint: disable=protected-access -delegate_method = core_layers._delegate_method # pylint: disable=protected-access + # tf.keras is keras. + # For TF 2.14, there also exists a tf_keras package, but TF does not use it. + from keras.src import backend as keras_backend + from keras.src.engine import keras_tensor as kt + from keras.src.layers import core as core_layers -OpDispatcher = tf.__internal__.dispatch.OpDispatcher +elif tf.__version__.startswith("2.15."): + KerasTensor = tf.keras.__internal__.KerasTensor + RaggedKerasTensor = tf.keras.__internal__.RaggedKerasTensor + # OSS TensorFlow 2.15 can choose between keras 2.15 and tf_keras 2.15 + # BUT THESE ARE DIFFERENT PACKAGES WITH SEPARATE GLOBAL REGISTRIES + # so it is essential that we pick the right one by replicating the logic from + # https://github.com/tensorflow/tensorflow/blob/r2.15/tensorflow/python/util/lazy_loader.py#L96 + if os.environ.get("TF_USE_LEGACY_KERAS", None) in ("true", "True", "1"): + from tf_keras.src import backend as keras_backend + from tf_keras.src.layers import core as core_layers + from tf_keras.src.engine import keras_tensor as kt + else: + from keras.src import backend as keras_backend + from keras.src.layers import core as core_layers + from keras.src.engine import keras_tensor as kt +elif hasattr(tf, "_keras_internal"): # Special case: internal. + KerasTensor = tf.keras.__internal__.KerasTensor + RaggedKerasTensor = tf.keras.__internal__.RaggedKerasTensor + kt = tf._keras_internal.engine.keras_tensor # pylint: disable=protected-access + core_layers = tf._keras_internal.layers.core # pylint: disable=protected-access + keras_backend = tf._keras_internal.backend # pylint: disable=protected-access + +else: # TF2.16 and onwards. + # ../__init__.py has already checked that tf.keras has version 2, not 3, + # which implies that tf.keras is tf_keras, and we do not second-guess + # the selection logic. + KerasTensor = tf.keras.__internal__.KerasTensor + RaggedKerasTensor = tf.keras.__internal__.RaggedKerasTensor + from tf_keras.src import backend as keras_backend + from tf_keras.src.layers import core as core_layers + from tf_keras.src.engine import keras_tensor as kt + +# pytype: enable=import-error + +register_keras_tensor_specialization = kt.register_keras_tensor_specialization +delegate_property = core_layers._delegate_property # pylint: disable=protected-access +delegate_method = core_layers._delegate_method # pylint: disable=protected-access unique_keras_object_name = keras_backend.unique_object_name # Delete imports, in their order above. @@ -102,5 +124,6 @@ del type_spec del tf del type_spec_registry -del keras_tensor +del keras_backend del core_layers +del kt