Skip to content

Commit

Permalink
clone(DistArray) supports device-based arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
evaleev committed Sep 28, 2023
1 parent 089dcc6 commit 4b79b5a
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 10 deletions.
32 changes: 26 additions & 6 deletions src/TiledArray/conversions/clone.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
#ifndef TILEDARRAY_CONVERSIONS_CLONE_H__INCLUDED
#define TILEDARRAY_CONVERSIONS_CLONE_H__INCLUDED

#ifdef TILEDARRAY_HAS_DEVICE
#include "TiledArray/device/device_task_fn.h"
#endif

namespace TiledArray {

/// Forward declarations
Expand Down Expand Up @@ -53,12 +57,28 @@ inline DistArray<Tile, Policy> clone(const DistArray<Tile, Policy>& arg) {
if (arg.is_zero(index)) continue;

// Spawn a task to clone the tiles
Future<value_type> tile = world.taskq.add(
[](const value_type& tile) -> value_type {
using TiledArray::clone;
return clone(tile);
},
arg.find(index));

Future<value_type> tile;
if constexpr (!detail::is_device_tile_v<value_type>) {
tile = world.taskq.add(
[](const value_type& tile) -> value_type {
using TiledArray::clone;
return clone(tile);
},
arg.find(index));
} else {
#ifdef TILEDARRAY_HAS_DEVICE
tile = madness::add_device_task(
world,
[](const value_type& tile) -> value_type {
using TiledArray::clone;
return clone(tile);
},
arg.find(index));
#else
abort(); // unreachable
#endif
}

// Store result tile
result.set(index, tile);
Expand Down
1 change: 1 addition & 0 deletions src/TiledArray/external/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -807,6 +807,7 @@ inline std::optional<Stream>*& madness_task_stream_opt_ptr_accessor() {

inline std::optional<Stream>& tls_stream_opt_accessor() {
static thread_local std::optional<Stream> stream_opt =

device::Env::instance()->stream(0);
return stream_opt;
}
Expand Down
5 changes: 1 addition & 4 deletions src/TiledArray/tensor/type_traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -314,8 +314,7 @@ template <typename ReduceOp, typename Result, typename... Args>
constexpr const bool is_reduce_op_v =
is_reduce_op_<void, ReduceOp, Result, Args...>::value;

/// detect cuda tile
#ifdef TILEDARRAY_HAS_DEVICE
/// detect device tile types
template <typename T>
struct is_device_tile : public std::false_type {};

Expand All @@ -329,8 +328,6 @@ struct is_device_tile<LazyArrayTile<T, Op>>
template <typename T>
static constexpr const auto is_device_tile_v = is_device_tile<T>::value;

#endif

template <typename Tensor, typename Enabler = void>
struct default_permutation;

Expand Down

0 comments on commit 4b79b5a

Please sign in to comment.