Skip to content

Commit

Permalink
[Bugfix] Lower gemma's unloaded_params exception to warning (#7002)
Browse files Browse the repository at this point in the history
  • Loading branch information
mgoin authored Aug 1, 2024
1 parent fb3db61 commit f4fd390
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 9 deletions.
6 changes: 3 additions & 3 deletions vllm/model_executor/models/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loaded_params.add(name)
unloaded_params = params_dict.keys() - loaded_params
if unloaded_params:
raise RuntimeError(
"Some weights are not initialized from checkpoints: "
f"{unloaded_params}")
logger.warning(
"Some weights are not initialized from checkpoints: %s",
unloaded_params)
9 changes: 6 additions & 3 deletions vllm/model_executor/models/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
Expand All @@ -41,6 +42,8 @@

from .interfaces import SupportsLoRA

logger = init_logger(__name__)


class Gemma2MLP(nn.Module):

Expand Down Expand Up @@ -390,6 +393,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

unloaded_params = params_dict.keys() - loaded_params
if unloaded_params:
raise RuntimeError(
"Some weights are not initialized from checkpoints: "
f"{unloaded_params}")
logger.warning(
"Some weights are not initialized from checkpoints: %s",
unloaded_params)
6 changes: 3 additions & 3 deletions vllm/model_executor/models/paligemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

unloaded_params = params_dict.keys() - loaded_params
if unloaded_params:
raise RuntimeError(
"Some weights are not initialized from checkpoints: "
f"{unloaded_params}")
logger.warning(
"Some weights are not initialized from checkpoints: %s",
unloaded_params)

0 comments on commit f4fd390

Please sign in to comment.