Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support cpu tensors without direct device invocation #3842

Merged
merged 56 commits into from
Jan 5, 2024

Conversation

abhilash1910
Copy link
Contributor

@abhilash1910 abhilash1910 commented Jun 29, 2023

Motivation:
Fix for reproducible issue #3837 on cpu. On cpus direct invocation of torch.cpu.tensor leads to dtype mismatch.
Another way would be to have something like :
["torch.DoubleTensor" if device_type == 'cpu else '"torch.{}.DoubleTensor".format(device_type)] for all elements in the supported list , but that would eliminate "torch.cpu.DoubleTensor" ,etc from the scope.
@jeffra requesting review.

CLA is signed

@tjruwase
Copy link
Contributor

@abhilash1910, thanks for this PR. I think this PR needs some work that leverages PR #3633 for the following reasons.

  1. As you observed, strings like torch.cpu.DoubleTensor are invalid, and the underlying issue is that existing code is now broken with the support for accelerators other than cuda. So, it will be better to remove code that generates this invalid strings.
  2. [CPU] Skip CPU support unimplemented error #3633 is adding API for querying the dtypes supported by the accelerator, and it seems to me that this entire logic could be rewritten to use that API. I think this API is less brittle since it does not rely on string formatting, and is easier to maintain since the respective accelerator owners can update their list of supported dtypes.

Please share your thoughts. Thanks!

@abhilash1910
Copy link
Contributor Author

abhilash1910 commented Jun 30, 2023

@tjruwase yes I think that would be a proper fix , instead of having separate dtypes, we can directly leverage abstract accelerator interface. Let me go through the changes for this . Thanks (Making this draft for now).
I see the changes from my colleague Yejing, will discuss and work on this .

@abhilash1910 abhilash1910 marked this pull request as draft June 30, 2023 16:31
@tjruwase
Copy link
Contributor

@abhilash1910, thanks for your alignment. I will push to get #3633 merged asap. I left some comments on there.

@abhilash1910
Copy link
Contributor Author

Yes sure, I will work with my colleague Yejing to make this work.

@delock
Copy link
Collaborator

delock commented Jul 17, 2023

@abhilash1910, thanks for this PR. I think this PR needs some work that leverages PR #3633 for the following reasons.

  1. As you observed, strings like torch.cpu.DoubleTensor are invalid, and the underlying issue is that existing code is now broken with the support for accelerators other than cuda. So, it will be better to remove code that generates this invalid strings.
  2. [CPU] Skip CPU support unimplemented error #3633 is adding API for querying the dtypes supported by the accelerator, and it seems to me that this entire logic could be rewritten to use that API. I think this API is less brittle since it does not rely on string formatting, and is easier to maintain since the respective accelerator owners can update their list of supported dtypes.

Please share your thoughts. Thanks!

@tjruwase I have a question. For the following check, is the check focusing on the data type, or data type+device type is needed? If only data type is important, then proper way is strip device type from t.type(), then compare with data type list.

@tjruwase
Copy link
Contributor

@tjruwase I have a question. For the following check, is the check focusing on the data type, or data type+device type is needed? If only data type is important, then proper way is strip device type from t.type(), then compare with data type list.

Yes, the focus is on checking the data type supported by the device. My feedback is based on avoiding any assumptions of string format combinations of device and data type. For example as shown in this list, the same dtype is formatted differently for cpu and cuda tensors. So, I
suggested relying on torch dtype which is canonical, and which has an accelerator API to retrieve.

Please let me know your thoughts.

@delock
Copy link
Collaborator

delock commented Jul 18, 2023

Yes, dtype is better. Some additional changed in _reduce_non_expert_gradients and _reduce_expert_gradients will be needed accordingly.

@tjruwase
Copy link
Contributor

Yes, dtype is better. Some additional changed in _reduce_non_expert_gradients and _reduce_expert_gradients will be needed accordingly.

@delock, thanks for the pointer. @abhilash1910, could you please help handle those changes in your PR?

@abhilash1910
Copy link
Contributor Author

@tjruwase @delock _reduce_non_expert_gradients and _reduce_expert_gradients uses SparseTensor class, would it make sense to replace type with dtype there ?

@tjruwase
Copy link
Contributor

@abhilash1910, you raise an important issue with my proposal that I had overlooked. This is that SparseTensor.type() is a string while torch.dtype is not. My proposal would result in populating supported_types with objects of different types, which creates all kinds of complications.
image

One idea is to populate the list with [f'{t}' for t in accelerator.supported_dtypes()] to harmonize the object types. But the problem is that while SparseTensor.type() will work for lookups, neither torch.Tensor.type() nor torch.Tensor.dtype will work. In other words, we will always need different lookup logic for torch tensors on one hand, and deepspeed sparse_tensors on the other hand.

Another potential issue with existing code is that we don't check whether the dtype of the underlying tensor of a SparseTensor is supported by the accelerator. I need to double check this concern with my teammates.

Therefore, I am now wondering it would better to take a new approach that makes this difference explicit, and uses isinstance(t, torch.Tensor) and isinstance(t, SparseTensor) appropriately to guard the lookups. I think this will make the code easier to maintain since the logic is more explicit. Also, I think is more intuitive for split_half_float_double_sparse() function to return separate buckets for sparse and dense tensors for the callers to handle appropriately.

These are just some thoughts. Please let me know what you think. Thanks!

