-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Modify Torch Support to handle tensors on GPU #19
Conversation
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! This was an oversight, I must not have thoroughly tested this with GPU tensors.
That said, I think it's better to avoid copying the entire array to CPU memory at the beginning, because the array may be very large and we may only need to visualize a small portion of it. Would you mind changing this to only copy to CPU right before converting to numpy? (Or, I think .numpy(force=True)
accomplishes the same thing?)
Hi @danieldjohnson, Thanks for your response - Indeed I only just noticed the slicing! Looking into the Made the changes - not sure if this is needed in |
Thanks! One more minor thing: would you mind squashing these commits together into a single commit so that the old version doesn't appear in the commit history? |
Done :) |
treescope/external/torch_support.py
Outdated
@@ -74,12 +74,13 @@ def _truncate_and_copy( | |||
ignoring any axes whose slices are already computed in `source_slices`. | |||
""" | |||
assert torch is not None, "PyTorch is not available." | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, sorry, could you also amend to remove this line with extra whitespace? It's tripping up the internal lint system and I haven't set up the external lint checks yet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hopefully ok now!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
- Conversion to numpy is now done using `some_tensor.numpy(force=True)` - This ensures device conversion as well as some cases such as complex tensors
It seems the torch support assumed tensors exist on CPU, as this is a pre-requisite for conversion to numpy arrays with
some_tensor.numpy()
. Perhaps this was an intentional design choice to avoid occupying non-accelerator memory for users who aren't being careful, but moving non-cpu tensors onto cpu automatically for the purposes of visualization is likely acceptable in most use cases, so I have added this.Minor note that the .cpu() and .detach() conversion can be combined into one-line. Not sure what is preferable.