Skip to content

Commit

Permalink
Add sycl name in se_gpu_pjrt_client
Browse files Browse the repository at this point in the history
  • Loading branch information
ShengYang1 committed Mar 5, 2024
1 parent d6cfa24 commit c2f60c3
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 1 deletion.
2 changes: 2 additions & 0 deletions xla/pjrt/gpu/se_gpu_pjrt_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -971,6 +971,8 @@ StatusOr<std::unique_ptr<PjRtClient>> GetStreamExecutorGpuClient(
const GpuClientOptions& options) {
#if TENSORFLOW_USE_ROCM
auto pjrt_platform_name = xla::RocmName();
#elif TENSORFLOW_USE_SYCL
auto pjrt_platform_name = xla::SyclName();
#else // TENSORFLOW_USE_ROCM
auto pjrt_platform_name = xla::CudaName();
#endif // TENSORFLOW_USE_ROCM
Expand Down
3 changes: 2 additions & 1 deletion xla/pjrt/gpu/se_gpu_pjrt_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ namespace xla {
namespace {

bool IsGpuClient(const PjRtClient& client) {
return client.platform_id() == CudaId() || client.platform_id() == RocmId();
return client.platform_id() == CudaId() || client.platform_id() == RocmId() ||
client.platform_id() == SyclId();
}

bool IsSameTopology(const PjRtTopologyDescription& topology1,
Expand Down
8 changes: 8 additions & 0 deletions xla/pjrt/pjrt_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ inline const char* RocmName() {
static constexpr char kRocmName[] = "rocm";
return kRocmName;
}
inline const char* SyclName() {
static constexpr char kSyclName[] = "sycl";
return kSyclName;
}
inline const char* TpuName() {
static constexpr char kTpuName[] = "tpu";
return kTpuName;
Expand All @@ -60,6 +64,10 @@ inline PjRtPlatformId RocmId() {
static const PjRtPlatformId kRocmId = tsl::Fingerprint64(RocmName());
return kRocmId;
}
inline PjRtPlatformId SyclId() {
static const PjRtPlatformId kSyclId = tsl::Fingerprint64(SyclName());
return kSyclId;
}
inline PjRtPlatformId TpuId() {
static const PjRtPlatformId kTpuId = tsl::Fingerprint64(TpuName());
return kTpuId;
Expand Down

0 comments on commit c2f60c3

Please sign in to comment.