Skip to content

Commit

Permalink
enable hf model or pipeline config in hf_args
Browse files Browse the repository at this point in the history
* support all generic `pipeline` args at all times
* adds `do_sample` when `model` is a parameter to the `Callable`
* adds `low_cpu_mem_usage` and all `pipeline` for `Callables` without `model`
* consolidates optimal device selection & set when not provided by config

Signed-off-by: Jeffrey Martin <[email protected]>
  • Loading branch information
jmartin-tech committed Jun 17, 2024
1 parent fd06da1 commit 62f91a9
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 90 deletions.
213 changes: 124 additions & 89 deletions garak/generators/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@

import inspect
import logging
import os
import re
from typing import List, Union
from typing import Callable, List, Union
import warnings

import backoff
Expand All @@ -33,15 +34,15 @@
models_to_deprefix = ["gpt2"]


class HFRateLimitException(Exception):
class HFRateLimitException(GarakException):
pass


class HFLoadingException(Exception):
class HFLoadingException(GarakException):
pass


class HFInternalServerError(Exception):
class HFInternalServerError(GarakException):
pass


Expand All @@ -51,30 +52,106 @@ def _set_hf_context_len(self, config):
if isinstance(config.n_ctx, int):
self.context_len = config.n_ctx

def _gather_hf_params(self, hf_constructor: Callable):
# this may be a bit too naive as it will pass any parameter valid for the pipeline signature
# this falls over when passed `from_pretrained` methods as the callable model params are not explicit
params = self.hf_args
if params["device"] is None:
params["device"] = self.device

args = {}

parameters = inspect.signature(hf_constructor).parameters

if "model" in parameters:
args["model"] = self.name
# expand for
parameters = {"do_sample": True} | parameters
else:
# callable is for a Pretrained class also map standard `pipeline` params
from transformers import pipeline

parameters = (
{"low_cpu_mem_usage": True}
| parameters
| inspect.signature(pipeline).parameters
)

for k in parameters:
if k == "model":
continue # special case `model` comes from `name` in the generator
if k in params:
val = params[k]
if k == "torch_dtype" and hasattr(torch, val):
args[k] = getattr(
torch, val
) # some model type specific classes do not yet support direct string representation
continue
if (
k == "device"
and "device_map" in parameters
and "device_map" in params
):
# per transformers convention hold `device_map` before `device`
continue
args[k] = params[k]

return args

def _select_hf_device(self):
"""Determine the most efficient device for tensor load, hold any existing `device` already selected"""
import torch.cuda

selected_device = None
if self.hf_args["device"] is not None:
if isinstance(self.hf_args["device"], int):
# this assumes that indexed only devices selections means `cuda`
selected_device = torch.device("cuda:" + str(self.hf_args["device"]))
else:
selected_device = torch.device(self.hf_args["device"])

if selected_device is None:
selected_device = torch.device(
"cuda"
if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available() else "cpu"
)

if isinstance(selected_device, torch.device) and selected_device.type == "mps":
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
logging.debug("Enabled MPS fallback environment variable")

logging.debug(
"Using %s, based on torch environment evaluation", selected_device
)
return selected_device


class Pipeline(Generator, HFCompatible):
"""Get text generations from a locally-run Hugging Face pipeline"""

DEFAULT_PARAMS = Generator.DEFAULT_PARAMS | {
"generations": 10,
"hf_args": {
"torch_dtype": "float16",
"do_sample": True,
"device": None,
},
}
generator_family_name = "Hugging Face 🤗 pipeline"
supports_multiple_generations = True
parallel_capable = False

def __init__(
self, name="", do_sample=True, generations=10, device=0, config_root=_config
):
def __init__(self, name="", config_root=_config):
self.name = name
self.generations = generations
self.do_sample = do_sample
self.device = device

super().__init__(
self.name, generations=self.generations, config_root=config_root
)
super().__init__(self.name, config_root=config_root)

import torch.multiprocessing as mp

mp.set_start_method("spawn", force=True)

self.device = self._select_hf_device()
self._load_client()

def _load_client(self):
Expand All @@ -86,18 +163,7 @@ def _load_client(self):
if _config.run.seed is not None:
set_seed(_config.run.seed)

import torch.cuda

# consider how this could be abstracted well
self.device = (
"cuda:" + str(self.device)
if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available() else "cpu"
)

logging.debug("Using %s, based on torch environment evaluation", self.device)

