-
Notifications
You must be signed in to change notification settings - Fork 330
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement check to skip value range transform #198
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably add a simple test case, just for coverage.
Yeah, I was wondering about that. Any suggestion @qlzh727? We’d have to inspect the graph, right? |
@qlzh727 How is this working between the preprocessing layers chain? Are we going to have some important casting overhead as we have profiled at #165 (comment)? |
if the value range for all the image are (0, 255) for the input image, then it will benefit from this change (since it skip the conversion). When the input range are (0, 1), then any chain of the layer will keep converting them between (0, 1) and (0, 255). We might want to add some warning to user if this method has been hit several times in the pipeline, and suggest user to pre-convert their value to (0, 255). Anyway, this method is always good to have since it always skip the unnecessary conversion. |
Can we find a way to communicate/propagate between KLP layers that the range is converted? |
As I suppose that a chain of transformations is a quite common use case and it could be the O() of the casting overhead. |
* Implement check to skip value range transform * add simple test case * Run format
* Implement check to skip value range transform * add simple test case * Run format
int64 is not a supported type in jax bye default. Trying to use it gets the following error. UserWarning: Explicitly requested dtype int64 requested in array is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more. target = jnp.array(target, dtype="int64") We can stick to int32 (which is what will be used anyway).
No description provided.