Skip to content

Commit f346fd0

Browse files
hyeontaekGoogle-ML-Automation
authored andcommitted
[JAX] Use xla::ifrt::Client::MakeArraysFromHostBufferShards() in Array creation when possible
This changes makes use of the new `xla::ifrt::Client::MakeArraysFromHostBufferShards()` API when possible. This API needs a single call to create a multi-shard IFRT Array (to be wrapped as a JAX `PyArray`), which provides more optimization opportunities for the runtime than creating single-device IFRT Arrays and then assembling them. Please note that `xla::ifrt::Client::MakeArraysFromHostBufferShards()` implementation in PjRt-IFRT is not yet optimized, so there is no immediate performance benefits for McJAX. As an exception, it takes the previous path of array assembly if any shard for `BatchedDevicePut` is not a host buffer, but already a single-device array, because `xla::ifrt::Client::MakeArraysFromHostBufferShards()` works only if all the sharded input to be host buffers. With batching possible at IFRT level, we now skip `DevicePutResultFn` step; `DevicePut` (now `DevicePutWithDevice` and `DevicePutWithSharding`) internally calls per-shard functions (with GIL released) and returns a final IFRT Array. This change includes a code cleanup for `xla::DevicePutResult::owning_pybuffer`, which was originally intended to hold a Python object to keep an IFRT Array valid when it is created from `DevicePut()` implementations, but this role has been entirely covered by `on_done_with_host_buffer` function supplied at IFRT Array creation time. PiperOrigin-RevId: 749989229
1 parent d0b6eb2 commit f346fd0

File tree

8 files changed

+587
-387
lines changed

8 files changed

+587
-387
lines changed

jaxlib/xla/pjit.cc

