diff --git a/treescope/external/torch_support.py b/treescope/external/torch_support.py index 4d7c09d..0b00e70 100644 --- a/treescope/external/torch_support.py +++ b/treescope/external/torch_support.py @@ -79,7 +79,7 @@ def _truncate_and_copy( assert ( len(prefix_slices) == len(array_source.shape) == len(array_dest.shape) ) - array_dest[prefix_slices] = array_source[prefix_slices].numpy() + array_dest[prefix_slices] = array_source[prefix_slices].numpy(force=True) else: # Recursive step. axis = len(prefix_slices) @@ -145,7 +145,7 @@ def get_array_data_with_truncation( if edge_items_per_axis == (None,) * array.ndim: # No truncation. - return array.numpy(), mask.numpy() + return array.numpy(force=True), mask.numpy(force=True) dest_shape = [ size if edge_items is None else 2 * edge_items + 1