Skip to content

Commit

Permalink
Add support for AWQ-quantized Idefics2 (huggingface#2233)
Browse files Browse the repository at this point in the history
  • Loading branch information
danieldk authored and yuanwu2017 committed Sep 25, 2024
1 parent 8a223eb commit e955f7b
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 11 deletions.
35 changes: 24 additions & 11 deletions server/text_generation_server/models/custom_modeling/idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
TensorParallelEmbedding,
TensorParallelRowLinear,
)
from text_generation_server.utils.weights import DefaultWeightsLoader


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
Expand Down Expand Up @@ -682,7 +683,7 @@ def forward(self, image_hidden_states, attention_mask):
class Idefics2ForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
config.vision_config.quantize = config.quantize
config.vision_config.quantize = None
config.vision_config.speculator = config.speculator
config.text_config.quantize = config.quantize
config.text_config.speculator = config.speculator
Expand All @@ -695,16 +696,28 @@ def __init__(self, prefix, config, weights):
name="text_model",
)
self.dtype = weights.dtype
self.vision_model = Idefics2VisionTransformer(
prefix=f"{prefix}.model.vision_model" if prefix else "model.vision_model",
config=vision_config,
weights=weights,
)
self.connector = Idefics2Connector(
prefix=f"{prefix}.model.connector" if prefix else "model.connector",
config=config,
weights=weights,
)

# The vision and connector models are not quantized.
with weights.use_loader(DefaultWeightsLoader()):
self.vision_model = Idefics2VisionTransformer(
prefix=(
f"{prefix}.model.vision_model" if prefix else "model.vision_model"
),
config=vision_config,
weights=weights,
)

quantize = config.quantize
try:
config.quantize = None
self.connector = Idefics2Connector(
prefix=f"{prefix}.model.connector" if prefix else "model.connector",
config=config,
weights=weights,
)
finally:
config.quantize = quantize

self.config = config
self.image_seq_len = config.perceiver_config.resampler_n_latents
self.image_token_id = config.image_token_id
Expand Down
15 changes: 15 additions & 0 deletions server/text_generation_server/utils/weights.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from contextlib import contextmanager
from pathlib import Path
from typing import Dict, List, Optional, Union
from safetensors import safe_open
Expand Down Expand Up @@ -306,6 +307,20 @@ def get_tensor_shard(self, var, dim):
def get_weights_row(self, prefix: str):
return self.weights_loader.get_weights_row(self, prefix)

@contextmanager
def use_loader(self, weights_loader: WeightsLoader):
"""
This method is a context manager that can be used to use `Weights` with
a different loader for the duration of the context.
"""

old_loader = self.weights_loader
self.weights_loader = weights_loader
try:
yield
finally:
self.weights_loader = old_loader


def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]:
"""
Expand Down

0 comments on commit e955f7b

Please sign in to comment.