Skip to content

Commit

Permalink
Rotary testing (#323)
Browse files Browse the repository at this point in the history
* Add rotary helper tests

* Fix formatting

Co-authored-by: Chris Yuan <[email protected]>
  • Loading branch information
yuanandonly and Chris Yuan authored Jun 3, 2022
1 parent 37cf11f commit 5ccbcd9
Showing 1 changed file with 30 additions and 0 deletions.
30 changes: 30 additions & 0 deletions tests/test_rotary_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
import torch

from xformers.components.positional_embedding import RotaryEmbedding
from xformers.components.positional_embedding.rotary import (
apply_rotary_pos_emb,
rotate_half,
)

DEVICES = (
[torch.device("cpu")]
Expand All @@ -21,6 +25,32 @@
EMB = 32


def test_helper_methods():
# rotate_half
tens = torch.tensor([[0, 1, 2, 3], [3, 1, 2, 0], [0, 1, 0, 1], [1, 0, 1, 0]])
tens_rotated = rotate_half(tens)
assert torch.equal(
tens_rotated,
torch.tensor([[-2, -3, 0, 1], [-2, 0, 3, 1], [0, -1, 0, 1], [-1, 0, 1, 0]]),
)

# apply_rotary_pos_emb
cos_test = torch.ones((1, 1, 4, 4))
sin_test = cos_test.clone()
q_test = 3 * torch.ones((2, 2, 3, 4))
q_applied = apply_rotary_pos_emb(q_test, cos_test, sin_test)
assert torch.equal(
q_applied,
torch.concat(
(
torch.zeros((2, 2, 3, 2), dtype=torch.float),
6 * torch.ones((2, 2, 3, 2), dtype=torch.float),
),
dim=-1,
),
)


@pytest.mark.parametrize("device", DEVICES)
def test_rotary_embeddings(device):
rotary = RotaryEmbedding(EMB).to(device)
Expand Down

0 comments on commit 5ccbcd9

Please sign in to comment.