Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc] Remove unnecessary ModelRunner imports #4703

Merged
merged 2 commits into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 24 additions & 57 deletions tests/samplers/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.utils import Counter
from vllm.worker.model_runner import ModelRunner
from vllm.utils import Counter, is_pin_memory_available


class MockLogitsSampler(Sampler):
Expand All @@ -26,20 +25,14 @@ def forward(self, *args, **kwargs):


def _prepare_test(
batch_size: int
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, ModelRunner]:
batch_size: int
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler]:
input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16)
fake_logits = torch.full((batch_size, VOCAB_SIZE),
1e-2,
dtype=input_tensor.dtype)
sampler = MockLogitsSampler(fake_logits)
model_runner = ModelRunner(model_config=None,
parallel_config=None,
scheduler_config=None,
device_config=None,
load_config=None,
lora_config=None)
return input_tensor, fake_logits, sampler, model_runner
return input_tensor, fake_logits, sampler


VOCAB_SIZE = 32000
Expand All @@ -53,7 +46,6 @@ def _do_sample(
batch_size: int,
input_tensor: torch.Tensor,
sampler: MockLogitsSampler,
model_runner: ModelRunner,
sampling_params: SamplingParams,
device: str,
):
Expand All @@ -75,7 +67,7 @@ def _do_sample(
seq_lens,
query_lens=seq_lens,
device=device,
pin_memory=model_runner.pin_memory)
pin_memory=is_pin_memory_available())
return sampler(logits=input_tensor, sampling_metadata=sampling_metadata)


Expand All @@ -85,28 +77,24 @@ def test_sampler_all_greedy(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
batch_size)
input_tensor, fake_logits, sampler = _prepare_test(batch_size)

sampling_params = SamplingParams(temperature=0)
sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
sampler_output = _do_sample(batch_size, fake_logits, sampler,
sampling_params, device)
expected = torch.argmax(fake_logits, dim=-1)
for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output.samples:
assert nth_output.output_token == expected[i].item()

del model_runner


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_random(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
batch_size)
_, fake_logits, sampler = _prepare_test(batch_size)

for i in range(batch_size):
fake_logits[i, i] = 1e2
Expand All @@ -115,23 +103,21 @@ def test_sampler_all_random(seed: int, device: str):
temperature=1.0,
n=random.randint(1, 10),
)
sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
sampler_output = _do_sample(batch_size, fake_logits, sampler,
sampling_params, device)

for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output.samples:
assert nth_output.output_token == i

del model_runner


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_random_seed(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
_, fake_logits, sampler, model_runner = _prepare_test(batch_size)
_, fake_logits, sampler = _prepare_test(batch_size)

for i in range(batch_size):
fake_logits[i, i] = 1e2
Expand All @@ -141,60 +127,54 @@ def test_sampler_all_random_seed(seed: int, device: str):
n=random.randint(1, 10),
seed=random.randint(0, 10000),
)
sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner,
sampler_output = _do_sample(batch_size, fake_logits, sampler,
sampling_params, device)

for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output.samples:
assert nth_output.output_token == i

del model_runner


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_random_seed_deterministic(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
_, fake_logits, sampler, model_runner = _prepare_test(batch_size)
_, fake_logits, sampler = _prepare_test(batch_size)

sampling_params = SamplingParams(
temperature=1.0,
n=random.randint(1, 10),
seed=random.randint(0, 10000),
)
first_sampler_output = _do_sample(batch_size, fake_logits, sampler,
model_runner, sampling_params, device)
sampling_params, device)

second_sampler_output = _do_sample(batch_size, fake_logits, sampler,
model_runner, sampling_params, device)
sampling_params, device)

assert first_sampler_output == second_sampler_output

del model_runner


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_beam(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
_, fake_logits, sampler, model_runner = _prepare_test(batch_size)
_, fake_logits, sampler = _prepare_test(batch_size)

sampling_params = SamplingParams(
temperature=0,
best_of=2,
use_beam_search=True,
)
_do_sample(batch_size, fake_logits, sampler, model_runner, sampling_params,
device)
_do_sample(batch_size, fake_logits, sampler, sampling_params, device)
# no assertion here as I am not sure how to determine whether
# the outputs are expected - in other words, this just tests
# whether there are no exceptions in the sampler
# when handling an all-beam search case.
del model_runner


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
Expand Down Expand Up @@ -448,13 +428,13 @@ def run_test_case(*,
("Invalid test case, expected_penalization does not match computed"
"batch size")

_, fake_logits, sampler, model_runner = _prepare_test(batch_size)
_, fake_logits, sampler = _prepare_test(batch_size)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
seq_lens=seq_lens if seq_lens else None,
query_lens=seq_lens if seq_lens else None,
device=device,
pin_memory=model_runner.pin_memory)
pin_memory=is_pin_memory_available())
# the logits tensor is modified in-place by the sampler
_ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata)

