Skip to content

Commit

Permalink
Avoid placing iterator ops on GPUs.
Browse files Browse the repository at this point in the history
We've seen errors in multi-GPU settings where iterator ops placed on GPUs
attempt to copy dataset/iterator variants to the GPU, which is not supported,
causing errors to be raised. Likely TFF's explicit device logic is interfering
with TF device placement logic at runtime. However it was not sufficient to
simply skip setting `device` attributes on these ops, it was required to
explicitly set them to the CPU.

This was found using FilePerUserClientData based datasets, for example the
FLAIR implementation.

PiperOrigin-RevId: 522655638
  • Loading branch information
ZacharyGarrett authored and tensorflow-copybara committed Apr 7, 2023
1 parent d63d36b commit 08c5fd0
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions tensorflow_federated/cc/core/impl/executors/session_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License
#include <vector>

#include "absl/container/flat_hash_set.h"
#include "absl/strings/match.h"
#include "absl/strings/str_join.h"
#include "absl/strings/str_split.h"
#include "absl/synchronization/mutex.h"
Expand Down Expand Up @@ -137,13 +138,22 @@ void SetDevice(std::string_view device, tensorflow::GraphDef* graph_def,
<< ") on device [" << device << "]"
<< "and marking for compilation on device type [" << device_type
<< "]";
}
if (!node_pb.device().empty()) {
} else if (absl::StartsWith(node_pb.op(), "IteratorGetNext") ||
node_pb.op() == "MakeIterator" ||
absl::StartsWith(node_pb.op(), "AnonymousIteratorV")) {
// TODO(b/276782974): We must avoid forcing the Iterator ops on the GPU,
// which will happen below because GPU kernels exist. TF will determine
// that the iterator is on the host and correctly place the node for us,
// but this will cause issues if we eagerly put the GetNext on the
// accelerator divce here.
VLOG(5) << "Forcing iterator op to CPU [" << node_pb.name() << "]";
node_pb.set_device(
absl::StrCat("/device:", tensorflow::DEVICE_CPU, ":0"));
} else if (!node_pb.device().empty()) {
VLOG(5) << "Skipping already placed node [" << node_pb.name() << "] ("
<< node_pb.op() << ") on " << node_pb.device();
continue;
// Note: Don't place general ops directly on TPU.
} else if (tensorflow::KernelDefAvailable(device_type, node_pb) &&
// Note: Don't place general ops directly on TPU.
strcmp(device_type, tensorflow::DEVICE_TPU) != 0) {
VLOG(5) << "Placing node [" << node_pb.name() << "] (" << node_pb.op()
<< ") on device [" << device << "]";
Expand Down

0 comments on commit 08c5fd0

Please sign in to comment.