Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc/Testing] Use torch.testing.assert_close #7324

Merged
18 changes: 9 additions & 9 deletions tests/distributed/test_comm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def all_reduce_test_worker(tp_size: int, pp_size: int, rank: int,
expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
t = all_tensors[rank % tp_size]
t = tensor_model_parallel_all_reduce(t)
assert torch.allclose(t, expected)
torch.testing.assert_close(t, expected)


@ray.remote(num_gpus=1, max_calls=1)
Expand Down Expand Up @@ -62,7 +62,7 @@ def all_gather_test_worker(tp_size: int, pp_size: int, rank: int,
expected = torch.cat(all_tensors, dim=all_gather_dimension)
t = all_tensors[rank % tp_size]
t = tensor_model_parallel_all_gather(t, all_gather_dimension)
assert torch.allclose(t, expected)
torch.testing.assert_close(t, expected)


@ray.remote(num_gpus=1, max_calls=1)
Expand Down Expand Up @@ -96,12 +96,12 @@ def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
else:
recv_dict = broadcast_tensor_dict(src=0)
assert len(recv_dict) == len(test_dict)
assert torch.allclose(recv_dict["a"], test_dict["a"])
assert torch.allclose(recv_dict["b"], test_dict["b"])
torch.testing.assert_close(recv_dict["a"], test_dict["a"])
torch.testing.assert_close(recv_dict["b"], test_dict["b"])
assert recv_dict["c"] == test_dict["c"]
assert recv_dict["d"] == test_dict["d"]
assert recv_dict["e"] == test_dict["e"]
assert torch.allclose(recv_dict["f"], test_dict["f"])
torch.testing.assert_close(recv_dict["f"], test_dict["f"])


@ray.remote(num_gpus=1, max_calls=1)
Expand Down Expand Up @@ -136,12 +136,12 @@ def send_recv_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,

if not get_pp_group().is_first_rank:
assert len(recv_dict) == len(test_dict)
assert torch.allclose(recv_dict["a"], test_dict["a"])
assert torch.allclose(recv_dict["b"], test_dict["b"])
torch.testing.assert_close(recv_dict["a"], test_dict["a"])
torch.testing.assert_close(recv_dict["b"], test_dict["b"])
assert recv_dict["c"] == test_dict["c"]
assert recv_dict["d"] == test_dict["d"]
assert recv_dict["e"] == test_dict["e"]
assert torch.allclose(recv_dict["f"], test_dict["f"])
torch.testing.assert_close(recv_dict["f"], test_dict["f"])


@ray.remote(num_gpus=1, max_calls=1)
Expand All @@ -163,7 +163,7 @@ def send_recv_test_worker(tp_size: int, pp_size: int, rank: int,
get_pp_group().send(test_tensor)

if not get_pp_group().is_first_rank:
assert torch.allclose(test_tensor, recv_tensor)
torch.testing.assert_close(test_tensor, recv_tensor)


@pytest.mark.skipif(torch.cuda.device_count() < 2,
Expand Down
8 changes: 4 additions & 4 deletions tests/distributed/test_custom_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def graph_allreduce(tp_size, pp_size, rank, distributed_init_port):
out2 = tensor_model_parallel_all_reduce(inp2)
dist.all_reduce(inp2, group=group)
graph.replay()
assert torch.allclose(out1, inp1)
assert torch.allclose(out2, inp2)
torch.testing.assert_close(out1, inp1)
torch.testing.assert_close(out2, inp2)


@ray.remote(num_gpus=1, max_calls=1)
Expand All @@ -96,13 +96,13 @@ def eager_allreduce(tp_size, pp_size, rank, distributed_init_port):
out = inp
for _ in range(num_communication):
out = fa.all_reduce_unreg(out)
assert torch.allclose(out, inp * (tp_size**num_communication))
torch.testing.assert_close(out, inp * (tp_size**num_communication))

inp = torch.ones(sz * 4, dtype=torch.bfloat16, device=device)
out = inp
for _ in range(num_communication):
out = fa.all_reduce_unreg(out)
assert torch.allclose(out, inp * (tp_size**num_communication))
torch.testing.assert_close(out, inp * (tp_size**num_communication))


@pytest.mark.parametrize("tp_size", [2])
Expand Down
2 changes: 1 addition & 1 deletion tests/kernels/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,4 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
ref_iscale = one / ref_scale
ref_out = (as_float32_tensor(x) * ref_iscale).clamp(
fp8_traits.min, fp8_traits.max).to(dtype=torch.float8_e4m3fn)
return ref_out, ref_scale
return ref_out, ref_scale.view((1, ))
10 changes: 5 additions & 5 deletions tests/kernels/test_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_act_and_mul(
ref_out = layer.forward_native(x)
# The SiLU and GELU implementations are equivalent to the native PyTorch
# implementations, so we can do exact comparison.
assert torch.allclose(out, ref_out, atol=0.0, rtol=0.0)
torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0)


@pytest.mark.parametrize("activation", [FastGELU, NewGELU])
Expand All @@ -73,7 +73,7 @@ def test_activation(
layer = activation()
out = layer(x)
ref_out = layer.forward_native(x)
assert torch.allclose(out,
ref_out,
atol=get_default_atol(out),
rtol=get_default_rtol(out))
torch.testing.assert_close(out,
ref_out,
atol=get_default_atol(out),
rtol=get_default_rtol(out))
4 changes: 2 additions & 2 deletions tests/kernels/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def test_paged_attention(
atol, rtol = 1e-3, 1e-5
if kv_cache_dtype == "fp8":
atol, rtol = 1e-2, 1e-5
assert torch.allclose(output, ref_output, atol=atol, rtol=rtol)
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)


def ref_multi_query_kv_attention(
Expand Down Expand Up @@ -379,4 +379,4 @@ def test_multi_query_kv_attention(
)
atol = get_default_atol(output) if is_hip() else 1e-3
rtol = get_default_rtol(output) if is_hip() else 1e-5
assert torch.allclose(output, ref_output, atol=atol, rtol=rtol)
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
4 changes: 2 additions & 2 deletions tests/kernels/test_blocksparse_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def test_paged_attention(
atol, rtol = 1e-3, 1e-5
if kv_cache_dtype == "fp8":
atol, rtol = 1e-2, 1e-5
assert torch.allclose(output, ref_output, atol=atol, rtol=rtol)
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)


def ref_multi_query_kv_attention(
Expand Down Expand Up @@ -441,4 +441,4 @@ def test_varlen_blocksparse_attention_prefill(
scale,
dtype,
)
assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2)
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2)
54 changes: 27 additions & 27 deletions tests/kernels/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,10 @@ def test_copy_blocks(

# Compare the results.
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
assert torch.allclose(key_cache, cloned_key_cache)
torch.testing.assert_close(key_cache, cloned_key_cache)
for value_cache, cloned_value_cache in zip(value_caches,
cloned_value_caches):
assert torch.allclose(value_cache, cloned_value_cache)
torch.testing.assert_close(value_cache, cloned_value_cache)


@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
Expand Down Expand Up @@ -184,17 +184,17 @@ def test_reshape_and_cache(
cloned_value_cache[block_idx, :, :, block_offset] = value[i]

if kv_cache_dtype == "fp8":
assert torch.allclose(result_key_cache,
cloned_key_cache,
atol=0.001,
rtol=0.1)
assert torch.allclose(result_value_cache,
cloned_value_cache,
atol=0.001,
rtol=0.1)
torch.testing.assert_close(result_key_cache,
cloned_key_cache,
atol=0.001,
rtol=0.1)
torch.testing.assert_close(result_value_cache,
cloned_value_cache,
atol=0.001,
rtol=0.1)
else:
assert torch.allclose(key_cache, cloned_key_cache)
assert torch.allclose(value_cache, cloned_value_cache)
torch.testing.assert_close(key_cache, cloned_key_cache)
torch.testing.assert_close(value_cache, cloned_value_cache)


@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
Expand Down Expand Up @@ -290,17 +290,17 @@ def test_reshape_and_cache_flash(
cloned_value_cache[block_idx, block_offset, :, :] = value[i]

if kv_cache_dtype == "fp8":
assert torch.allclose(result_key_cache,
cloned_key_cache,
atol=0.001,
rtol=0.1)
assert torch.allclose(result_value_cache,
cloned_value_cache,
atol=0.001,
rtol=0.1)
torch.testing.assert_close(result_key_cache,
cloned_key_cache,
atol=0.001,
rtol=0.1)
torch.testing.assert_close(result_value_cache,
cloned_value_cache,
atol=0.001,
rtol=0.1)
else:
assert torch.allclose(key_cache, cloned_key_cache)
assert torch.allclose(value_cache, cloned_value_cache)
torch.testing.assert_close(key_cache, cloned_key_cache)
torch.testing.assert_close(value_cache, cloned_value_cache)


@pytest.mark.parametrize("direction", COPYING_DIRECTION)
Expand Down Expand Up @@ -372,10 +372,10 @@ def test_swap_blocks(
block_mapping_tensor)

for src, dst in block_mapping:
assert torch.allclose(src_key_caches_clone[src].cpu(),
dist_key_caches[0][dst].cpu())
assert torch.allclose(src_value_caches_clone[src].cpu(),
dist_value_caches[0][dst].cpu())
torch.testing.assert_close(src_key_caches_clone[src].cpu(),
dist_key_caches[0][dst].cpu())
torch.testing.assert_close(src_value_caches_clone[src].cpu(),
dist_value_caches[0][dst].cpu())


@pytest.mark.parametrize("num_heads", NUM_HEADS)
Expand Down Expand Up @@ -411,4 +411,4 @@ def test_fp8_e4m3_conversion(
converted_cache = torch.empty_like(cache)
ops.convert_fp8(converted_cache, cache_fp8)

assert torch.allclose(cache, converted_cache, atol=0.001, rtol=0.1)
torch.testing.assert_close(cache, converted_cache, atol=0.001, rtol=0.1)
23 changes: 13 additions & 10 deletions tests/kernels/test_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def cutlass_fp8_gemm_helper(m: int,
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)

assert torch.allclose(out, baseline, rtol=1e-2, atol=5e-2)
torch.testing.assert_close(out, baseline, rtol=1e-2, atol=5e-2)


def cutlass_int8_gemm_helper(m: int,
Expand Down Expand Up @@ -106,7 +106,7 @@ def cutlass_int8_gemm_helper(m: int,
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)

assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0)
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)


@pytest.mark.parametrize("m", [1, 16, 32, 64, 128, 256, 512, 222, 100, 33])
Expand Down Expand Up @@ -252,7 +252,7 @@ def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int,
azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding

a_dq = scale_a * (aq_i32 + azp_aq_i8).to(dtype=torch.float32)
assert torch.allclose(a_dq, scale_a * aq_f32 + azp_a)
torch.testing.assert_close(a_dq, scale_a * aq_f32 + azp_a)

baseline_dq = torch.mm(a_dq, b_dq).to(out_dtype)

Expand All @@ -271,8 +271,8 @@ def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int,
scale_b,
out_dtype=out_dtype,
bias=azp_bias[0, :])
assert torch.allclose(out, baseline_dq, rtol=1e-2, atol=1e0)
assert torch.allclose(out, baseline_q, rtol=1e-2, atol=1e0)
torch.testing.assert_close(out, baseline_dq, rtol=1e-2, atol=1e0)
torch.testing.assert_close(out, baseline_q, rtol=1e-2, atol=1e0)


@pytest.mark.parametrize("m", [32, 64, 128])
Expand Down Expand Up @@ -302,7 +302,10 @@ def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding

a_dq = scale_a * (aq_i32 - azp_aq_i8).to(dtype=torch.float32)
assert torch.allclose(a_dq, scale_a * aq_f32 - azp_a, rtol=1e-4, atol=1e-3)
torch.testing.assert_close(a_dq,
scale_a * aq_f32 - azp_a,
rtol=1e-4,
atol=1e-3)

if use_bias:
bias = torch.rand((1, n), device="cuda", dtype=out_dtype) * 10 + 2.5
Expand Down Expand Up @@ -335,8 +338,8 @@ def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
# float16 precision is 10-bit mantissa -> 2^-11 ~ 0.05%
rtol = 1e-2 if out_dtype == torch.bfloat16 else 1e-3
atol = 1e-3
assert torch.allclose(out, baseline_dq, rtol=rtol, atol=atol)
assert torch.allclose(out, baseline_q, rtol=rtol, atol=atol)
torch.testing.assert_close(out, baseline_dq, rtol=rtol, atol=atol)
torch.testing.assert_close(out, baseline_q, rtol=rtol, atol=atol)


# Test working with a subset of A and B
Expand All @@ -363,7 +366,7 @@ def test_cutlass_subset():
scale_b,
out_dtype=torch.bfloat16)

assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0)
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)


# Test to make sure cuda graphs work
Expand Down Expand Up @@ -411,4 +414,4 @@ def test_cutlass_cuda_graph(per_act_token: bool, per_out_ch: bool):

baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
scale_b * b.to(dtype=torch.float32)).to(torch.bfloat16)
assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0)
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
4 changes: 2 additions & 2 deletions tests/kernels/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def test_flash_attn_with_paged_kv(
scale=scale,
soft_cap=soft_cap,
)
assert torch.allclose(output, ref_output, atol=2e-2, rtol=1e-2), \
torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}"


Expand Down Expand Up @@ -211,5 +211,5 @@ def test_varlen_with_paged_kv(
sliding_window=sliding_window,
soft_cap=soft_cap,
)
assert torch.allclose(output, ref_output, atol=2e-2, rtol=1e-2), \
torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}"
4 changes: 2 additions & 2 deletions tests/kernels/test_flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def test_flashinfer_decode_with_paged_kv(kv_lens: List[int],
block_tables=block_tables,
scale=scale,
soft_cap=soft_cap)
assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}"


