Skip to content

Commit

Permalink
better diff layout
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
weifengpy committed Apr 19, 2024
1 parent 761416a commit c656f1e
Showing 1 changed file with 43 additions and 41 deletions.
84 changes: 43 additions & 41 deletions torchao/dtypes/nf4tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ def decorator(func):
return decorator


@implements([torch.ops.aten.detach])
def noop_detach(func, *args, **kwargs):
return args[0][0]


@implements(
[
aten.detach.default,
Expand Down Expand Up @@ -246,6 +251,44 @@ def _to_copy(func, *args, **kwargs):
return args[0][0].get_original_weight().to(args[1]["dtype"])


@implements([torch.ops.aten.to.dtype])
def to_dtype(func, *args, **kwargs):
if not args[0][0].is_contiguous():
assert args[0][0].t().is_contiguous()
return torch.ops.aten.to.dtype(args[0][0].t(), args[0][1]).t()
return args[0][0].get_original_weight().to(args[0][1])


@implements([torch.ops.aten.t.default])
def t_default(func, *args, **kwargs):
a = args[0][0]
tensor_meta = SubclassTensorArgs(
a.size(),
(a.stride(1), a.stride(0)),
a.storage_offset(),
a.dtype,
a.device,
a.requires_grad,
)
b = NF4Tensor(
tensor_meta,
a.block_size,
a.n_blocks,
a.scaler_block_size,
a.quantized_scalers,
a.quantization_factor,
a.scaler_mean,
a.quantized_data,
a.nf4,
)
return b


@implements([torch.ops.aten.mm.default])
def mm_default(func, *args, **kwargs):
return linear_nf4(args[0][0], args[0][1])


@implements(
[
aten.copy_.default,
Expand Down Expand Up @@ -297,47 +340,6 @@ def nf4_copy_(aten_op, args, kwargs=None):
return original.copy_(same_meta_nf4)


@implements([torch.ops.aten.to.dtype])
def to_dtype(func, *args, **kwargs):
if not args[0][0].is_contiguous():
assert args[0][0].t().is_contiguous()
return torch.ops.aten.to.dtype(args[0][0].t(), args[0][1]).t()
return args[0][0].get_original_weight().to(args[0][1])


@implements([torch.ops.aten.t.default])
def t_default(func, *args, **kwargs):
a = args[0][0]
tensor_meta = SubclassTensorArgs(
a.size(),
(a.stride(1), a.stride(0)),
a.storage_offset(),
a.dtype,
a.device,
a.requires_grad,
)
b = NF4Tensor(
tensor_meta,
a.block_size,
a.n_blocks,
a.scaler_block_size,
a.quantized_scalers,
a.quantization_factor,
a.scaler_mean,
a.quantized_data,
a.nf4,
)
return b

@implements([torch.ops.aten.detach])
def noop_detach(func, *args, **kwargs):
return args[0][0]

@implements([torch.ops.aten.mm.default])
def mm_default(func, *args, **kwargs):
return linear_nf4(args[0][0], args[0][1])


@dataclass
class SubclassTensorArgs:
original_shape: torch.Size
Expand Down

0 comments on commit c656f1e

Please sign in to comment.