diff --git a/src/TiledArray/conversions/clone.h b/src/TiledArray/conversions/clone.h index b8c05df840..910d86e21d 100644 --- a/src/TiledArray/conversions/clone.h +++ b/src/TiledArray/conversions/clone.h @@ -26,6 +26,10 @@ #ifndef TILEDARRAY_CONVERSIONS_CLONE_H__INCLUDED #define TILEDARRAY_CONVERSIONS_CLONE_H__INCLUDED +#ifdef TILEDARRAY_HAS_DEVICE +#include "TiledArray/device/device_task_fn.h" +#endif + namespace TiledArray { /// Forward declarations @@ -53,12 +57,28 @@ inline DistArray clone(const DistArray& arg) { if (arg.is_zero(index)) continue; // Spawn a task to clone the tiles - Future tile = world.taskq.add( - [](const value_type& tile) -> value_type { - using TiledArray::clone; - return clone(tile); - }, - arg.find(index)); + + Future tile; + if constexpr (!detail::is_device_tile_v) { + tile = world.taskq.add( + [](const value_type& tile) -> value_type { + using TiledArray::clone; + return clone(tile); + }, + arg.find(index)); + } else { +#ifdef TILEDARRAY_HAS_DEVICE + tile = madness::add_device_task( + world, + [](const value_type& tile) -> value_type { + using TiledArray::clone; + return clone(tile); + }, + arg.find(index)); +#else + abort(); // unreachable +#endif + } // Store result tile result.set(index, tile); diff --git a/src/TiledArray/external/device.h b/src/TiledArray/external/device.h index dcf286c443..44d9c77a68 100644 --- a/src/TiledArray/external/device.h +++ b/src/TiledArray/external/device.h @@ -807,6 +807,7 @@ inline std::optional*& madness_task_stream_opt_ptr_accessor() { inline std::optional& tls_stream_opt_accessor() { static thread_local std::optional stream_opt = + device::Env::instance()->stream(0); return stream_opt; } diff --git a/src/TiledArray/tensor/type_traits.h b/src/TiledArray/tensor/type_traits.h index e9d1681f71..eed84c6026 100644 --- a/src/TiledArray/tensor/type_traits.h +++ b/src/TiledArray/tensor/type_traits.h @@ -314,8 +314,7 @@ template constexpr const bool is_reduce_op_v = is_reduce_op_::value; -/// detect cuda tile -#ifdef TILEDARRAY_HAS_DEVICE +/// detect device tile types template struct is_device_tile : public std::false_type {}; @@ -329,8 +328,6 @@ struct is_device_tile> template static constexpr const auto is_device_tile_v = is_device_tile::value; -#endif - template struct default_permutation;