pipline_kwargs = self._gather_pipeline_params(pipeline=pipeline)
pipline_kwargs = self._gather_hf_params(hf_constructor=pipeline)
self.generator = pipeline("text-generation", **pipline_kwargs)
if not hasattr(self, "deprefix_prompt"):
self.deprefix_prompt = self.name in models_to_deprefix
Expand All @@ -110,17 +176,6 @@ def _load_client(self):
def _clear_client(self):
self.generator = None

def _gather_pipeline_params(self, pipeline):
# this may be a bit too naive as it will pass any parameter valid for the pipeline signature
args = {}
for k in inspect.signature(pipeline).parameters:
if k == "model":
# special case of known mapping as `model` may be reserved for the class
args[k] = self.name
if hasattr(self, k):
args[k] = getattr(self, k)
return args

def _call_model(
self, prompt: str, generations_this_call: int = 1
) -> List[Union[str, None]]:
Expand Down Expand Up @@ -191,7 +246,7 @@ def _load_client(self):
if "use_fp8" in _config.plugins.generators.OptimumPipeline:
self.use_fp8 = True

pipline_kwargs = self._gather_pipeline_params(pipeline=pipeline)
pipline_kwargs = self._gather_hf_params(hf_constructor=pipeline)
self.generator = pipeline("text-generation", **pipline_kwargs)
if not hasattr(self, "deprefix_prompt"):
self.deprefix_prompt = self.name in models_to_deprefix
Expand Down Expand Up @@ -219,18 +274,9 @@ def _load_client(self):

import torch.cuda

# consider how this could be abstracted well
self.device = (
"cuda:" + str(self.device)
if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available() else "cpu"
)

logging.debug("Using %s, based on torch environment evaluation", self.device)

# Note that with pipeline, in order to access the tokenizer, model, or device, you must get the attribute
# directly from self.generator instead of from the ConversationalPipeline object itself.
pipline_kwargs = self._gather_pipeline_params(pipeline=pipeline)
pipline_kwargs = self._gather_hf_params(hf_constructor=pipeline)
self.generator = pipeline("conversational", **pipline_kwargs)
self.conversation = Conversation()
if not hasattr(self, "deprefix_prompt"):
Expand Down Expand Up @@ -278,7 +324,7 @@ def _call_model(
return [re.sub("^" + re.escape(prompt), "", _o) for _o in outputs]


class InferenceAPI(Generator, HFCompatible):
class InferenceAPI(Generator):
"""Get text generations from Hugging Face Inference API"""

generator_family_name = "Hugging Face 🤗 Inference API"
Expand Down Expand Up @@ -407,7 +453,7 @@ def _pre_generate_hook(self):
self.wait_for_model = False


class InferenceEndpoint(InferenceAPI, HFCompatible):
class InferenceEndpoint(InferenceAPI):
"""Interface for Hugging Face private endpoints
Pass the model URL as the name, e.g. https://xxx.aws.endpoints.huggingface.cloud
"""
Expand Down Expand Up @@ -479,35 +525,22 @@ def _load_client(self):
if _config.run.seed is not None:
transformers.set_seed(_config.run.seed)

import torch.cuda

# consider how this could be abstracted well
self.init_device = (
"cuda:" + str(self.device)
if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available() else "cpu"
)

if not torch.cuda.is_available():
logging.debug("Using CPU, torch.cuda.is_available() returned False")
self.device = -1
self.init_device = "cpu"

trust_remote_code = self.name.startswith("mosaicml/mpt-")

model_kwargs = self._gather_hf_params(
hf_constructor=transformers.AutoConfig.from_pretrained
) # will defer to device_map if device map was `auto` may not match self.device

self.config = transformers.AutoConfig.from_pretrained(
self.name, trust_remote_code=trust_remote_code
)
self.config.init_device = (
self.init_device # or "cuda:0" For fast initialization directly on GPU!
self.name, trust_remote_code=trust_remote_code, **model_kwargs
)

self._set_hf_context_len(self.config)
self.config.init_device = self.device # determined by Pipeline `__init__``

self.model = transformers.AutoModelForCausalLM.from_pretrained(
self.name,
config=self.config,
).to(self.init_device)
self.name, config=self.config
).to(self.device)