+18-24
Original file line numberDiff line numberDiff line change
@@ -434,13 +434,15 @@ void CallShardArgFallback(
434434
// Prepares the input PjRtBuffers from the python arguments. This is equivalent
435435
// to shard_args() in pxla.py but for only a few supported cases.
436436
absl::StatusOr<std::vector<tsl::RCReference<xla::ifrt::Array>>>
437-
PrepareIfrtInputs(const xla::PyLoadedExecutable& executable,
438-
absl::Span<nb::object const> flat_dynamic_args,
439-
bool enable_x64, const std::vector<bool>& kept_args,
440-
const std::vector<nb::object>& in_shardings,
441-
const std::vector<nb::object>& in_device_local_layouts,
442-
const nb::callable& shard_arg_fallback,
443-
std::vector<nb::object>& keep_alive_objects) {
437+
PrepareIfrtInputs(
438+
const xla::PyLoadedExecutable& executable,
439+
absl::Span<nb::object const> flat_dynamic_args,
440+
absl::Span<xla::PyArgSignature const> flat_dynamic_arg_signatures,
441+
bool enable_x64, const std::vector<bool>& kept_args,
442+
const std::vector<nb::object>& in_shardings,
443+
const std::vector<nb::object>& in_device_local_layouts,
444+
const nb::callable& shard_arg_fallback,
445+
std::vector<nb::object>& keep_alive_objects) {
444446
const auto& addressable_devices =
445447
executable.ifrt_loaded_executable()->addressable_devices();
446448
const auto& num_global_devices =
@@ -484,20 +486,11 @@ PrepareIfrtInputs(const xla::PyLoadedExecutable& executable,
484486
TF_RETURN_IF_ERROR(
485487
jax::ApplyTransferGuardToHostToDevice(transfer_guard_formatter));
486488
TF_ASSIGN_OR_RETURN(
487-
auto on_device_fn,
488-
DevicePut(arg, executable.ifrt_loaded_executable()->client(),
489-
data_device, options, xla::ifrt::MemoryKind()));
490-
TF_ASSIGN_OR_RETURN(xla::DevicePutResult on_device, [&]() {
491-
// Must release the GIL before calling IFRT because backends may
492-
// decide to block/sleep for device buffer allocation.
493-
nb::gil_scoped_release gil_release;
494-
return std::move(on_device_fn)();
495-
}());
496-
497-
num_args_arrays.push_back(std::move(on_device.ifrt_array));
498-
if (on_device.owning_pybuffer) {
499-
keep_alive_objects.push_back(std::move(on_device.owning_pybuffer));
500-
}
489+
auto device_put_result,
490+
DevicePutWithDevice(arg,
491+
executable.ifrt_loaded_executable()->client(),
492+
data_device, xla::ifrt::MemoryKind(), options));
493+
num_args_arrays.push_back(std::move(device_put_result.ifrt_array));
501494
continue;
502495
} else {
503496
CallShardArgFallback(arg, in_shardings[dce_index],
@@ -750,9 +743,10 @@ absl::StatusOr<nb::object> PjitFunction::Call(nb::handle callable,
750743
// A vector of [num_inputs].
751744
auto num_args_arrays = PrepareIfrtInputs(
752745
*cache_entry->executable, flat_dynamic_args,
753-
call_signature.jax_enable_x64, cache_entry->kept_var_bitvec,
754-
cache_entry->in_shardings, cache_entry->in_device_local_layouts,
755-
shard_arg_fallback_, keep_alive_objects);
746+
call_signature.dynamic_arg_signatures, call_signature.jax_enable_x64,
747+
cache_entry->kept_var_bitvec, cache_entry->in_shardings,
748+
cache_entry->in_device_local_layouts, shard_arg_fallback_,
749+
keep_alive_objects);
756750

757751
if (!num_args_arrays.ok()) {
758752
VLOG(2) << "Failed to prepare IFRT inputs: " << num_args_arrays.status();

jaxlib/xla/pmap_lib.cc

+16-56
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ limitations under the License.
5656
#include "jaxlib/xla/pytree.h"
5757
#include "jaxlib/xla/sharded_device_array.h"
5858
#include "jaxlib/xla/sharding.h"
59-
#include "jaxlib/xla/to_ifrt_sharding.h"
6059
#include "jaxlib/xla/traceback.h"
6160
#include "xla/pjrt/exceptions.h"
6261
#include "xla/pjrt/status_casters.h"
@@ -65,7 +64,6 @@ limitations under the License.
6564
#include "xla/python/ifrt/device_list.h"
6665
#include "xla/python/ifrt/executable.h"
6766
#include "xla/python/ifrt/memory.h"
68-
#include "xla/python/ifrt/shape.h"
6967
#include "xla/python/ifrt/sharding.h"
7068
#include "xla/python/nb_helpers.h"
7169
#include "xla/python/nb_numpy.h"
@@ -186,74 +184,36 @@ absl::StatusOr<ShardArgResult> ShardArg(
186184
indices.size(), n_devices);
187185
}
188186

189-
std::vector<tsl::RCReference<xla::ifrt::Array>> per_device_arrays;
190-
per_device_arrays.reserve(n_devices);
191-
absl::InlinedVector<xla::ifrt::Device*, 1> devices;
192-
devices.reserve(n_devices);
193-
// TODO(hyeontaek): The created array will never be disassembled. We should
194-
// omit collecting shapes and make the OpaqueSharding non-disassemblable?
195-
std::vector<xla::ifrt::Shape> shapes;
196-
shapes.reserve(n_devices);
197-
198-
nb::list owning_pylist;
199187
ShardArgResult result;
200-
result.owning_sda = owning_pylist;
201188
const bool jax_enable_x64 = GetEnableX64();
202189

203-
std::vector<xla::DevicePutResultFn> device_put_fns;
204-
device_put_fns.reserve(n_devices);
190+
std::vector<nb::object> owning_args;
191+
std::vector<nb::handle> args;
192+
owning_args.reserve(n_devices);
193+
args.reserve(n_devices);
205194
xla::DevicePutOptions options;
206195
options.squash_64bit_types = !jax_enable_x64;
207196
options.allow_zero_copy = true;
197+
xla::ifrt::Client* ifrt_client = nullptr;
208198
for (size_t i = 0; i < n_devices; ++i) {
209199
auto to_device = nb::cast<xla::PyDevice*>(py_devices_list[i]);
210200
if (to_device->client().get() == nullptr) {
211201
return xla::InvalidArgument("Cannot copy to unattached devices.");
212202
}
213-
214-
TF_ASSIGN_OR_RETURN(
215-
device_put_fns.emplace_back(),
216-
DevicePut(arg[indices[i]], to_device->client()->ifrt_client(),
217-
to_device->device(), options, xla::ifrt::MemoryKind()));
218-
}
219-
std::vector<xla::DevicePutResult> device_puts;
220-
device_puts.reserve(n_devices);
221-
{
222-
nb::gil_scoped_release gil_release;
223-
for (auto& device_put_fn : device_put_fns) {
224-
TF_ASSIGN_OR_RETURN(auto device_put, std::move(device_put_fn)());
225-
device_puts.push_back(std::move(device_put));
226-
}
227-
}
228-
for (auto& device_put : device_puts) {
229-
per_device_arrays.push_back(std::move(device_put.ifrt_array));
230-
devices.push_back(
231-
per_device_arrays.back()->sharding().devices()->devices().front());
232-
shapes.push_back(per_device_arrays.back()->shape());
233-
if (device_put.owning_pybuffer) {
234-
owning_pylist.append(device_put.owning_pybuffer);
203+
if (i == 0) {
204+
ifrt_client = to_device->client()->ifrt_client();
235205
}
206+
owning_args.push_back(arg[indices[i]]);
207+
args.push_back(owning_args.back());
236208
}
237-
238-
if (per_device_arrays.empty()) {
239-
return xla::InvalidArgument("Per-device arrays must not be empty.");
240-
}
241-
// TODO(hyeontaek): The logical shape here is inaccurate. We
242-
// may want to avoid creating a new Array or specialize Array
243-
// to disallow access to the logical shape.
244-
xla::ifrt::Shape shape = per_device_arrays.front()->shape();
245-
TF_ASSIGN_OR_RETURN(
246-
auto ifrt_sharding,
247-
xla::GetIfrtConcreteSharding(input_spec.array_sharding, shape, shapes));
209+
CHECK(ifrt_client != nullptr);
248210
TF_ASSIGN_OR_RETURN(
249-
result.ifrt_array,
250-
per_device_arrays.front()
251-
->client()
252-
->AssembleArrayFromSingleDeviceArrays(
253-
std::move(shape), std::move(ifrt_sharding),
254-
absl::MakeSpan(per_device_arrays),
255-
xla::ifrt::ArrayCopySemantics::kReuseInput,
256-
xla::ifrt::SingleDeviceShardSemantics::kAddressableShards));
211+
xla::DevicePutResult device_put_result,
212+
xla::DevicePutWithSharding(
213+
args, ifrt_client, ndarray.dtype(),
214+
nb::cast<std::vector<int64_t>>(ndarray.attr("shape")),
215+
input_spec.array_sharding, options));
216+
result.ifrt_array = std::move(device_put_result.ifrt_array);
257217
return result;
258218
}
259219
tsl::profiler::TraceMe traceme("pmap_lib_shard_arg_python_fallback");

jaxlib/xla/py_array.cc

+13-61
Original file line numberDiff line numberDiff line change
@@ -1257,89 +1257,41 @@ absl::StatusOr<PyArray> PyArray::BatchedDevicePut(
12571257
options.allow_zero_copy =
12581258
(!force_copy && (host_buffer_semantics ==
12591259
ifrt::Client::HostBufferSemantics::kImmutableZeroCopy));
1260-
if (!dst_devices.empty()) {
1261-
options.ifrt_user_context =
1262-
dst_devices.front()->client()->ifrt_client()->CreateUserContext();
1263-
}
12641260

1265-
nb::list owning_pylist;
12661261
std::vector<tsl::RCReference<ifrt::Array>> ifrt_arrays;
12671262

12681263
absl::InlinedVector<ifrt::Device*, 1> devices;
12691264
devices.reserve(n_devices);
12701265
std::vector<xla::ifrt::Shape> shapes;
12711266
shapes.reserve(n_devices);
12721267

1273-
ifrt::MemoryKind dst_memory_kind = xla::GetMemoryKind(sharding);
1274-
1275-
std::vector<DevicePutResultFn> device_put_fns;
1276-
device_put_fns.reserve(xs.size());
1277-
size_t i = 0;
1278-
for (auto& x : xs) {
1268+
std::vector<nb::handle> args;
1269+
args.reserve(xs.size());
1270+
for (const nb::object& x : xs) {
12791271
if (PyArray::IsPyArray(x)) {
12801272
TF_RETURN_IF_ERROR(
12811273
jax::ApplyTransferGuardToDeviceToDevice(transfer_guard_formatter));
12821274
} else {
12831275
TF_RETURN_IF_ERROR(
12841276
jax::ApplyTransferGuardToHostToDevice(transfer_guard_formatter));
12851277
}
1286-
TF_ASSIGN_OR_RETURN(
1287-
device_put_fns.emplace_back(),
1288-
DevicePut(x, dst_devices[i]->client()->ifrt_client(),
1289-
dst_devices[i]->device(), options, dst_memory_kind));
1290-
++i;
1291-
}
1292-
std::vector<DevicePutResult> device_puts;
1293-
device_puts.reserve(device_put_fns.size());
1294-
{
1295-
nb::gil_scoped_release gil_release;
1296-
for (auto& device_put_fn : device_put_fns) {
1297-
TF_ASSIGN_OR_RETURN(auto device_put, std::move(device_put_fn)());
1298-
device_puts.push_back(std::move(device_put));
1299-
}
1300-
}
1301-
for (auto& device_put : device_puts) {
1302-
ifrt_arrays.push_back(std::move(device_put.ifrt_array));
1303-
devices.push_back(
1304-
ifrt_arrays.back()->sharding().devices()->devices().front());
1305-
shapes.push_back(ifrt_arrays.back()->shape());
1306-
if (device_put.owning_pybuffer) {
1307-
owning_pylist.append(device_put.owning_pybuffer);
1308-
}
1278+
args.push_back(x);
13091279
}
1310-
1311-
// TODO(phawkins): it's highly suspicious to me that owning_pylist isn't
1312-
// consumed here. Look into this.
1313-
13141280
auto weak_type = nb::cast<bool>(aval.attr("weak_type"));
13151281
auto dtype = aval.attr("dtype");
13161282
auto shape = nb::cast<std::vector<int64_t>>(aval.attr("shape"));
1283+
TF_ASSIGN_OR_RETURN(nb_class_ptr<jax::PyDeviceList> py_device_list,
1284+
jax::GetPyDeviceList(sharding));
13171285

13181286
TF_ASSIGN_OR_RETURN(
1319-
auto ifrt_sharding,
1320-
sharding.type().is(jax::PmapSharding::type())
1321-
? xla::GetIfrtConcreteSharding(sharding, ifrt::Shape(shape),
1322-
std::move(shapes))
1323-
: xla::GetIfrtHloSharding(sharding, ifrt::Shape(shape)));
1324-
TF_ASSIGN_OR_RETURN(auto ifrt_dtype, DtypeToIfRtDType(dtype));
1325-
// TODO(emilyaf): Remove the following and just use ifrt_dtype when tokens are
1326-
// supported.
1327-
ifrt::DType array_dtype =
1328-
ifrt_arrays.empty() ? ifrt_dtype : ifrt_arrays.front()->dtype();
1329-
TF_ASSIGN_OR_RETURN(auto py_device_list, jax::GetPyDeviceList(sharding));
1330-
TF_ASSIGN_OR_RETURN(
1331-
auto ifrt_array,
1332-
py_device_list->py_client()
1333-
->ifrt_client()
1334-
->AssembleArrayFromSingleDeviceArrays(
1335-
array_dtype, ifrt::Shape(shape), std::move(ifrt_sharding),
1336-
absl::MakeSpan(ifrt_arrays),
1337-
xla::ifrt::ArrayCopySemantics::kReuseInput,
1338-
xla::ifrt::SingleDeviceShardSemantics::kAddressableShards));
1339-
1340-
return PyArray(aval, weak_type, dtype, std::move(shape), sharding,
1287+
DevicePutResult device_put_result,
1288+
DevicePutWithSharding(args, py_device_list->py_client()->ifrt_client(),
1289+
dtype, shape, sharding, options));
1290+
1291+
return PyArray(aval, weak_type, dtype, std::move(shape), std::move(sharding),
13411292
py_device_list->py_client(), Traceback::Get(),
1342-
std::move(ifrt_array), committed, /*skip_checks=*/true);
1293+
std::move(device_put_result.ifrt_array), committed,
1294+
/*skip_checks=*/true);
13431295
}
13441296

13451297
absl::StatusOr<PyArray> PyArray::ReorderShards(

jaxlib/xla/py_client.cc

+15-19
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ limitations under the License.
5757
#include "jaxlib/xla/py_memory_space.h"
5858
#include "jaxlib/xla/py_values.h"
5959
#include "jaxlib/xla/python_ref_manager.h"
60+
#include "jaxlib/xla/sharding.h"
6061
#include "jaxlib/xla/traceback.h"
6162
#include "xla/literal.h"
6263
#include "xla/pjrt/exceptions.h"
@@ -66,6 +67,7 @@ limitations under the License.
6667
#include "xla/pjrt/pjrt_executable.h"
6768
#include "xla/pjrt/pjrt_layout.h"
6869
#include "xla/pjrt/status_casters.h"
70+
#include "xla/python/ifrt/array.h"
6971
#include "xla/python/ifrt/client.h"
7072
#include "xla/python/ifrt/compiler.h"
7173
#include "xla/python/ifrt/device.h"
@@ -339,25 +341,19 @@ absl::Status PyClient::Defragment() {
339341
options.allow_zero_copy =
340342
(!force_copy && (host_buffer_semantics ==
341343
ifrt::Client::HostBufferSemantics::kImmutableZeroCopy));
342-
TF_ASSIGN_OR_RETURN(auto put_fn,
343-
DevicePut(argument, client->ifrt_client_.get(), device,
344-
options, ifrt::MemoryKind()));
345-
TF_ASSIGN_OR_RETURN(auto put, [&]() {
346-
// Must release the GIL before calling IFRT because backends may
347-
// decide to block/sleep for device buffer allocation.
348-
nb::gil_scoped_release gil_release;
349-
return std::move(put_fn)();
350-
}());
351-
352-
if (put.ifrt_array) {
353-
auto traceback = Traceback::Get();
354-
return PyArray::MakeFromSingleDeviceArray(
355-
std::move(client), std::move(traceback), std::move(put.ifrt_array),
356-
/*weak_type=*/false,
357-
/*committed=*/false);
358-
} else {
359-
return put.owning_pybuffer;
360-
}
344+
TF_ASSIGN_OR_RETURN(DevicePutResult device_put_result,
345+
DevicePutWithDevice(argument, client->ifrt_client_.get(),
346+
device, ifrt::MemoryKind(), options));
347+
auto sharding = make_nb_class<jax::SingleDeviceSharding>(
348+
client, client->ifrt_client()->MakeDeviceList({device}),
349+
/*memory_kind=*/nb::none());
350+
351+
auto traceback = Traceback::Get();
352+
return PyArray::MakeFromIfrtArrayAndSharding(
353+
std::move(client), std::move(traceback),
354+
std::move(device_put_result.ifrt_array), std::move(sharding),
355+
/*weak_type=*/false, /*committed=*/false,
356+
/*skip_checks=*/true);
361357
}
362358

363359
namespace {

0 commit comments

Comments
 (0)