Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support per-request seed #2514

Merged
merged 12 commits into from
Feb 21, 2024
222 changes: 147 additions & 75 deletions tests/samplers/test_sampler.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import random
from typing import Tuple
from typing import Tuple, List
from unittest.mock import patch

import pytest
import torch
from transformers import GenerationConfig, GenerationMixin
from typing import Optional

from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.utils import set_random_seed
Expand Down Expand Up @@ -46,15 +47,13 @@ def _prepare_test(
]


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_greedy(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
batch_size)

def _do_sample(
batch_size: int,
input_tensor: torch.Tensor,
sampler: MockLogitsSampler,
model_runner: ModelRunner,
sampling_params: SamplingParams,
):
seq_group_metadata_list = []
prompt_lens = []
for i in range(batch_size):
Expand All @@ -63,17 +62,31 @@ def test_sampler_all_greedy(seed: int, device: str):
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=SamplingParams(temperature=0, ),
sampling_params=sampling_params,
block_tables={0: [1]},
))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())

sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
return sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_greedy(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
batch_size)

sampling_params = SamplingParams(temperature=0)
sampler_output = _do_sample(batch_size, input_tensor, sampler,
model_runner, sampling_params)
expected = torch.argmax(fake_logits, dim=-1)
for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output.samples:
Expand All @@ -94,35 +107,72 @@ def test_sampler_all_random(seed: int, device: str):
for i in range(batch_size):
fake_logits[i, i] = 1e2

seq_group_metadata_list = []
prompt_lens = []
sampling_params = SamplingParams(
temperature=1.0,
n=random.randint(1, 10),
)
sampler_output = _do_sample(batch_size, input_tensor, sampler,
model_runner, sampling_params)

for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output.samples:
assert nth_output.output_token == i

del model_runner


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_random_seed(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
batch_size)

for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=SamplingParams(
temperature=1.0,
n=random.randint(1, 10),
),
block_tables={0: [1]},
))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
fake_logits[i, i] = 1e2

sampling_params = SamplingParams(
temperature=1.0,
n=random.randint(1, 10),
seed=random.randint(0, 10000),
)
sampler_output = _do_sample(batch_size, input_tensor, sampler,
model_runner, sampling_params)

sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output.samples:
assert nth_output.output_token == i

del model_runner


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_random_seed_deterministic(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
batch_size)

sampling_params = SamplingParams(
temperature=1.0,
n=random.randint(1, 10),
seed=random.randint(0, 10000),
)
first_sampler_output = _do_sample(batch_size, input_tensor, sampler,
model_runner, sampling_params)

second_sampler_output = _do_sample(batch_size, input_tensor, sampler,
model_runner, sampling_params)

assert first_sampler_output == second_sampler_output

del model_runner


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_beam(seed: int, device: str):
Expand All @@ -131,29 +181,13 @@ def test_sampler_all_beam(seed: int, device: str):
batch_size = random.randint(1, 256)
input_tensor, _, sampler, model_runner = _prepare_test(batch_size)

seq_group_metadata_list = []
prompt_lens = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=SamplingParams(
temperature=0,
best_of=2,
use_beam_search=True,
),
block_tables={0: [1]},
))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())

sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens)
sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
sampling_params = SamplingParams(
temperature=0,
best_of=2,
use_beam_search=True,
)
_do_sample(batch_size, input_tensor, sampler, model_runner,
sampling_params)
# no assertion here as I am not sure how to determine whether
# the outputs are expected - in other words, this just tests
# whether there are no exceptions in the sampler
Expand All @@ -171,14 +205,15 @@ def test_sampler_mixed(seed: int, device: str):
batch_size)

seq_group_metadata_list = []
expected_tokens = []
expected_tokens: List[Optional[List[int]]] = []
prompt_lens = []
for i in range(batch_size):
n = 1
sampling_type = random.randint(0, 2)
expected: Optional[List[int]] = None
sampling_type = random.randint(0, 3)
if sampling_type == 0:
sampling_params = SamplingParams(temperature=0)
elif sampling_type == 1:
expected = [torch.argmax(fake_logits[i], dim=-1).item()]
elif sampling_type in (1, 2):
n = random.randint(1, 10)
sampling_params = SamplingParams(
temperature=random.random() + 0.1,
Expand All @@ -187,13 +222,17 @@ def test_sampler_mixed(seed: int, device: str):
n=n,
presence_penalty=random.randint(0, 1),
)
if sampling_type == 2:
sampling_params.seed = random.randint(0, 10000)
else:
for idx in range(n):
fake_logits[i, i + idx] = 1e2
expected = list(range(i, i + n))
else:
sampling_params = SamplingParams(temperature=0,
use_beam_search=True,
best_of=2)
for idx in range(n):
fake_logits[i, i + idx] = 1e2
expected_tokens.append(i + idx)
expected_tokens.append(expected)
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
Expand All @@ -204,17 +243,50 @@ def test_sampler_mixed(seed: int, device: str):
))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())

sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
for i, sequence_output in enumerate(sampler_output):
if seq_group_metadata_list[i].sampling_params.use_beam_search:
continue
for nth_output in sequence_output.samples:
assert nth_output.output_token in expected_tokens
def test_sampling(model_runner: ModelRunner):
sampling_metadata = model_runner._prepare_sample(
seq_group_metadata_list, prompt_lens, subquery_lens=prompt_lens)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)

for i, (sequence_output, metadata) in enumerate(
zip(sampler_output, seq_group_metadata_list)):
if metadata.sampling_params.use_beam_search:
continue

if metadata.sampling_params.seed is not None \
and expected_tokens[i] is None:
# Record seeded random result to compare with results of second invocation
expected_tokens[i] = [
nth_output.output_token
for nth_output in sequence_output.samples
]
continue

for n, nth_output in enumerate(sequence_output.samples):
if metadata.sampling_params.temperature == 0 or metadata.sampling_params.seed is not None:
# Ensure exact matches for greedy or random with seed
assert nth_output.output_token == expected_tokens[i][n]
else:
# For non-seeded random check that one of the high-logit tokens were chosen
assert nth_output.output_token in expected_tokens[i]

# Test batch
test_sampling(model_runner)

# Shuffle the batch and resample
target_index = list(range(batch_size))
for list_to_shuffle in (target_index, seq_group_metadata_list,
expected_tokens, prompt_lens):
random.Random(seed).shuffle(list_to_shuffle)
target_index = torch.tensor(target_index)
input_tensor.data = input_tensor.index_select(0, target_index)
fake_logits.data = fake_logits.index_select(0, target_index)

# This time, results of seeded random samples will be compared with the corresponding
# sample in the pre-shuffled batch
test_sampling(model_runner)

del model_runner

Expand Down
82 changes: 82 additions & 0 deletions tests/samplers/test_seeded_generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""Verify that seeded random sampling is deterministic.

Run `pytest tests/samplers/test_seeded_generate.py --forked`.
"""
import copy
import random
from itertools import combinations

import pytest

from vllm.model_executor.utils import set_random_seed
from vllm import SamplingParams

MODEL = "facebook/opt-125m"
RANDOM_SEEDS = list(range(5))


@pytest.fixture
def vllm_model(vllm_runner):
vllm_model = vllm_runner(MODEL, dtype="half")
yield vllm_model
del vllm_model


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_random_sample_with_seed(
vllm_model,
example_prompts,
seed: int,
) -> None:
set_random_seed(seed)

sampling_params = SamplingParams(
# Parameters to ensure sufficient randomness
temperature=2.0,
top_p=min(random.random() + 0.3, 1),
top_k=random.randint(5, 20),
n=random.randint(1, 10),
presence_penalty=random.randint(0, 1),
max_tokens=8,
ignore_eos=True,
)

sampling_params_seed_1 = copy.deepcopy(sampling_params)
sampling_params_seed_1.seed = 100
sampling_params_seed_2 = copy.deepcopy(sampling_params)
sampling_params_seed_2.seed = 200

llm = vllm_model.model

for prompt in example_prompts:
for params in (
sampling_params,
sampling_params_seed_1,
sampling_params_seed_2,
sampling_params,
sampling_params_seed_1,
sampling_params_seed_2,
):
llm._add_request(
prompt=prompt,
prompt_token_ids=None,
sampling_params=params,
)

results = llm._run_engine(use_tqdm=False)
all_outputs = [[out.token_ids for out in output.outputs]
for output in results]

for i in range(0, len(example_prompts), 6):
outputs = all_outputs[i:i + 6]

# verify all non-seeded requests differ
for output_a, output_b in combinations(
(outputs[0], outputs[1], outputs[2], outputs[3]),
2,
):
assert output_a != output_b

# verify requests with the same seed match
assert outputs[1] == outputs[4]
assert outputs[2] == outputs[5]
1 change: 1 addition & 0 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
block_tables=block_tables,
lora_request=seq_group.lora_request,
prefix=seq_group.prefix,
state=seq_group.state,
)
seq_group_metadata_list.append(seq_group_metadata)
return seq_group_metadata_list, scheduler_outputs
Expand Down
Loading
Loading