Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sn1 437 implement gemma 3 27 b it #644

Merged
merged 24 commits into from
Mar 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3,353 changes: 1,652 additions & 1,701 deletions poetry.lock

Large diffs are not rendered by default.

7 changes: 5 additions & 2 deletions prompting/datasets/sn13.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,12 @@ def sample(self) -> ChatEntry:
raise self.exception
# Randomly select a sample from the dataset.
messages = []
for _ in range(4):
for i in range(4):
sample_idx = random.randint(0, len(self.dataset) - 1)
if message := self.dataset[sample_idx]["text"]:
messages.append({"role": random.choice(["user", "assistant"]), "content": message})
if i % 2 == 0:
messages.append({"role": "user", "content": message})
else:
messages.append({"role": "assistant", "content": message})

return ChatEntry(messages=messages, organic=False, source=self._url)
81 changes: 34 additions & 47 deletions prompting/llms/hf_llm.py
Original file line number Diff line number Diff line change
@@ -1,66 +1,58 @@
import random
from abc import abstractmethod

import numpy as np
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, pipeline
from loguru import logger

try:
import torch
except ImportError:
logger.warning("torch is not installed. This module will not be available.")


class ReproducibleHF:
def __init__(
self,
model_id: str = "hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4",
device: str = "cuda:0",
sampling_params: dict[str, str | float | int | bool] | None = None,
):
"""Deterministic HuggingFace model."""
def __init__(self, model_id: str, device: str, sampling_params: dict[str, str | float | int | bool] | None = None):
self.model_id = model_id
self._device = device
self.sampling_params = {} if sampling_params is None else sampling_params
self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
device_map=self._device,
)
self.sampling_params = sampling_params if sampling_params else {}

self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.valid_generation_params = set(
AutoModelForCausalLM.from_pretrained(model_id).generation_config.to_dict().keys()
)
self.llm = pipeline("text-generation", model=self.model, tokenizer=self.tokenizer)
@staticmethod
@abstractmethod
def format_messages(messages: list[str] | list[dict[str, str]]) -> list[dict[str, str | list[dict[str, str]]]]:
raise NotImplementedError("This method must be implemented by the subclass")

@torch.inference_mode()
def generate(
self,
messages: list[str] | list[dict[str, str]],
sampling_params: dict[str, str | float | int | bool] | None = None,
seed: int | None = None,
) -> str:
"""Generate text with optimized performance."""
self.set_random_seeds(seed)
with torch.inference_mode():
self.set_random_seeds(seed)

inputs = self.tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
return_dict=True,
).to(self._device)
inputs = self.tokenizer.apply_chat_template(
self.message_formater(messages),
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
return_dict=True,
).to(self._device)

params = sampling_params if sampling_params else self.sampling_params
filtered_params = {k: v for k, v in params.items() if k in self.valid_generation_params}
params = sampling_params if sampling_params else self.sampling_params
filtered_params = {k: v for k, v in params.items() if k in self.valid_generation_params}

outputs = self.model.generate(
**inputs,
**filtered_params,
eos_token_id=self.tokenizer.eos_token_id,
)
outputs = self.model.generate(
**inputs,
**filtered_params,
)

results = self.tokenizer.batch_decode(
outputs[:, inputs["input_ids"].shape[1] :],
skip_special_tokens=True,
)[0]
results = self.tokenizer.batch_decode(
outputs[:, inputs["input_ids"].shape[1] :],
skip_special_tokens=True,
)[0]

return results if len(results) > 1 else results[0]
return results if len(results) > 1 else results[0]

def set_random_seeds(self, seed: int | None = 42):
"""Set random seeds for reproducibility across all relevant libraries."""
Expand All @@ -72,8 +64,3 @@ def set_random_seeds(self, seed: int | None = 42):
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


# if __name__ == "__main__":
# llm = ReproducibleHF(model="Qwen/Qwen2-0.5B", tensor_parallel_size=1, seed=42)
# llm.generate({"role": "user", "content": "Hello, world!"})
32 changes: 32 additions & 0 deletions prompting/llms/hf_text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from loguru import logger

try:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel
except ImportError:
logger.warning("Transformers or torch is not installed. This module will not be available.")

from .hf_llm import ReproducibleHF


class HFTextGeneration(ReproducibleHF):
def __init__(
self,
model_id: str = "meta-llama/Meta-Llama-3.1-70B-Instruct-AWQ-INT4",
device: str = "cuda:0",
sampling_params: dict[str, str | float | int | bool] | None = None,
):
super().__init__(model_id, device, sampling_params)
self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
device_map=self._device,
)
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.valid_generation_params = set(self.model.generation_config.to_dict().keys())
self.message_formater = self.format_messages

@staticmethod
def format_messages(messages: list[str] | list[dict[str, str]]) -> list[dict[str, str | list[dict[str, str]]]]:
return messages
61 changes: 61 additions & 0 deletions prompting/llms/hf_text_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from loguru import logger

try:
import torch
from transformers import AutoModelForImageTextToText, AutoProcessor
except ImportError:
logger.warning("Transformers or torch is not installed. This module will not be available.")

from prompting.llms.hf_llm import ReproducibleHF


class HFTextImageToText(ReproducibleHF):
def __init__(
self,
model_id: str = "google/gemma-3-27b-it",
device: str = "cuda:0",
sampling_params: dict[str, str | float | int | bool] | None = None,
):
super().__init__(model_id, device, sampling_params)
self.model: AutoModelForImageTextToText = AutoModelForImageTextToText.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map=self._device,
)
self.tokenizer = AutoProcessor.from_pretrained(model_id)
self.valid_generation_params = set(self.model.generation_config.to_dict().keys())
self.message_formater = HFTextImageToText.format_messages

@staticmethod
def format_messages(messages: list[str] | list[dict[str, str]]) -> list[dict[str, str | list[dict[str, str]]]]:
"""Format the messages for the gemma model.

Converts message content strings to dictionaries with type and text fields.
Example:
Input: [{"role": "user", "content": "Hello"}]
Output: [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}]
"""
formatted_messages = []
# Check if the message is a list of only one element and that element is a list
if isinstance(messages, list) and len(messages) == 1 and isinstance(messages[0], list):
messages = messages[0]
for message in messages:
if isinstance(message, dict) and "content" in message:
# If content is a string, convert it to a list with a dictionary
if isinstance(message["content"], str):
formatted_message = message.copy()
formatted_message["content"] = [{"type": "text", "text": message["content"]}]
formatted_messages.append(formatted_message)
else:
# If content is already in the correct format, keep it as is
formatted_messages.append(message)
else:
# Handle other message formats if needed
formatted_messages.append(message)

return formatted_messages


if __name__ == "__main__":
model = HFTextImageToText(model_id="google/gemma-3-27b-it", device="cuda:0")
print(model.generate([{"role": "user", "content": "What's ur name?"}]))
Loading