Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.float8_e4m3fn does not support torch.cat #107256

Closed
awgu opened this issue Aug 15, 2023 · 9 comments
Closed

torch.float8_e4m3fn does not support torch.cat #107256

awgu opened this issue Aug 15, 2023 · 9 comments
Labels
module: float8 For torch.float8_e5m2 and torch.float8_e4m3 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@awgu
Copy link
Contributor

awgu commented Aug 15, 2023

As of right now, this is not blocking any work. However, I would imagine that concatenating float8_e4m3fn tensors should be well-defined.

Repro
import torch

def test_cat_fp8_cpu():
    t1 = torch.empty((3,), dtype=torch.float8_e4m3fn)
    t2 = torch.empty((2,), dtype=torch.float8_e4m3fn)
    t = torch.cat([t1, t2], dim=0)
    print(t)

def test_cat_fp8_cuda():
    t1 = torch.ones((3,), dtype=torch.float8_e4m3fn, device="cuda")
    t2 = torch.ones((2,), dtype=torch.float8_e4m3fn, device="cuda")
    t = torch.cat([t1, t2], dim=0)
    print(t)

try:
    test_cat_fp8_cpu()
except Exception as e:
    print(e)
try:
    test_cat_fp8_cuda()
except Exception as e:
    print(e)
Unsupported TypeMeta in ATen: nullptr (uninitialized) (please report this error)
Unsupported TypeMeta in ATen: nullptr (uninitialized) (please report this error)

Addendum: torch.float8_e4m3fn does not support torch.ones on CPU:

"fill_cpu" not implemented for 'Float8_e4m3fn'

cc @yanbing-j

@yanbing-j
Copy link
Collaborator

Hi @awgu ,

FP8 support is at an very early stage now. Only data type and tensor factory is added. Most of the operators don't support FP8 yet, e.g, cat kernel is not supported. So this failure is expected. cc @vkuzo @albanD

@vkuzo
Copy link
Contributor

vkuzo commented Aug 16, 2023

we should support cat, @awgu let us know when you need it

@albanD albanD added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: float8 For torch.float8_e5m2 and torch.float8_e4m3 labels Aug 16, 2023
@HeinrichAD
Copy link

I know that FP8 support is still at a very early stage, but to keep track of what is possible at the moment, I created a small test script with some basic functional tests.

So that not everyone who is wondering about the same questions has to do the tests themselves, I'm sharing my results. Note, that the results are generated with preview (nightly) version (PyTorch version: 2.1.0.dev20230904+cu121) and the test script can be found below the table.

Command Support Output
> torch.empty cpu tensor([ 5.2000e+01, -4.0000e+00, -3.1250e-02], dtype=torch.float8_e4m3fn)
cuda tensor([0., 0., 0.], device='cuda:0', dtype=torch.float8_e4m3fn)
> torch.zeros cpu tensor([0., 0., 0.], dtype=torch.float8_e4m3fn)
cuda tensor([0., 0., 0.], device='cuda:0', dtype=torch.float8_e4m3fn)
> torch.ones cpu "fill_cpu" not implemented for 'Float8_e4m3fn'
cuda tensor([1., 1., 1.], device='cuda:0', dtype=torch.float8_e4m3fn)
> torch.zeros_like cpu tensor([0., 0., 0.], dtype=torch.float8_e4m3fn)
cuda tensor([0., 0., 0.], device='cuda:0', dtype=torch.float8_e4m3fn)
> torch.ones_like cpu "fill_cpu" not implemented for 'Float8_e4m3fn'
cuda tensor([1., 1., 1.], device='cuda:0', dtype=torch.float8_e4m3fn)
> t1 + t2 cpu "add_stub" not implemented for 'Float8_e4m3fn'
cuda "ufunc_add_CUDA" not implemented for 'Float8_e4m3fn'
> t1 - t2 cpu "add_stub" not implemented for 'Float8_e4m3fn'
cuda "ufunc_add_CUDA" not implemented for 'Float8_e4m3fn'
> t1 * t2 cpu tensor([ 0.0000, -2.0000, -0.2188], dtype=torch.float8_e4m3fn)
cuda "mul_cuda" not implemented for 'Float8_e4m3fn'
> t1 / t2 cpu "div_cpu" not implemented for 'Float8_e4m3fn'
cuda "div_true_cuda" not implemented for 'Float8_e4m3fn'
> t1 ** t2 cpu "pow" not implemented for 'Float8_e4m3fn'
cuda "pow_cuda" not implemented for 'Float8_e4m3fn'
> t1 // t2 cpu "div_floor_cpu_reduced_float" not implemented for 'Float8_e4m3fn'
cuda "div_floor_cuda" not implemented for 'Float8_e4m3fn'
> t1 % t2 cpu "remainder_cpu" not implemented for 'Float8_e4m3fn'
cuda "remainder_cuda" not implemented for 'Float8_e4m3fn'
> torch.cat cpu Unsupported TypeMeta in ATen: nullptr (uninitialized) (please report this error)
cuda Unsupported TypeMeta in ATen: nullptr (uninitialized) (please report this error)
> torch.stack cpu Unsupported TypeMeta in ATen: nullptr (uninitialized) (please report this error)
cuda Unsupported TypeMeta in ATen: nullptr (uninitialized) (please report this error)
Test script
import torch

cuda = "cuda:0"
dtype = torch.float8_e4m3fn
print("> torch.__version__:", torch.__version__)

print("> torch.empty")
try: print("✔ cpu: ", torch.empty((3,), dtype=dtype))
except Exception as e: print("✘ cpu: ", e)
try: print("✔ cuda:", torch.empty((3,), dtype=dtype, device=cuda))
except Exception as e: print("✘ cuda:", e)

print("> torch.zeros")
try: print("✔ cpu: ", torch.zeros((3,), dtype=dtype))
except Exception as e: print("✘ cpu: ", e)
try: print("✔ cuda:", torch.zeros((3,), dtype=dtype, device=cuda))
except Exception as e: print("✘ cuda:", e)

print("> torch.ones")
try: print("✔ cpu: ", torch.ones((3,), dtype=dtype))
except Exception as e: print("✘ cpu: ", e)
try: print("✔ cuda:", torch.ones((3,), dtype=dtype, device=cuda))
except Exception as e: print("✘ cuda:", e)

print("> torch.zeros_like")
t = torch.empty((3,), dtype=dtype)
try: print("✔ cpu: ", torch.zeros_like(t))
except Exception as e: print("✘ cpu: ", e)
t = t.to(cuda)
try: print("✔ cuda:", torch.zeros_like(t))
except Exception as e: print("✘ cuda:", e)

print("> torch.ones_like")
t = torch.empty((3,), dtype=dtype)
try: print("✔ cpu: ", torch.ones_like(t))
except Exception as e: print("✘ cpu: ", e)
t = t.to(cuda)
try: print("✔ cuda:", torch.ones_like(t))
except Exception as e: print("✘ cuda:", e)

print("> t1 + t2")
t1 = torch.empty((3,), dtype=dtype)
t2 = torch.empty((3,), dtype=dtype)
try: print("✔ cpu: ", t1 + t2)
except Exception as e: print("✘ cpu: ", e)
t1, t2 = t1.to(cuda), t2.to(cuda)
try: print("✔ cuda:", t1 + t2)
except Exception as e: print("✘ cuda:", e)

print("> t1 - t2")
t1 = torch.empty((3,), dtype=dtype)
t2 = torch.empty((3,), dtype=dtype)
try: print("✔ cpu: ", t1 - t2)
except Exception as e: print("✘ cpu: ", e)
t1, t2 = t1.to(cuda), t2.to(cuda)
try: print("✔ cuda:", t1 - t2)
except Exception as e: print("✘ cuda:", e)

print("> t1 * t2")
t1 = torch.empty((3,), dtype=dtype)
t2 = torch.empty((3,), dtype=dtype)
try: print("✔ cpu: ", t1 * t2)
except Exception as e: print("✘ cpu: ", e)
t1, t2 = t1.to(cuda), t2.to(cuda)
try: print("✔ cuda:", t1 * t2)
except Exception as e: print("✘ cuda:", e)

print("> t1 / t2")
t1 = torch.empty((3,), dtype=dtype)
t2 = torch.empty((3,), dtype=dtype)
try: print("✔ cpu: ", t1 / t2)
except Exception as e: print("✘ cpu: ", e)
t1, t2 = t1.to(cuda), t2.to(cuda)
try: print("✔ cuda:", t1 / t2)
except Exception as e: print("✘ cuda:", e)

print("> t1 ** t2")
t1 = torch.empty((3,), dtype=dtype)
t2 = torch.tensor(2, dtype=dtype)
try: print("✔ cpu: ", t1 ** t2)
except Exception as e: print("✘ cpu: ", e)
t1, t2 = t1.to(cuda), t2.to(cuda)
try: print("✔ cuda:", t1 ** t2)
except Exception as e: print("✘ cuda:", e)

print("> t1 // t2")
t1 = torch.empty((3,), dtype=dtype)
t2 = torch.tensor(2, dtype=dtype)
try: print("✔ cpu: ", t1 // t2)
except Exception as e: print("✘ cpu: ", e)
t1, t2 = t1.to(cuda), t2.to(cuda)
try: print("✔ cuda:", t1 // t2)
except Exception as e: print("✘ cuda:", e)

print("> t1 % t2")
t1 = torch.empty((3,), dtype=dtype)
t2 = torch.tensor(2, dtype=dtype)
try: print("✔ cpu: ", t1 % t2)
except Exception as e: print("✘ cpu: ", e)
t1, t2 = t1.to(cuda), t2.to(cuda)
try: print("✔ cuda:", t1 % t2)
except Exception as e: print("✘ cuda:", e)

print("> torch.cat([t1, t2], dim=0)")
t1 = torch.empty((3,), dtype=dtype)
t2 = torch.empty((2,), dtype=dtype)
try: print("✔ cpu: ", torch.cat([t1, t2], dim=0))
except Exception as e: print("✘ cpu: ", e)
t1, t2 = t1.to(cuda), t2.to(cuda)
try: print("✔ cuda:", torch.cat([t1, t2], dim=0))
except Exception as e: print("✘ cuda:", e)

print("> torch.stack([t1, t2], dim=0)")
t1 = torch.empty((3,), dtype=dtype)
t2 = torch.empty((3,), dtype=dtype)
try: print("✔ cpu: ", torch.stack([t1, t2], dim=0))
except Exception as e: print("✘ cpu: ", e)
t1, t2 = t1.to(cuda), t2.to(cuda)
try: print("✔ cuda:", torch.stack([t1, t2], dim=0))
except Exception as e: print("✘ cuda:", e)

pytorchmergebot pushed a commit that referenced this issue Dec 15, 2023
According to the #107256 (comment) the ops tested in `test_schema_correctness` are not supported with `torch.float8_e4m3fn` yet. Until they are not supported, it is best to skip the test.

Pull Request resolved: #115757
Approved by: https://github.com/drisspg
guilhermeleobas pushed a commit to guilhermeleobas/pytorch that referenced this issue Dec 18, 2023
According to the pytorch#107256 (comment) the ops tested in `test_schema_correctness` are not supported with `torch.float8_e4m3fn` yet. Until they are not supported, it is best to skip the test.

Pull Request resolved: pytorch#115757
Approved by: https://github.com/drisspg
dmenig pushed a commit to dmenig/pytorch that referenced this issue Dec 21, 2023
According to the pytorch#107256 (comment) the ops tested in `test_schema_correctness` are not supported with `torch.float8_e4m3fn` yet. Until they are not supported, it is best to skip the test.

Pull Request resolved: pytorch#115757
Approved by: https://github.com/drisspg
@LSC527
Copy link

LSC527 commented Jul 9, 2024

One use case of float8_e4m3fn cat:
vllm-project/vllm#6249

@mvpatel2000
Copy link
Contributor

mvpatel2000 commented Oct 21, 2024

@awgu any chance this is going to be supported? it makes a few FP8 use cases more difficult without it

Similarly, it would be nice if torch.empty supported FP8 in deterministic mode. As is, I have to disable it every time I write a unit test which is a bit annoying.

@awgu
Copy link
Contributor Author

awgu commented Oct 21, 2024

@mvpatel2000 seems deterministic mode got fixed in #128733

for cat, someone just needs to plumb it through I think
cc: @vkuzo

@Xynonners
Copy link

honestly, having basic arithmetic ops support for float8 would make a lot of things easier

@cxiliao
Copy link

cxiliao commented Jan 7, 2025

A small trick for temporarialy using torch.cat() with fp8 dtype:

Just use view() to cast to another 8bit dtype to bypass type check.

demo

import torch

def torch_test():
    a = torch.ones((1, 5), dtype=torch.float8_e5m2)
    b = torch.ones((1, 5), dtype=torch.float8_e5m2)
    c = torch.cat([a.view(torch.int8), b.view(torch.int8)], dim=0)

    print(c.view(torch.float8_e5m2))

torch_test()

@danielvegamyhre
Copy link
Contributor

@vkuzo looks like this was implemented in #138046

Building from source:

>>> x, y = torch.ones([1], dtype=torch.float8_e4m3fn), torch.ones([1], dtype=torch.float8_e4m3fn)
>>> torch.cat([x,y])
tensor([1., 1.], dtype=torch.float8_e4m3fn)

@awgu awgu closed this as completed Jan 8, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: float8 For torch.float8_e5m2 and torch.float8_e4m3 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

10 participants