Expand Down Expand Up @@ -244,5 +244,5 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
block_tables=block_tables,
scale=scale,
soft_cap=soft_cap)
assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}"
14 changes: 7 additions & 7 deletions tests/kernels/test_fp8_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int,
scale_ub=scale_ub,
use_per_token_if_dynamic=True)

assert torch.allclose(ref_scales, ops_scales)
assert torch.allclose(ref_out.to(dtype=torch.float32),
ops_out.to(dtype=torch.float32))
torch.testing.assert_close(ref_scales, ops_scales)
torch.testing.assert_close(ref_out.to(dtype=torch.float32),
ops_out.to(dtype=torch.float32))


@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
Expand All @@ -57,9 +57,9 @@ def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int,
ref_out, ref_scale = ref_dynamic_per_tensor_fp8_quant(x)
ops_out, ops_scale = ops.scaled_fp8_quant(x)

assert torch.allclose(ref_scale, ops_scale)
assert torch.allclose(ref_out.to(dtype=torch.float32),
ops_out.to(dtype=torch.float32))
torch.testing.assert_close(ref_scale, ops_scale)
torch.testing.assert_close(ref_out.to(dtype=torch.float32),
ops_out.to(dtype=torch.float32))


# Regression test for a case with large activations where an int32 index cannot
Expand All @@ -84,4 +84,4 @@ def test_fp8_quant_large(seed: int) -> None:
ref_out = ref_out.to(dtype=dtype)
ops_out = ops_out.to(dtype=dtype)

assert torch.allclose(ref_out, ops_out)
torch.testing.assert_close(ref_out, ops_out)
12 changes: 7 additions & 5 deletions tests/kernels/test_int8_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
# kernel
ops_out, ops_scales = scaled_int8_quant(x)

assert torch.allclose(ops_scales, ref_scales)
assert torch.allclose(ops_out, ref_out,
atol=1) # big atol to account for rounding errors
torch.testing.assert_close(ops_scales, ref_scales)
torch.testing.assert_close(
ops_out, ref_out, atol=1,
rtol=0.0) # big atol to account for rounding errors


@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
Expand All @@ -54,5 +55,6 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
int8_traits.max).to(torch.int8)
out2, _ = scaled_int8_quant(x, scale)

assert torch.allclose(out1, out2,
atol=1) # big atol to account for rounding errors
torch.testing.assert_close(
out1, out2, atol=1,
rtol=0.0) # big atol to account for rounding errors
Loading
Loading