@jomayeri jomayeri linked an issue Jul 20, 2023 that may be closed by this pull request
deepspeed/runtime/engine.py Outdated Show resolved Hide resolved
@abhilash1910
Copy link
Contributor Author

abhilash1910 commented Jul 21, 2023

@abhilash1910, you raise an important issue with my proposal that I had overlooked. This is that SparseTensor.type() is a string while torch.dtype is not. My proposal would result in populating supported_types with objects of different types, which creates all kinds of complications. image

One idea is to populate the list with [f'{t}' for t in accelerator.supported_dtypes()] to harmonize the object types. But the problem is that while SparseTensor.type() will work for lookups, neither torch.Tensor.type() nor torch.Tensor.dtype will work. In other words, we will always need different lookup logic for torch tensors on one hand, and deepspeed sparse_tensors on the other hand.

Another potential issue with existing code is that we don't check whether the dtype of the underlying tensor of a SparseTensor is supported by the accelerator. I need to double check this concern with my teammates.

Therefore, I am now wondering it would better to take a new approach that makes this difference explicit, and uses isinstance(t, torch.Tensor) and isinstance(t, SparseTensor) appropriately to guard the lookups. I think this will make the code easier to maintain since the logic is more explicit. Also, I think is more intuitive for split_half_float_double_sparse() function to return separate buckets for sparse and dense tensors for the callers to handle appropriately.

These are just some thoughts. Please let me know what you think. Thanks!

Yes my thoughts exactly. I was thinking of adding a dtype getter inside the sparsetensor to make it consistent
But checking the accelerator device type during getting gradients for reduction should be present .@tjruwase @jeffra . Let me make the additions.

@abhilash1910 abhilash1910 marked this pull request as ready for review July 21, 2023 11:40
@abhilash1910
Copy link
Contributor Author

@tjruwase @jeffra does _reduce_non_expert_gradients and _reduce_expert_gradients also need dtype changes , since we are adding the attribute . Also not sure why the CLA is appearing.
@microsoft-github-policy-service agree company="Intel"

@abhilash1910
Copy link
Contributor Author

@tjruwase could you retrigger CI (issue seems to be fix now)? Thanks.

@abhilash1910
Copy link
Contributor Author

@tjruwase could you help re-trigger the CI and re-review ? Much appreciated.

@delock
Copy link
Collaborator

delock commented Dec 8, 2023

Hi @abhilash1910 can you clarify whether current failures in CI is related to your PR or just a test issue? Thanks!

@abhilash1910
Copy link
Contributor Author

@delock I think that it might be a test issue as I am able to run the CI for the sparse test locally. I changed the pathway of code and still I see the same allclose issue. @tjruwase could you suggest any modifications on this? This is strange as I tested in an isolated env and did not get the issue.

@delock
Copy link
Collaborator

delock commented Dec 8, 2023

Hi @abhilash1910 some suggestions:

  1. provide more details (hw, sw, log ...) of your local run so there might be hint of difference.
  2. try to modify the test as a standalone workload (not using pytest) so debugging could be possible
  3. seperate assert for dense and sparse grad so you will know which type of tensor have difference. This might narrow down the location of possible bug.

@abhilash1910
Copy link
Contributor Author

Thanks @inkcherry for highlighting the boundary issue ; seems it will pass the CI now.
@tjruwase could you help retrigger the CI ? Thanks

@inkcherry
Copy link
Contributor

Thanks @inkcherry for highlighting the boundary issue ; seems it will pass the CI now. @tjruwase could you help retrigger the CI ? Thanks

I could reproduce the CI issue in my local env and It could passed currently.
FYI: @delock

@delock
Copy link
Collaborator

delock commented Dec 19, 2023

Hi @tjruwase the previous error in CI workflow unit/runtime/sparse_tensor/test_averaging_sparse_gradients.py should be fixed by commit from @inkcherry. Can you help restart the workflow? Thanks!

@delock
Copy link
Collaborator

delock commented Dec 22, 2023

Hi @abhilash1910 are the following two errors related to your change?

FAILED unit/inference/quantization/test_intX_quantization.py::TestQuantizedInt::test_zero3_int4_post_init_quant_nvme_offload
FAILED unit/inference/quantization/test_intX_quantization.py::TestQuantizedInt::test_zero3_int4_quantized_initialization_nvme_offload

@abhilash1910
Copy link
Contributor Author

@delock I think this failure is related to this PR , but it seems to be arising after the previous fix . I will take a look at it .

@tjruwase tjruwase added this pull request to the merge queue Jan 5, 2024
Merged via the queue into deepspeedai:master with commit c84c28d Jan 5, 2024
14 checks passed
mauryaavinash95 pushed a commit to mauryaavinash95/DeepSpeed that referenced this pull request Feb 17, 2024
Motivation:
Fix for reproducible issue deepspeedai#3837 on cpu. On cpus direct invocation of
torch.cpu.tensor leads to dtype mismatch.
Another way would be to have something like :
["torch.DoubleTensor" if device_type == 'cpu else
'"torch.{}.DoubleTensor".format(device_type)] for all elements in the
supported list , but that would eliminate "torch.cpu.DoubleTensor" ,etc
from the scope.
@jeffra requesting review.
  
CLA is signed

---------

Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: inkcherry <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[BUG] Incorrect type check in engine.py for CPU training
5 participants