if not hasattr(self, "deprefix_prompt"):
self.deprefix_prompt = self.name in models_to_deprefix
Expand Down Expand Up @@ -537,7 +570,7 @@ def _call_model(
) -> List[Union[str, None]]:
self._load_client()
self.generation_config.max_new_tokens = self.max_tokens
self.generation_config.do_sample = self.do_sample
self.generation_config.do_sample = self.hf_args["do_sample"]
self.generation_config.num_return_sequences = generations_this_call
if self.temperature is not None:
self.generation_config.temperature = self.temperature
Expand All @@ -550,7 +583,7 @@ def _call_model(
with torch.no_grad():
inputs = self.tokenizer(
prompt, truncation=True, return_tensors="pt"
).to(self.init_device)
).to(self.device)

try:
outputs = self.model.generate(
Expand All @@ -574,21 +607,23 @@ def _call_model(
return [re.sub("^" + re.escape(prompt), "", i) for i in text_output]


class LLaVA(Generator):
class LLaVA(Generator, HFCompatible):
"""Get LLaVA ([ text + image ] -> text) generations"""

DEFAULT_PARAMS = Generator.DEFAULT_PARAMS | {
"max_tokens": 4000,
# "exist_tokens + max_new_tokens < 4K is the golden rule."
# https://github.com/haotian-liu/LLaVA/issues/1095#:~:text=Conceptually%2C%20as%20long%20as%20the%20total%20tokens%20are%20within%204K%2C%20it%20would%20be%20fine%2C%20so%20exist_tokens%20%2B%20max_new_tokens%20%3C%204K%20is%20the%20golden%20rule.
"max_tokens": 4000,
# consider shifting below to kwargs or llava_kwargs that is a dict to allow more customization
"torch_dtype": torch.float16,
"low_cpu_mem_usage": True,
"device_map": "cuda:0",
"hf_args": {
"torch_dtype": "float16",
"low_cpu_mem_usage": True,
"device_map": "auto",
},
}

# rewrite modality setting
modality = {"in": {"text", "image"}, "out": {"text"}}
parallel_capable = False

# Support Image-Text-to-Text models
# https://huggingface.co/llava-hf#:~:text=Llava-,Models,-9
Expand All @@ -603,20 +638,20 @@ def __init__(self, name="", generations=10, config_root=_config):
super().__init__(name, generations=generations, config_root=config_root)
if self.name not in self.supported_models:
raise ModelNameMissingError(
f"Invalid modal name {self.name}, current support: {self.supported_models}."
f"Invalid model name {self.name}, current support: {self.supported_models}."
)

self.device = self._select_hf_device()
model_kwargs = self._gather_hf_params(
hf_constructor=LlavaNextForConditionalGeneration.from_pretrained
) # will defer to device_map if device map was `auto` may not match self.device

self.processor = LlavaNextProcessor.from_pretrained(self.name)
self.model = LlavaNextForConditionalGeneration.from_pretrained(
self.name,
torch_dtype=self.torch_dtype,
low_cpu_mem_usage=self.low_cpu_mem_usage,
self.name, **model_kwargs
)
if torch.cuda.is_available():
self.model.to(self.device_map)
else:
raise RuntimeError(
"CUDA is not supported on this device. Please make sure CUDA is installed and configured properly."
)

self.model.to(self.device)

def generate(
self, prompt: str, generations_this_call: int = 1
Expand All @@ -630,7 +665,7 @@ def generate(
raise Exception(e)

inputs = self.processor(text_prompt, image_prompt, return_tensors="pt").to(
self.device_map
self.device
)
exist_token_number: int = inputs.data["input_ids"].shape[1]
output = self.model.generate(
Expand Down
16 changes: 15 additions & 1 deletion tests/generators/test_huggingface.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,25 @@
import transformers
import garak.generators.huggingface
from garak._config import GarakSubConfig

DEFAULT_GENERATIONS_QTY = 10


def test_pipeline():
g = garak.generators.huggingface.Pipeline("gpt2")
gen_config = {
"huggingface": {
"Pipeline": {
"name": "gpt2",
"hf_args": {
"device": "cpu",
},
}
}
}
config_root = GarakSubConfig()
setattr(config_root, "generators", gen_config)

g = garak.generators.huggingface.Pipeline("gpt2", config_root=config_root)
assert g.name == "gpt2"
assert g.generations == DEFAULT_GENERATIONS_QTY
assert isinstance(g.generator, transformers.pipelines.text_generation.Pipeline)
Expand Down

0 comments on commit 62f91a9

Please sign in to comment.