Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TPU, keras preprocessing layer] Some Op must be a compile-time constant. #15655

Closed
3 tasks done
ProtossDragoon opened this issue Nov 16, 2021 · 9 comments
Closed
3 tasks done

Comments

@ProtossDragoon
Copy link

ProtossDragoon commented Nov 16, 2021

Please go to TF Forum for help and support:

https://discuss.tensorflow.org/tag/keras

If you open a GitHub issue, here is our policy:

It must be a bug, a feature request, or a significant problem with the documentation (for small docs fixes please send a PR instead).
The form below must be filled out.

Here's why we have that policy:.

Keras developers respond to issues. We want to focus on work that benefits the whole community, e.g., fixing bugs and adding features. Support only helps individuals. GitHub also notifies thousands of people when issues are filed. We want them to see you communicating an interesting problem, rather than being redirected to Stack Overflow.

System information.

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): y
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): google colab
  • TensorFlow installed from (source or binary): google colab
  • TensorFlow version (use command below): 2.7
  • Python version: google colab
  • Bazel version (if compiling from source):
  • GPU model and memory: TPU issue
  • Exact command to reproduce: Here is a COLAB notebook! You could reproduce this issue without any other codes. Just change COLAB runtime to TPU device, and run all cells.

You can collect some of this information using our environment capture script:

https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh

You can obtain the TensorFlow version with:
python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"

Describe the problem.

Hi!
TPU error raises especially with Kears preprocessing layers.
I've tried to connect two models, augmentation model that contains preprocessing layer and segmentation model.

def new_concatenated_model(
    image_input_hw,
    mask_input_hw,
    class_n
):
    seg_model = create_segmentation_model(class_n)
    aug_model = create_augmentation_model(
        image_input_hw, mask_input_hw, class_n)
    
    image_input_shape = list(image_input_hw) + [3]

    @auto_tpu(device=CURRENT_DEVICE) # decorator `auto_tpu` is just context manager.
    def create():
        im = seg_model.input
        model = AugConcatedSegModel(
            inputs=im,
            outputs=seg_model(im),
            augmentation_model=aug_model,
            name='seg_model_train_with_aug'
        )
        return model
    
    model = create()
    return model

train_step() function code was mainly came from tensorflow official tutorial document.

class AugConcatedSegModel(tf.keras.Model):
    def __init__(
        self,
        inputs=None,
        outputs=None,
        augmentation_model=None, 
        **kwargs
    ):
        super().__init__(inputs=inputs, outputs=outputs, **kwargs)
        self.augmentation_model = augmentation_model

    def train_step(self, data):
        im, ma = data
        im, ma = self.augmentation_model((im, ma))

        with tf.GradientTape() as tape:
            ma_pred = self(im, training=True)  # Forward pass
            # Compute the loss value
            # (the loss function is configured in `compile()`)
            loss = self.compiled_loss(ma, ma_pred, regularization_losses=self.losses)

        # Compute gradients
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        # Update metrics (includes the metric that tracks the loss)
        self.compiled_metrics.update_state(ma, ma_pred)
        # Return a dict mapping metric names to current value
        return {m.name: m.result() for m in self.metrics}

Describe the current behavior.

Expected to train successfully without error.
same code were tested on:

  • CPU : No errors
  • GPU : No errors
  • TPU : Error

You could reproduce this error very fast
https://colab.research.google.com/drive/1LhHj1FrkZE9QnFhY-NOO8mn7aiXhZgNh?usp=sharing
Runtime - Run all.

  • When I changed augmentation model to just plain Conv2D layers, the error disappeared.

Describe the expected behavior.

The COLAB notebook runs without any error.


Contributing.

  • Do you want to contribute a PR? (yes/no):
  • If yes, please read this page for instructions
  • Briefly describe your candidate solution(if contributing):

Standalone code to reproduce the issue.

https://colab.research.google.com/drive/1LhHj1FrkZE9QnFhY-NOO8mn7aiXhZgNh?usp=sharing

Source code / logs.

Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached. Try to provide a reproducible test case that is the bare minimum necessary to generate the problem.

InvalidArgumentError: 9 root error(s) found.
  (0) INVALID_ARGUMENT: {{function_node __inference_train_function_692915}} Input 0 to node `sequential_augmentation_model/sequential_augmentation_layers/random_flip/stateless_random_flip_left_right/stateless_random_uniform/StatelessRandomUniformV2` with op StatelessRandomUniformV2 must be a compile-time constant.