Expand All @@ -480,8 +460,6 @@ def run_test_case(*,
fake_logits[logits_idx, :] ==
-float('inf')) == 0, "No tokens should have been penalized"

del model_runner

for test_case in test_cases:
run_test_case(**test_case)

Expand All @@ -492,8 +470,7 @@ def test_sampler_mixed(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
batch_size)
input_tensor, fake_logits, sampler = _prepare_test(batch_size)

seq_group_metadata_list = []
expected_tokens: List[Optional[List[int]]] = []
Expand Down Expand Up @@ -534,13 +511,13 @@ def test_sampler_mixed(seed: int, device: str):
))
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())

def test_sampling(model_runner: ModelRunner):
def test_sampling():
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
seq_lens,
query_lens=seq_lens,
device=device,
pin_memory=model_runner.pin_memory)
pin_memory=is_pin_memory_available())
sampler_output = sampler(logits=fake_logits,
sampling_metadata=sampling_metadata)

Expand Down Expand Up @@ -570,7 +547,7 @@ def test_sampling(model_runner: ModelRunner):
assert nth_output.output_token in expected_tokens[i]

# Test batch
test_sampling(model_runner)
test_sampling()

# Shuffle the batch and resample
target_index = list(range(batch_size))
Expand All @@ -583,9 +560,7 @@ def test_sampling(model_runner: ModelRunner):

# This time, results of seeded random samples will be compared with
# the corresponding sample in the pre-shuffled batch
test_sampling(model_runner)

del model_runner
test_sampling()


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
Expand All @@ -605,12 +580,6 @@ def test_sampler_top_k_top_p(seed: int, device: str):
device=input_tensor.device,
dtype=input_tensor.dtype)
sampler = MockLogitsSampler(fake_logits)
model_runner = ModelRunner(model_config=None,
parallel_config=None,
scheduler_config=None,
device_config=None,
load_config=None,
lora_config=None)

generation_model = GenerationMixin()
generation_config = GenerationConfig(top_k=top_k,
Expand Down Expand Up @@ -641,7 +610,7 @@ def test_sampler_top_k_top_p(seed: int, device: str):
seq_lens,
query_lens=seq_lens,
device=device,
pin_memory=model_runner.pin_memory)
pin_memory=is_pin_memory_available())

sample_probs = None

Expand All @@ -657,5 +626,3 @@ def mock_sample(probs, *args, **kwargs):
hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float)
assert torch.allclose(hf_probs, sample_probs, atol=1e-5)
assert torch.equal(hf_probs.eq(0), sample_probs.eq(0))

del model_runner
23 changes: 7 additions & 16 deletions tests/test_logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.worker.model_runner import ModelRunner
from vllm.utils import is_pin_memory_available


class MockLogitsProcessor(LogitsProcessor):
Expand All @@ -30,21 +30,15 @@ def forward(self, *args, **kwargs):


def _prepare_test(
batch_size: int
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsProcessor, ModelRunner]:
batch_size: int
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsProcessor]:
vocab_size = 32000
input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16)
fake_logits = torch.full((batch_size, vocab_size),
1e-2,
dtype=input_tensor.dtype)
logits_processor = MockLogitsProcessor(32000, 0.5, fake_logits)
model_runner = ModelRunner(model_config=None,
parallel_config=None,
scheduler_config=None,
device_config=None,
load_config=None,
lora_config=None)
return input_tensor, fake_logits, logits_processor, model_runner
return input_tensor, fake_logits, logits_processor


RANDOM_SEEDS = list(range(128))
Expand All @@ -59,8 +53,7 @@ def test_logits_processors(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, logits_processor, model_runner = _prepare_test(
batch_size)
input_tensor, fake_logits, logits_processor = _prepare_test(batch_size)

# This sample logits processor gives infinite score to the i-th token,
# where i is the length of the input sequence.
Expand All @@ -87,8 +80,8 @@ def pick_ith(token_ids, logits):
seq_group_metadata_list,
seq_lens,
query_lens=seq_lens,
device=model_runner.device,
pin_memory=model_runner.pin_memory)
device=device,
pin_memory=is_pin_memory_available())
logits_processor_output = logits_processor(
embedding=None,
hidden_states=input_tensor,
Expand All @@ -99,5 +92,3 @@ def pick_ith(token_ids, logits):
fake_logits *= logits_processor.scale
assert torch.allclose(logits_processor_output[:, 1], fake_logits[:, 1],
1e-4)

del model_runner
Loading