diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 50cae1041bc8f..f49e4537e5945 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -182,6 +182,10 @@ Vision Language Models - Models - Example HuggingFace Models - :ref:`LoRA ` + * - :code:`ChameleonForConditionalGeneration` + - Chameleon + - :code:`facebook/chameleon-7b` etc. + - * - :code:`FuyuForCausalLM` - Fuyu - :code:`adept/fuyu-8b` etc. diff --git a/tests/models/test_chameleon.py b/tests/models/test_chameleon.py new file mode 100644 index 0000000000000..6e775da24d14e --- /dev/null +++ b/tests/models/test_chameleon.py @@ -0,0 +1,102 @@ +import re +from typing import List, Optional, Type + +import pytest + +from vllm.multimodal.utils import rescale_image_size + +from ..conftest import IMAGE_ASSETS, VllmRunner, _ImageAssets + +pytestmark = pytest.mark.vlm + +HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ + "stop_sign": + "USER: \nWhat's the content of the image?\nASSISTANT:", + "cherry_blossom": + "USER: \nWhat is the season?\nASSISTANT:", +}) + +models = ["facebook/chameleon-7b"] + + +#TODO (ywang96): Add correctness test when chameleon is +# available on transformers. +def run_test( + vllm_runner: Type[VllmRunner], + image_assets: _ImageAssets, + model: str, + *, + size_factors: List[float], + dtype: str, + max_tokens: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +): + """Test if the model can generate text given + a batch of images and prompts. + + """ + images = [asset.pil_image for asset in image_assets] + + inputs_per_image = [( + [prompt for _ in size_factors], + [rescale_image_size(image, factor) for factor in size_factors], + ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] + + with vllm_runner(model, + max_model_len=4096, + dtype=dtype, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=True) as vllm_model: + + for prompts, images in inputs_per_image: + vllm_outputs = vllm_model.generate_greedy(prompts, + max_tokens, + images=images) + for i in range(len(vllm_outputs)): + + # format prompt back to original + replacements = { + "": "", + "": "", + "": "" + } + pattern = '|'.join(replacements.keys()) + vllm_result = re.sub( + pattern, + lambda match: replacements[match.group(0)], #noqa B023 + vllm_outputs[i][1]) + vllm_result = vllm_result.replace("", "", 1023) + assert vllm_result[:len(prompts[i])] == prompts[i] + + # assert at least 10 new characters are generated + # (to take stop token into account) + assert len(vllm_outputs[i][1]) - len(prompts[i]) > 10 + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize( + "size_factors", + [ + # Single-scale + [1.0], + # Single-scale, batched + [1.0, 1.0, 1.0], + # Multi-scale + [0.25, 0.5, 1.0], + ], +) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [128]) +def test_models(vllm_runner, image_assets, model, size_factors, dtype: str, + max_tokens: int) -> None: + run_test( + vllm_runner, + image_assets, + model, + size_factors=size_factors, + dtype=dtype, + max_tokens=max_tokens, + tensor_parallel_size=1, + ) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 7b5cbbb251b1f..83abc40888137 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -105,7 +105,8 @@ def _image_token_str(model_config: ModelConfig, return None if model_type.startswith("llava"): return tokenizer.decode(model_config.hf_config.image_token_index) - + if model_type == "chameleon": + return "" raise TypeError("Unknown model type: {model_type}") diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 8df0a0034c023..31370aebba599 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -16,9 +16,10 @@ "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b "BloomForCausalLM": ("bloom", "BloomForCausalLM"), - "ChameleonForCausalLM": - ("chameleon", "ChameleonForConditionalGeneration" - ), #TODO(ywang96): fix model name when huggingface fixes it + #TODO(ywang96): remove this when huggingface fixes the model repo + "ChameleonForCausalLM": ("chameleon", "ChameleonForConditionalGeneration"), + "ChameleonForConditionalGeneration": + ("chameleon", "ChameleonForConditionalGeneration"), "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"), "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"), "CohereForCausalLM": ("commandr", "CohereForCausalLM"), diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 02a3cb02769f4..d06eb0504079f 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -1,13 +1,17 @@ from functools import cached_property -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import (Any, Dict, Iterable, List, Literal, Optional, Tuple, + TypedDict) import torch import torch.nn.functional as F +from PIL import Image from torch import nn from vllm.attention import Attention, AttentionMetadata -from vllm.config import CacheConfig +from vllm.config import CacheConfig, MultiModalConfig from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -22,10 +26,114 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput -from vllm.transformers_utils.configs import ChameleonConfig +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.image import (cached_get_tokenizer, + repeat_and_pad_image_tokens) +from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData +from vllm.transformers_utils.configs import (ChameleonConfig, + ChameleonVQVAEConfig) from vllm.utils import print_warning_once +from .interfaces import SupportsVision + +logger = init_logger(__name__) + +# These configs are not part of the model config but the preprocessor +# and processor files, so we hardcode them in the model file for now. +CHAMELEON_CROP_SIZE_HEIGHT = CHAMELEON_CROP_SIZE_WIDTH = 512 +CHAMELEON_IMAGE_SEQ_LENGTH = 1024 +CHAMELEON_IMAGE_TOKEN_ID = 8711 +CHAMELEON_IMAGE_START_TOKEN_ID = 8197 +CHAMELEON_IMAGE_END_TOKEN_ID = 8196 +CHAMELEON_SEP_TOKEN_ID = 8710 + + +class ChameleonImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: torch.Tensor + """Shape: `(batch_size, num_channels, height, width)`""" + + +def get_max_chameleon_image_tokens(ctx: InputContext): + return CHAMELEON_IMAGE_SEQ_LENGTH + + +def dummy_seq_data_for_chameleon( + seq_len: int, + *, + image_token_id: int, + image_feature_size_override: Optional[int] = None, +): + if image_feature_size_override is None: + image_feature_size = CHAMELEON_IMAGE_SEQ_LENGTH + else: + image_feature_size = image_feature_size_override + + token_ids = [image_token_id] * image_feature_size + token_ids += [0] * (seq_len - image_feature_size) + return SequenceData(token_ids) + + +def dummy_image_for_chameleon( + image_width_override: Optional[int] = None, + image_height_override: Optional[int] = None, +): + width = CHAMELEON_CROP_SIZE_WIDTH + height = CHAMELEON_CROP_SIZE_HEIGHT + if image_width_override is not None: + width = image_width_override + if image_height_override is not None: + height = image_height_override + + image = Image.new("RGB", (width, height), color=0) + return {"image": image} + + +def dummy_data_for_chameleon(ctx: InputContext, seq_len: int): + + seq_data = dummy_seq_data_for_chameleon( + seq_len, + image_token_id=CHAMELEON_IMAGE_TOKEN_ID, + ) + + mm_data = dummy_image_for_chameleon() + return seq_data, mm_data + + +def input_processor_for_chameleon(ctx: InputContext, llm_inputs: LLMInputs): + + """ + Processing input prompt to insert required tokens for image placeholder. + + See https://github.com/huggingface/transformers/blob/0fdea8607d7e01eb0e38a1ebeb7feee30a22f0cf/src/transformers/models/chameleon/processing_chameleon.py#L58 + """ # noqa + + multi_modal_data = llm_inputs.get("multi_modal_data") + if multi_modal_data is None or "image" not in multi_modal_data: + return llm_inputs + + model_config = ctx.model_config + tokenizer = cached_get_tokenizer(model_config.tokenizer) + new_prompt, new_token_ids = repeat_and_pad_image_tokens( + tokenizer, + llm_inputs.get("prompt"), + llm_inputs["prompt_token_ids"], + image_token_id=CHAMELEON_IMAGE_TOKEN_ID, + repeat_count=CHAMELEON_IMAGE_SEQ_LENGTH, + pad_token_left=CHAMELEON_IMAGE_START_TOKEN_ID, + pad_token_right=CHAMELEON_IMAGE_END_TOKEN_ID, + ) + + # Appending sep token for chat mode to follow default processor + # behavior + new_prompt += tokenizer.sep_token + new_token_ids += [CHAMELEON_SEP_TOKEN_ID] + + # NOTE: Create a defensive copy of the original inputs + return LLMInputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) + class ChameleonLayerNorm(nn.LayerNorm): @@ -318,12 +426,333 @@ def forward( return hidden_states, residual +# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEVectorQuantizer #noqa +class ChameleonVQVAEVectorQuantizer(nn.Module): + + def __init__(self, config: ChameleonVQVAEConfig): + super().__init__() + self.num_embeddings = config.num_embeddings + self.embedding_dim = config.embed_dim + self.beta = getattr(config, "beta", 0.25) + + self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim) + self.re_embed = self.num_embeddings + + def forward(self, hidden_state: torch.Tensor): + hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous() + hidden_state_flattened = hidden_state.view(-1, self.embedding_dim) + + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + distances = ( + torch.sum(hidden_state_flattened**2, dim=1, keepdim=True) + + torch.sum(self.embedding.weight**2, dim=1) - + 2 * torch.einsum("bd,dn->bn", hidden_state_flattened, + self.embedding.weight.transpose(0, 1))) + + min_encoding_indices = torch.argmin(distances, dim=1) + hidden_state_quant = self.embedding(min_encoding_indices).view( + hidden_state.shape) + + # compute loss for embedding + loss = torch.mean((hidden_state_quant.detach() - hidden_state)** + 2) + self.beta * torch.mean( + (hidden_state_quant - hidden_state.detach())**2) + + # preserve gradients + hidden_state_quant = hidden_state + (hidden_state_quant - + hidden_state).detach() + + # reshape back to match original input shape + hidden_state_quant = hidden_state_quant.permute(0, 3, 1, + 2).contiguous() + + return hidden_state_quant, loss, min_encoding_indices + + +# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderConvDownsample #noqa +class ChameleonVQVAEEncoderConvDownsample(nn.Module): + + def __init__(self, in_channels: int): + super().__init__() + self.conv = nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, hidden_states: torch.Tensor): + # no asymmetric padding in torch conv, must do it ourselves + hidden_states = F.pad(hidden_states, + pad=(0, 1, 0, 1), + mode="constant", + value=0) + hidden_states = self.conv(hidden_states) + return hidden_states + + +# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderResnetBlock #noqa +class ChameleonVQVAEEncoderResnetBlock(nn.Module): + + def __init__( + self, + config: ChameleonVQVAEConfig, + in_channels: int, + out_channels=None, + conv_shortcut=False, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None \ + else out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = torch.nn.GroupNorm(num_groups=32, + num_channels=in_channels, + eps=1e-6, + affine=True) + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + self.norm2 = torch.nn.GroupNorm(num_groups=32, + num_channels=out_channels, + eps=1e-6, + affine=True) + self.dropout = torch.nn.Dropout(config.dropout) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, hidden_states: torch.Tensor): + residual = hidden_states + hidden_states = self.norm1(hidden_states) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.conv1(hidden_states) + + hidden_states = self.norm2(hidden_states) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + residual = self.conv_shortcut(residual) + else: + residual = self.nin_shortcut(residual) + + return residual + hidden_states + + +# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderAttnBlock #noqa +class ChameleonVQVAEEncoderAttnBlock(nn.Module): + + def __init__(self, in_channels: int): + super().__init__() + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=32, + num_channels=in_channels, + eps=1e-6, + affine=True) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, hidden_states: torch.Tensor): + residual = hidden_states + hidden_states = self.norm(hidden_states) + query_states = self.q(hidden_states) + key_states = self.k(hidden_states) + value_states = self.v(hidden_states) + + # compute attention + batch_size, channels, height, width = query_states.shape + query_states = query_states.reshape(batch_size, channels, + height * width).permute(0, 2, 1) + key_states = key_states.reshape(batch_size, channels, height * width) + attn_weights = torch.bmm(query_states, key_states) + attn_weights = attn_weights * (int(channels)**(-0.5)) + attn_weights = F.softmax(attn_weights, dim=2) + + # attend to values + value_states = value_states.reshape(batch_size, channels, + height * width) + attn_weights = attn_weights.permute(0, 2, 1) + attn_output = torch.bmm(value_states, + attn_weights).reshape(batch_size, channels, + height, width) + + attn_output = self.proj_out(attn_output) + return residual + attn_output + + +# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoder #noqa +class ChameleonVQVAEEncoder(nn.Module): + + def __init__(self, config: ChameleonVQVAEConfig): + super().__init__() + + self.num_resolutions = len(config.channel_multiplier) + self.num_res_blocks = config.num_res_blocks + base_channels = config.base_channels + resolution = config.resolution + in_channels = config.in_channels + double_latent = config.double_latent + latent_channels = config.latent_channels + channel_multiplier = config.channel_multiplier + + self.conv_in = torch.nn.Conv2d(in_channels, + base_channels, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_channel_multiplier = (1, ) + tuple(channel_multiplier) + self.in_channel_multiplier = in_channel_multiplier + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = base_channels * in_channel_multiplier[i_level] + block_out = base_channels * channel_multiplier[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ChameleonVQVAEEncoderResnetBlock( + config=config, + in_channels=block_in, + out_channels=block_out, + )) + block_in = block_out + if (config.attn_resolutions is not None + and curr_res in config.attn_resolutions + and config.attn_type == "vanilla"): + attn.append(ChameleonVQVAEEncoderAttnBlock(block_in)) + + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = ChameleonVQVAEEncoderConvDownsample(block_in) + curr_res = curr_res // 2 + self.down.append(down) + + self.mid = nn.Module() + self.mid.block_1 = ChameleonVQVAEEncoderResnetBlock( + config=config, + in_channels=block_in, + out_channels=block_in, + ) + self.mid.attn_1 = ChameleonVQVAEEncoderAttnBlock( + block_in) if config.attn_type == "vanilla" else nn.Identity() + self.mid.block_2 = ChameleonVQVAEEncoderResnetBlock( + config=config, + in_channels=block_in, + out_channels=block_in, + ) + + self.norm_out = torch.nn.GroupNorm(num_groups=32, + num_channels=block_in, + eps=1e-6, + affine=True) + self.conv_out = torch.nn.Conv2d( + block_in, + 2 * latent_channels if double_latent else latent_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + def forward(self, pixel_values: torch.Tensor): + # downsampling + hidden_states = [self.conv_in(pixel_values)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + hidden_state = self.down[i_level].block[i_block]( + hidden_states[-1], ) + if len(self.down[i_level].attn) > 0: + hidden_state = self.down[i_level].attn[i_block]( + hidden_state) + hidden_states.append(hidden_state) + if i_level != self.num_resolutions - 1: + hidden_states.append(self.down[i_level].downsample( + hidden_states[-1])) + + # middle + last_hidden_state = hidden_states[-1] + last_hidden_state = self.mid.block_1(last_hidden_state) + last_hidden_state = self.mid.attn_1(last_hidden_state) + last_hidden_state = self.mid.block_2(last_hidden_state) + + # end + last_hidden_state = self.norm_out(last_hidden_state) + last_hidden_state *= torch.sigmoid(last_hidden_state) + last_hidden_state = self.conv_out(last_hidden_state) + return last_hidden_state + + +# Adapted from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAE #noqa +class ChameleonVQVAE(nn.Module): + + def __init__(self, config: ChameleonVQVAEConfig): + super().__init__() + self.encoder = ChameleonVQVAEEncoder(config) + self.quantize = ChameleonVQVAEVectorQuantizer(config) + self.quant_conv = torch.nn.Conv2d(config.latent_channels, + config.embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(config.embed_dim, + config.latent_channels, 1) + self.eval() # Chameleon's VQ model is frozen + + def encode( + self, pixel_values: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + hidden_states = self.encoder(pixel_values) + hidden_states = self.quant_conv(hidden_states) + quant, emb_loss, indices = self.quantize(hidden_states) + return quant, emb_loss, indices + + +# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonImageVocabularyMapping #noqa class ChameleonImageVocabularyMapping: """ A class for mapping discrete image tokens from VQGAN to BPE tokens. """ - def __init__(self, vocab_map): + def __init__(self, vocab_map: Dict[str, int]): self.vocab_map = vocab_map self.image_token_id = vocab_map.get("") @@ -401,13 +830,23 @@ def __init__( for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - # TODO: Support image input - # self.vqmodel = ChameleonVQModel(config.vq_config) + self.vqmodel = ChameleonVQVAE(config.vq_config) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) + def get_image_tokens(self, pixel_values: torch.Tensor) -> torch.Tensor: + """ + Tokenizes images into discrete tokens with VQGAN module. Converts + obtained image tokens into BPE tokens and wraps with "boi" and "eoi" + special tokens. + """ + batch_size = pixel_values.shape[0] + _, _, image_toks = self.vqmodel.encode(pixel_values) + bpe_toks = self.vocabulary_mapping.convert_img2bpe(image_toks) + bpe_toks = bpe_toks.view(batch_size, -1) + return bpe_toks + def forward( self, input_ids: Optional[torch.Tensor], @@ -434,16 +873,22 @@ def forward( return hidden_states -class ChameleonForConditionalGeneration(nn.Module): +@MULTIMODAL_REGISTRY.register_image_input_mapper() +@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_chameleon_image_tokens) +@INPUT_REGISTRY.register_dummy_data(dummy_data_for_chameleon) +@INPUT_REGISTRY.register_input_processor(input_processor_for_chameleon) +class ChameleonForConditionalGeneration(nn.Module, SupportsVision): def __init__( self, config: ChameleonConfig, + multimodal_config: MultiModalConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config + self.multimodal_config = multimodal_config self.model = ChameleonModel(config, cache_config, quant_config) self.unpadded_vocab_size = config.vocab_size self.lm_head = ParallelLMHead( @@ -458,6 +903,36 @@ def __init__( config.vocab_size, logit_scale) self.sampler = Sampler() + def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: + + expected_dims = (3, CHAMELEON_CROP_SIZE_HEIGHT, + CHAMELEON_CROP_SIZE_WIDTH) + actual_dims = tuple(data.shape[1:]) + + if actual_dims != expected_dims: + expected_expr = ("batch_size", *map(str, expected_dims)) + raise ValueError( + f"The expected shape of pixel values is {expected_expr}. " + f"You supplied {tuple(data.shape)}.") + + return data + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[ChameleonImagePixelInputs]: + pixel_values = kwargs.pop("pixel_values", None) + + if pixel_values is None: + return None + + if not isinstance(pixel_values, torch.Tensor): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + return ChameleonImagePixelInputs( + type="pixel_values", + data=self._validate_pixel_values(pixel_values), + ) + def forward( self, input_ids: torch.Tensor, @@ -468,10 +943,17 @@ def forward( **kwargs, ) -> torch.Tensor: - # TODO (ywang96): Support image input - # image_tokens = self.process_image_input(**kwargs) - # image_mask = input_ids == self.vocabulary_mapping.image_token_id - # input_ids[special_image_mask] = image_tokens.flatten().to(input_ids.dtype) #noqa + image_input = self._parse_and_validate_image_input(**kwargs) + + if image_input is not None: + assert self.model.vqmodel is not None + image_tokens = self.model.get_image_tokens(image_input["data"].to( + self.config.torch_dtype)) + image_token_id = self.model.vocabulary_mapping.image_token_id + special_image_mask = input_ids == image_token_id + image_tokens = image_tokens.to(input_ids.device, input_ids.dtype) + input_ids = input_ids.masked_scatter(special_image_mask, + image_tokens) hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata) @@ -511,43 +993,52 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if "rotary_emb.inv_freq" in name: continue - # Skip loading vqgan - # TODO: add support for the vision model - if "vqmodel" in name: - continue if ("rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name): # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break + use_default_weight_loading = False + if "vqmodel" in name: + if self.model.vqmodel is not None: + # We only do sharding for language model and + # not vqvae for now. + use_default_weight_loading = True else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # Remapping the name of FP8 kv-scale. - if name.endswith("kv_scale"): - remapped_kv_scale_name = name.replace( - ".kv_scale", ".attn.kv_scale") - if remapped_kv_scale_name not in params_dict: - print_warning_once( - f"Found kv scale in the checkpoint (e.g. {name}), " - "but not found the expected name in the model " - f"(e.g. {remapped_kv_scale_name}). kv-scale is " - "not loaded.") + for (param_name, weight_name, + shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: continue - else: - name = remapped_kv_scale_name + # Remapping the name of FP8 kv-scale. + if name.endswith("kv_scale"): + remapped_kv_scale_name = name.replace( + ".kv_scale", ".attn.kv_scale") + if remapped_kv_scale_name not in params_dict: + print_warning_once( + "Found kv scale in the checkpoint (e.g. " + f"{name}), but not found the expected name in " + f"the model (e.g. {remapped_kv_scale_name}). " + "kv-scale is not loaded.") + continue + else: + name = remapped_kv_scale_name + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + if use_default_weight_loading and name in params_dict: param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index d177fb2e49178..080c0777ebdcc 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -1,4 +1,5 @@ -from vllm.transformers_utils.configs.chameleon import ChameleonConfig +from vllm.transformers_utils.configs.chameleon import (ChameleonConfig, + ChameleonVQVAEConfig) from vllm.transformers_utils.configs.chatglm import ChatGLMConfig from vllm.transformers_utils.configs.dbrx import DbrxConfig # RWConfig is for the original tiiuae/falcon-40b(-instruct) and @@ -12,6 +13,7 @@ __all__ = [ "ChameleonConfig", + "ChameleonVQVAEConfig", "ChatGLMConfig", "DbrxConfig", "MPTConfig", diff --git a/vllm/transformers_utils/configs/chameleon.py b/vllm/transformers_utils/configs/chameleon.py index 73f0e0c33989d..c1ac1182e14c4 100644 --- a/vllm/transformers_utils/configs/chameleon.py +++ b/vllm/transformers_utils/configs/chameleon.py @@ -1,3 +1,5 @@ +from typing import List, Optional + from transformers import PretrainedConfig @@ -5,9 +7,7 @@ # transformers once the new release with Chameleon support # is available. class ChameleonConfig(PretrainedConfig): - model_type = "chameleon" - is_composition = True keys_to_ignore_at_inference = ["past_key_values"] def __init__( @@ -31,7 +31,7 @@ def __init__( rope_scaling=None, attention_bias=False, attention_dropout=0.0, - qk_layernorm=False, + model_parallel_size=1, swin_norm=False, vq_config=None, vocabulary_map=None, @@ -46,10 +46,6 @@ def __init__( self.num_attention_heads = num_attention_heads self.mlp_bias = mlp_bias - # for backward compatibility - if num_key_value_heads is None: - num_key_value_heads = num_attention_heads - self.num_key_value_heads = num_key_value_heads self.hidden_act = hidden_act self.initializer_range = initializer_range @@ -60,10 +56,14 @@ def __init__( self._rope_scaling_validation() self.attention_bias = attention_bias self.attention_dropout = attention_dropout - self.qk_layernorm = qk_layernorm + self.model_parallel_size = model_parallel_size self.swin_norm = swin_norm - # vq config is currently ignored - # self.vq_config = ChameleonVQConfig(**vq_config) + + if vq_config is None: + vq_config = {} + + self.vq_config = ChameleonVQVAEConfig(**vq_config) + self.vocabulary_map = vocabulary_map super().__init__( @@ -99,3 +99,40 @@ def _rope_scaling_validation(self): raise ValueError( "`rope_scaling`'s factor field must be a float > 1, " f"got {rope_scaling_factor}") + + +class ChameleonVQVAEConfig(PretrainedConfig): + + model_type = "chameleon_vqgan" + + def __init__( + self, + embed_dim: int = 256, + num_embeddings: int = 8192, + double_latent: bool = False, + latent_channels: int = 256, + resolution: int = 512, + in_channels: int = 3, + base_channels: int = 128, + channel_multiplier: List[int] = [1, 1, 2, 2, 4], #noqa + num_res_blocks: int = 2, + attn_resolutions: Optional[List[int]] = None, + dropout: float = 0.0, + attn_type: str = "vanilla", + initializer_range=0.02, + **kwargs, + ): + super().__init__(**kwargs) + self.embed_dim = embed_dim + self.num_embeddings = num_embeddings + self.double_latent = double_latent + self.latent_channels = latent_channels + self.resolution = resolution + self.in_channels = in_channels + self.base_channels = base_channels + self.channel_multiplier = channel_multiplier + self.num_res_blocks = num_res_blocks + self.attn_resolutions = attn_resolutions + self.dropout = dropout + self.attn_type = attn_type + self.initializer_range = initializer_range