-
Notifications
You must be signed in to change notification settings - Fork 510
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Will OpenXLA support DLPack in PRJT C API? #6151
Comments
I'm guessing this relates to wanting DLPack support in JAX, not necessarily to XLA itself? @skye and @jyingl3 are actually looking at this. I believe their plan is to simplify JAX so it does not call this API, see jax-ml/jax#17941 If this question doesn't pertain to JAX, can you clarify which DLPack support you mean? |
Like Peter says, we don't believe this API is necessary to implement DLPack. We think you should be able to use Lines 967 to 984 in 59ca120
which is implemented in the PJRT C API. The main difference between I believe @jyingl3 is also working on xla/xla/pjrt/pjrt_c_api_client.h Lines 269 to 275 in 59ca120
This lets you implement the opposite direction of DLPack (accepting an external buffer, vs. the external framework using a buffer from the original client). Like Peter says, we'll make sure JAX uses just these APIs, so they should be all you need to interface with JAX in particular. |
Thanks for the answers, they are really useful.
Yes we want DL support in JAX, and I asked this question because have seen most changes of the code are happening in XLA. The reason why I mention function |
The missing APIs are all implemented and the extension worked well with internal patch after rebased to newest XLA, but I still have a question related to JAX here: Though new device can be added to JAX/XLA through PJRT, there's no extended device path supporting in DLPack path now. That's why internal patch is needed here, it looks like as below: @@ -254,6 +256,12 @@ StatusOr<PjRtDevice*> DeviceForDLDevice(const PjRtClient* cpu_client,
}
TF_RET_CHECK(gpu_client->platform_id() == RocmId());
return gpu_client->LookupAddressableDevice(context.device_id);
+ case kDLOneAPI:
+ if (gpu_client == nullptr)
+ return InvalidArgument(
+ "DLPack tensor is on extended device, but no backend was provided.");
+ return gpu_client->LookupAddressableDevice(context.device_id);
default:
return InvalidArgument("Unknown/unsupported DLPack device type %d",
context.device_type); And new device named --- a/xla/pjrt/pjrt_compiler.h
+++ b/xla/pjrt/pjrt_compiler.h
@@ -48,6 +48,10 @@ inline const char* TpuName() {
static constexpr char kTpuName[] = "tpu";
return kTpuName;
}
+inline const char* XpuName() {
+ static constexpr char kXpuName[] = "xpu";
+ return kXpuName;
+}
inline PjRtPlatformId CpuId() {
static const PjRtPlatformId kCpuId = tsl::Fingerprint64(CpuName());
return kCpuId;
@@ -64,6 +68,10 @@ inline PjRtPlatformId TpuId() {
static const PjRtPlatformId kTpuId = tsl::Fingerprint64(TpuName());
return kTpuId;
}
+inline PjRtPlatformId XpuId() {
+ static const PjRtPlatformId kXpuId = tsl::Fingerprint64(XpuName());
+ return kXpuId;
+} Can I direct submit this PR to XLA? Or any concern is here? |
After integrated extension for OpenXLA through PJRT C API, we found that DLPack is not supported by PJRT yet:
xla/xla/pjrt/pjrt_c_api_client.h
Lines 416 to 420 in 59ca120
Does the community have any plan to support it since some users are interested in this feature?
The text was updated successfully, but these errors were encountered: