-
Notifications
You must be signed in to change notification settings - Fork 23.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
torch.float8_e4m3fn
does not support torch.cat
#107256
Comments
we should support cat, @awgu let us know when you need it |
I know that FP8 support is still at a very early stage, but to keep track of what is possible at the moment, I created a small test script with some basic functional tests. So that not everyone who is wondering about the same questions has to do the tests themselves, I'm sharing my results. Note, that the results are generated with preview (nightly) version (PyTorch version:
Test scriptimport torch
cuda = "cuda:0"
dtype = torch.float8_e4m3fn
print("> torch.__version__:", torch.__version__)
print("> torch.empty")
try: print("✔ cpu: ", torch.empty((3,), dtype=dtype))
except Exception as e: print("✘ cpu: ", e)
try: print("✔ cuda:", torch.empty((3,), dtype=dtype, device=cuda))
except Exception as e: print("✘ cuda:", e)
print("> torch.zeros")
try: print("✔ cpu: ", torch.zeros((3,), dtype=dtype))
except Exception as e: print("✘ cpu: ", e)
try: print("✔ cuda:", torch.zeros((3,), dtype=dtype, device=cuda))
except Exception as e: print("✘ cuda:", e)
print("> torch.ones")
try: print("✔ cpu: ", torch.ones((3,), dtype=dtype))
except Exception as e: print("✘ cpu: ", e)
try: print("✔ cuda:", torch.ones((3,), dtype=dtype, device=cuda))
except Exception as e: print("✘ cuda:", e)
print("> torch.zeros_like")
t = torch.empty((3,), dtype=dtype)
try: print("✔ cpu: ", torch.zeros_like(t))
except Exception as e: print("✘ cpu: ", e)
t = t.to(cuda)
try: print("✔ cuda:", torch.zeros_like(t))
except Exception as e: print("✘ cuda:", e)
print("> torch.ones_like")
t = torch.empty((3,), dtype=dtype)
try: print("✔ cpu: ", torch.ones_like(t))
except Exception as e: print("✘ cpu: ", e)
t = t.to(cuda)
try: print("✔ cuda:", torch.ones_like(t))
except Exception as e: print("✘ cuda:", e)
print("> t1 + t2")
t1 = torch.empty((3,), dtype=dtype)
t2 = torch.empty((3,), dtype=dtype)
try: print("✔ cpu: ", t1 + t2)
except Exception as e: print("✘ cpu: ", e)
t1, t2 = t1.to(cuda), t2.to(cuda)
try: print("✔ cuda:", t1 + t2)
except Exception as e: print("✘ cuda:", e)
print("> t1 - t2")
t1 = torch.empty((3,), dtype=dtype)
t2 = torch.empty((3,), dtype=dtype)
try: print("✔ cpu: ", t1 - t2)
except Exception as e: print("✘ cpu: ", e)
t1, t2 = t1.to(cuda), t2.to(cuda)
try: print("✔ cuda:", t1 - t2)
except Exception as e: print("✘ cuda:", e)
print("> t1 * t2")
t1 = torch.empty((3,), dtype=dtype)
t2 = torch.empty((3,), dtype=dtype)
try: print("✔ cpu: ", t1 * t2)
except Exception as e: print("✘ cpu: ", e)
t1, t2 = t1.to(cuda), t2.to(cuda)
try: print("✔ cuda:", t1 * t2)
except Exception as e: print("✘ cuda:", e)
print("> t1 / t2")
t1 = torch.empty((3,), dtype=dtype)
t2 = torch.empty((3,), dtype=dtype)
try: print("✔ cpu: ", t1 / t2)
except Exception as e: print("✘ cpu: ", e)
t1, t2 = t1.to(cuda), t2.to(cuda)
try: print("✔ cuda:", t1 / t2)
except Exception as e: print("✘ cuda:", e)
print("> t1 ** t2")
t1 = torch.empty((3,), dtype=dtype)
t2 = torch.tensor(2, dtype=dtype)
try: print("✔ cpu: ", t1 ** t2)
except Exception as e: print("✘ cpu: ", e)
t1, t2 = t1.to(cuda), t2.to(cuda)
try: print("✔ cuda:", t1 ** t2)
except Exception as e: print("✘ cuda:", e)
print("> t1 // t2")
t1 = torch.empty((3,), dtype=dtype)
t2 = torch.tensor(2, dtype=dtype)
try: print("✔ cpu: ", t1 // t2)
except Exception as e: print("✘ cpu: ", e)
t1, t2 = t1.to(cuda), t2.to(cuda)
try: print("✔ cuda:", t1 // t2)
except Exception as e: print("✘ cuda:", e)
print("> t1 % t2")
t1 = torch.empty((3,), dtype=dtype)
t2 = torch.tensor(2, dtype=dtype)
try: print("✔ cpu: ", t1 % t2)
except Exception as e: print("✘ cpu: ", e)
t1, t2 = t1.to(cuda), t2.to(cuda)
try: print("✔ cuda:", t1 % t2)
except Exception as e: print("✘ cuda:", e)
print("> torch.cat([t1, t2], dim=0)")
t1 = torch.empty((3,), dtype=dtype)
t2 = torch.empty((2,), dtype=dtype)
try: print("✔ cpu: ", torch.cat([t1, t2], dim=0))
except Exception as e: print("✘ cpu: ", e)
t1, t2 = t1.to(cuda), t2.to(cuda)
try: print("✔ cuda:", torch.cat([t1, t2], dim=0))
except Exception as e: print("✘ cuda:", e)
print("> torch.stack([t1, t2], dim=0)")
t1 = torch.empty((3,), dtype=dtype)
t2 = torch.empty((3,), dtype=dtype)
try: print("✔ cpu: ", torch.stack([t1, t2], dim=0))
except Exception as e: print("✘ cpu: ", e)
t1, t2 = t1.to(cuda), t2.to(cuda)
try: print("✔ cuda:", torch.stack([t1, t2], dim=0))
except Exception as e: print("✘ cuda:", e) |
According to the #107256 (comment) the ops tested in `test_schema_correctness` are not supported with `torch.float8_e4m3fn` yet. Until they are not supported, it is best to skip the test. Pull Request resolved: #115757 Approved by: https://github.com/drisspg
According to the pytorch#107256 (comment) the ops tested in `test_schema_correctness` are not supported with `torch.float8_e4m3fn` yet. Until they are not supported, it is best to skip the test. Pull Request resolved: pytorch#115757 Approved by: https://github.com/drisspg
According to the pytorch#107256 (comment) the ops tested in `test_schema_correctness` are not supported with `torch.float8_e4m3fn` yet. Until they are not supported, it is best to skip the test. Pull Request resolved: pytorch#115757 Approved by: https://github.com/drisspg
One use case of float8_e4m3fn cat: |
@awgu any chance this is going to be supported? it makes a few FP8 use cases more difficult without it Similarly, it would be nice if |
@mvpatel2000 seems deterministic mode got fixed in #128733 for cat, someone just needs to plumb it through I think |
honestly, having basic arithmetic ops support for float8 would make a lot of things easier |
A small trick for temporarialy using torch.cat() with fp8 dtype: Just use view() to cast to another 8bit dtype to bypass type check. demo
|
As of right now, this is not blocking any work. However, I would imagine that concatenating
float8_e4m3fn
tensors should be well-defined.Repro
Addendum:
torch.float8_e4m3fn
does not supporttorch.ones
on CPU:cc @yanbing-j
The text was updated successfully, but these errors were encountered: