From 0ed30902e3879b98c7a985915ee37ccb6048a2cb Mon Sep 17 00:00:00 2001 From: ebsmothers Date: Thu, 22 Aug 2024 22:51:25 -0700 Subject: [PATCH] Move non-NF4 tensor to device prior to quantization on copy (#737) --- torchao/dtypes/nf4tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/dtypes/nf4tensor.py b/torchao/dtypes/nf4tensor.py index 90516ea199..b386f85ae0 100644 --- a/torchao/dtypes/nf4tensor.py +++ b/torchao/dtypes/nf4tensor.py @@ -339,7 +339,7 @@ def copy_(func, *args, **kwargs): # Convert Non NF4Tensor into NF4 for copy in if not isinstance(copy_in, NF4Tensor): copy_in_nf4 = NF4Tensor.from_tensor( - copy_in, original.block_size, original.scaler_block_size + copy_in.to(original.device), original.block_size, original.scaler_block_size ) return original.copy_(copy_in_nf4)