Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Will OpenXLA support DLPack in PRJT C API? #6151

Closed
Zantares opened this issue Oct 7, 2023 · 4 comments
Closed

Will OpenXLA support DLPack in PRJT C API? #6151

Zantares opened this issue Oct 7, 2023 · 4 comments
Labels
question Further information is requested

Comments

@Zantares
Copy link

Zantares commented Oct 7, 2023

After integrated extension for OpenXLA through PJRT C API, we found that DLPack is not supported by PJRT yet:

StatusOr<std::unique_ptr<ExternalReference>> ReleaseDeviceMemoryOwnership(
bool wait_for_operations_to_complete) override {
return Unimplemented(
"PJRT C API does not support ReleaseDeviceMemoryOwnership");
}

Does the community have any plan to support it since some users are interested in this feature?

@hawkinsp
Copy link
Member

hawkinsp commented Oct 9, 2023

I'm guessing this relates to wanting DLPack support in JAX, not necessarily to XLA itself?

@skye and @jyingl3 are actually looking at this. I believe their plan is to simplify JAX so it does not call this API, see jax-ml/jax#17941

If this question doesn't pertain to JAX, can you clarify which DLPack support you mean?

@skye
Copy link
Contributor

skye commented Oct 9, 2023

Like Peter says, we don't believe this API is necessary to implement DLPack. We think you should be able to use AcquireExternalReference instead:

xla/xla/pjrt/pjrt_client.h

Lines 967 to 984 in 59ca120

class ExternalReference {
public:
virtual ~ExternalReference() = 0;
// Return opaque device memory pointer to root buffer.
void* OpaqueDeviceMemoryDataPointer() const { return data_ptr_; }
// Stream is platform-specific. This is intended to support dlpack on GPU
// and is not expected to be implemented for all hardware platforms.
virtual Status WaitUntilBufferReadyOnStream(std::intptr_t stream) {
return Unimplemented(
"WaitUntilBufferReadyOnStream is only implemented for GPU.");
}
protected:
void* data_ptr_;
};
virtual StatusOr<std::unique_ptr<ExternalReference>>
AcquireExternalReference() = 0;

which is implemented in the PJRT C API.

The main difference between AcquireExternalReference and ReleaseDeviceMemoryOwnership is that the latter makes it an error to use the released buffer with the original client. We didn't think it was worth having a separate API where the only difference is error checking, but let us know if you disagree.

I believe @jyingl3 is also working on CreateViewOfDeviceBuffer support:

StatusOr<std::unique_ptr<PjRtBuffer>> CreateViewOfDeviceBuffer(
void* device_ptr, const Shape& shape, PjRtDevice* device,
std::function<void()> on_delete_callback,
std::optional<std::intptr_t> stream) override {
return Unimplemented(
"PJRT C API does not support CreateViewOfDeviceBuffer");
}

This lets you implement the opposite direction of DLPack (accepting an external buffer, vs. the external framework using a buffer from the original client).

Like Peter says, we'll make sure JAX uses just these APIs, so they should be all you need to interface with JAX in particular.

@Zantares
Copy link
Author

Thanks for the answers, they are really useful.

I'm guessing this relates to wanting DLPack support in JAX, not necessarily to XLA itself?

@skye and @jyingl3 are actually looking at this. I believe their plan is to simplify JAX so it does not call this API, see google/jax#17941

If this question doesn't pertain to JAX, can you clarify which DLPack support you mean?

Yes we want DL support in JAX, and I asked this question because have seen most changes of the code are happening in XLA.

The reason why I mention function ReleaseDeviceMemoryOwnership() is because we had a simple experiment in JAX to test DLPack, then we fell into the assert of ReleaseDeviceMemoryOwnership(). I have seen that new codes related to DLPack are merged to main branch, we will try it and try to understand the usage of AcquireExternalReference(). I will give the feedback in a few days later so please keep this issue for a while, thanks!

@Zantares
Copy link
Author

The missing APIs are all implemented and the extension worked well with internal patch after rebased to newest XLA, but I still have a question related to JAX here: Though new device can be added to JAX/XLA through PJRT, there's no extended device path supporting in DLPack path now. That's why internal patch is needed here, it looks like as below:

@@ -254,6 +256,12 @@ StatusOr<PjRtDevice*> DeviceForDLDevice(const PjRtClient* cpu_client,
       }
       TF_RET_CHECK(gpu_client->platform_id() == RocmId());
       return gpu_client->LookupAddressableDevice(context.device_id);
+    case kDLOneAPI:
+      if (gpu_client == nullptr)
+        return InvalidArgument(
+            "DLPack tensor is on extended device, but no backend was provided.");
+      return gpu_client->LookupAddressableDevice(context.device_id);
     default:
       return InvalidArgument("Unknown/unsupported DLPack device type %d",
                              context.device_type);

And new device named XPU (or other allowed name) is needed to be recognized in XLA:

--- a/xla/pjrt/pjrt_compiler.h
+++ b/xla/pjrt/pjrt_compiler.h
@@ -48,6 +48,10 @@ inline const char* TpuName() {
   static constexpr char kTpuName[] = "tpu";
   return kTpuName;
 }
+inline const char* XpuName() {
+  static constexpr char kXpuName[] = "xpu";
+  return kXpuName;
+}
 inline PjRtPlatformId CpuId() {
   static const PjRtPlatformId kCpuId = tsl::Fingerprint64(CpuName());
   return kCpuId;
@@ -64,6 +68,10 @@ inline PjRtPlatformId TpuId() {
   static const PjRtPlatformId kTpuId = tsl::Fingerprint64(TpuName());
   return kTpuId;
 }
+inline PjRtPlatformId XpuId() {
+  static const PjRtPlatformId kXpuId = tsl::Fingerprint64(XpuName());
+  return kXpuId;
+}

Can I direct submit this PR to XLA? Or any concern is here?

@penpornk penpornk added the question Further information is requested label Feb 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

4 participants