XLA compilation requires that operator arguments that represent shapes or dimensions be evaluated to concrete values at compile time. This error means that a shape or dimension argument could not be evaluated at compile time, usually because the value of the argument depends on a parameter to the computation, on a variable, or on a stateful operation such as a random number generator.

	 [[{{node sequential_augmentation_model/sequential_augmentation_layers/random_flip/stateless_random_flip_left_right/stateless_random_uniform/StatelessRandomUniformV2}}]]
	 [[TPUReplicate/_compile/_1646634736830564460/_4]]
  (1) INVALID_ARGUMENT: {{function_node __inference_train_function_692915}} Input 0 to node `sequential_augmentation_model/sequential_augmentation_layers/random_flip/stateless_random_flip_left_right/stateless_random_uniform/StatelessRandomUniformV2` with op StatelessRandomUniformV2 must be a compile-time constant.

XLA compilation requires that operator arguments that represent shapes or dimensions be evaluated to concrete values at compile time. This error means that a shape or dimension argument could not be evaluated at compile time, usually because the value of the argument depends on a parameter to the computation, on a variable, or on a stateful operation such as a random number generator.

	 [[{{node sequential_augmentation_model/sequential_augmentation_layers/random_flip/stateless_random_flip_left_right/stateless_random_uniform/StatelessRandomUniformV2}}]]
	 [[TPUReplicate/_compile/_1646634736830564460/_4]]
	 [[tpu_compile_succeeded_assert/_5094882425795608634/_5/_47]]
  (2) INVALID_ARGUMENT: {{function_node __inference_train_function_692915}} Input 0 to node `sequential_augmentation_model/sequential_augmentation_layers/random_flip/stateless_random_flip_left_right/stateless_random_uniform/StatelessRandomUniformV2` with op StatelessRandomUniformV2 must be a compile-time constant.

XLA compilation requires that operator arguments that represent shapes or dimensions be evaluated to concrete values at compile time. This error means that a shape or dimension argument could not be evaluated at compile time, usually because the value of the argument depends on a parameter to the computation, on a variable, or on a stateful operation such as a random number generator.

	 [[{{node sequential_augmentation_model/sequential_augmentation_layers/random_flip/stateless_random_flip_left_right/stateless_random_uniform/StatelessRandomUniformV2}}]]
	 [[TPUReplicate/_compile/_1646634736830564460/_4]]
	 [[tpu_compile_succeeded_assert/_5094882425795608634/_5/_159]]
  (3) INVALID_ARGUMENT: {{function_node __inference_train_function_692915}} Input 0 to node `sequential_augmentation_model/sequential_a ... [truncated]

NOTE: This issue is came from tensorflow/tensorflow#53051

@sanatmpa1
Copy link
Contributor

I am able to reproduce the issue with TF 2.7.0 . Please find the gist here.

@sachinprasadhs sachinprasadhs added the keras-team-review-pending Pending review by a Keras team member. label Dec 1, 2021
@mattdangerw
Copy link
Member

Took a look. There may something we need to change in how we set up randomness for the RandomFlip layer, I will check about this.

Overall, I think preferred approach here should be to apply the preprocessing layers inside a tf.data.Dataset.map before the training step. This would keep all preprocessing running on the CPU asynchronously, which should be more efficient in this case.

See this section of our guide which explains the choice:
https://keras.io/guides/preprocessing_layers/#preprocessing-data-before-the-model-or-inside-the-model

Note that the RandomRotation underlying op does not have TPU support (why you have tf.config.set_soft_device_placement(True)) so you will be running partially on the CPU anyway.

This blogpost also shows an example of running preprocessing separately tf.data and prefetching:
https://blog.tensorflow.org/2021/11/an-introduction-to-keras-preprocessing.html

@ProtossDragoon
Copy link
Author

Thank you for your insightful guidance @mattdangerw ! I'll see the blog post and revise my code.

@qlzh727
Copy link
Member

qlzh727 commented Dec 2, 2021

Adding @wangpengmit who works on tf.random.Generator.

@qlzh727 qlzh727 removed the keras-team-review-pending Pending review by a Keras team member. label Dec 2, 2021
@wangpengmit
Copy link
Contributor

StatelessRandomUniformV2 requires the inputs shape and alg to be compile-time-constant on XLA (code). I haven't looked at this bug closely to see why shape or alg is a dynamic tensor here (as opposed to a constant). Will look closer later.

@dlfrnaos19
Copy link

same issue

@bhack
Copy link
Contributor

bhack commented Jul 14, 2022

@wangpengmit It seems that we had a more general issue with all CompileTimeConstantInput args.

See tensorflow/tensorflow#56769

@qlzh727
Copy link
Member

qlzh727 commented Jul 25, 2022

This should be now addressed by tensorflow/tensorflow@1949eec.

@qlzh727 qlzh727 closed this as completed Jul 26, 2022
@google-ml-butler
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

8 participants