diff --git a/torchao/dtypes/nf4tensor.py b/torchao/dtypes/nf4tensor.py index 4143f88d8d..dfd65c8b74 100644 --- a/torchao/dtypes/nf4tensor.py +++ b/torchao/dtypes/nf4tensor.py @@ -38,6 +38,11 @@ def decorator(func): return decorator +@implements([torch.ops.aten.detach]) +def noop_detach(func, *args, **kwargs): + return args[0][0] + + @implements( [ aten.detach.default, @@ -246,6 +251,44 @@ def _to_copy(func, *args, **kwargs): return args[0][0].get_original_weight().to(args[1]["dtype"]) +@implements([torch.ops.aten.to.dtype]) +def to_dtype(func, *args, **kwargs): + if not args[0][0].is_contiguous(): + assert args[0][0].t().is_contiguous() + return torch.ops.aten.to.dtype(args[0][0].t(), args[0][1]).t() + return args[0][0].get_original_weight().to(args[0][1]) + + +@implements([torch.ops.aten.t.default]) +def t_default(func, *args, **kwargs): + a = args[0][0] + tensor_meta = SubclassTensorArgs( + a.size(), + (a.stride(1), a.stride(0)), + a.storage_offset(), + a.dtype, + a.device, + a.requires_grad, + ) + b = NF4Tensor( + tensor_meta, + a.block_size, + a.n_blocks, + a.scaler_block_size, + a.quantized_scalers, + a.quantization_factor, + a.scaler_mean, + a.quantized_data, + a.nf4, + ) + return b + + +@implements([torch.ops.aten.mm.default]) +def mm_default(func, *args, **kwargs): + return linear_nf4(args[0][0], args[0][1]) + + @implements( [ aten.copy_.default, @@ -297,47 +340,6 @@ def nf4_copy_(aten_op, args, kwargs=None): return original.copy_(same_meta_nf4) -@implements([torch.ops.aten.to.dtype]) -def to_dtype(func, *args, **kwargs): - if not args[0][0].is_contiguous(): - assert args[0][0].t().is_contiguous() - return torch.ops.aten.to.dtype(args[0][0].t(), args[0][1]).t() - return args[0][0].get_original_weight().to(args[0][1]) - - -@implements([torch.ops.aten.t.default]) -def t_default(func, *args, **kwargs): - a = args[0][0] - tensor_meta = SubclassTensorArgs( - a.size(), - (a.stride(1), a.stride(0)), - a.storage_offset(), - a.dtype, - a.device, - a.requires_grad, - ) - b = NF4Tensor( - tensor_meta, - a.block_size, - a.n_blocks, - a.scaler_block_size, - a.quantized_scalers, - a.quantization_factor, - a.scaler_mean, - a.quantized_data, - a.nf4, - ) - return b - -@implements([torch.ops.aten.detach]) -def noop_detach(func, *args, **kwargs): - return args[0][0] - -@implements([torch.ops.aten.mm.default]) -def mm_default(func, *args, **kwargs): - return linear_nf4(args[0][0], args[0][1]) - - @dataclass class SubclassTensorArgs: original_shape: torch.Size