diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e90b76dcdd9ad..821c9e1380280 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1158,11 +1158,12 @@ def profile_run(self) -> None: # Trigger compilation for general shape. hidden_states = self._dummy_run(self.max_num_tokens, dummy_kv_caches) - if not get_pp_group().is_last_rank: - return hidden_states - hidden_states = hidden_states[logit_indices] - logits = self.model.compute_logits(hidden_states, None) - # TODO(woosuk): Consider the memory usage of the sampler. + if get_pp_group().is_last_rank: + hidden_states = hidden_states[logit_indices] + logits = self.model.compute_logits(hidden_states, None) + # TODO(woosuk): Consider the memory usage of the sampler. + else: + logits = None torch.cuda.synchronize() del hidden_states, logits self.encoder_cache.clear()