diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index cc78a0ea3b869..41abdf211e7e7 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -118,8 +118,9 @@ def forward( sampling_tensors.frequency_penalties, sampling_tensors.repetition_penalties) - # Apply temperature scaling. + # Use float32 to apply temperature scaling. # Use in-place division to avoid creating a new tensor. + logits = logits.to(torch.float) logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1)) if do_top_p_top_k: diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 1c1e5f16b5172..04250c682cd23 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -13,6 +13,7 @@ logger = init_logger(__name__) _SAMPLING_EPS = 1e-5 +_MAX_TEMP = 1e-2 class SamplingType(IntEnum): @@ -145,6 +146,12 @@ def __init__( self.presence_penalty = presence_penalty self.frequency_penalty = frequency_penalty self.repetition_penalty = repetition_penalty + if 0 < temperature < _MAX_TEMP: + logger.warning( + "temperature %s is less than %s, which may cause numerical " + "errors nan or inf in tensors. We have maxed it out to %s.", + temperature, _MAX_TEMP, _MAX_TEMP) + temperature = max(temperature, _MAX_TEMP) self.temperature = temperature self.top_p = top_p self.top_k = top_k