Skip to content

Commit

Permalink
add torch mps support & enabled passed pipeline params
Browse files Browse the repository at this point in the history
* detect cuda vs mps vs cpu in a common way
* guard import of OptimimPipeline

Signed-off-by: Jeffrey Martin <[email protected]>
  • Loading branch information
jmartin-tech committed Jun 17, 2024
1 parent 11ebb73 commit fd06da1
Showing 1 changed file with 58 additions and 33 deletions.
91 changes: 58 additions & 33 deletions garak/generators/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
https://huggingface.co/docs/api-inference/quicktour
"""

import inspect
import logging
import re
from typing import List, Union
Expand All @@ -25,7 +26,7 @@
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration

from garak import _config
from garak.exception import ModelNameMissingError
from garak.exception import ModelNameMissingError, GarakException
from garak.generators.base import Generator


Expand Down Expand Up @@ -70,6 +71,10 @@ def __init__(
self.name, generations=self.generations, config_root=config_root
)

import torch.multiprocessing as mp

mp.set_start_method("spawn", force=True)

self._load_client()

def _load_client(self):
Expand All @@ -83,16 +88,17 @@ def _load_client(self):

import torch.cuda

if not torch.cuda.is_available():
logging.debug("Using CPU, torch.cuda.is_available() returned False")
self.device = -1

self.generator = pipeline(
"text-generation",
model=self.name,
do_sample=self.do_sample,
device=self.device,
# consider how this could be abstracted well
self.device = (
"cuda:" + str(self.device)
if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available() else "cpu"
)

logging.debug("Using %s, based on torch environment evaluation", self.device)

pipline_kwargs = self._gather_pipeline_params(pipeline=pipeline)
self.generator = pipeline("text-generation", **pipline_kwargs)
if not hasattr(self, "deprefix_prompt"):
self.deprefix_prompt = self.name in models_to_deprefix
if _config.loaded:
Expand All @@ -104,6 +110,17 @@ def _load_client(self):
def _clear_client(self):
self.generator = None

def _gather_pipeline_params(self, pipeline):
# this may be a bit too naive as it will pass any parameter valid for the pipeline signature
args = {}
for k in inspect.signature(pipeline).parameters:
if k == "model":
# special case of known mapping as `model` may be reserved for the class
args[k] = self.name
if hasattr(self, k):
args[k] = getattr(self, k)
return args

def _call_model(
self, prompt: str, generations_this_call: int = 1
) -> List[Union[str, None]]:
Expand Down Expand Up @@ -150,8 +167,14 @@ def _load_client(self):
if hasattr(self, "generator") and self.generator is not None:
return

from optimum.nvidia.pipelines import pipeline
from transformers import set_seed
try:
from optimum.nvidia.pipelines import pipeline
from transformers import set_seed
except Exception as e:
logging.exception(e)
raise GarakException(
f"Missing required dependencies for {self.__class__.__name__}"
)

if _config.run.seed is not None:
set_seed(_config.run.seed)
Expand All @@ -161,20 +184,15 @@ def _load_client(self):
if not torch.cuda.is_available():
message = "OptimumPipeline needs CUDA, but torch.cuda.is_available() returned False; quitting"
logging.critical(message)
raise ValueError(message)
raise GarakException(message)

use_fp8 = False
self.use_fp8 = False
if _config.loaded:
if "use_fp8" in _config.plugins.generators.OptimumPipeline:
use_fp8 = True

self.generator = pipeline(
"text-generation",
model=self.name,
do_sample=self.do_sample,
device=self.device,
use_fp8=use_fp8,
)
self.use_fp8 = True

pipline_kwargs = self._gather_pipeline_params(pipeline=pipeline)
self.generator = pipeline("text-generation", **pipline_kwargs)
if not hasattr(self, "deprefix_prompt"):
self.deprefix_prompt = self.name in models_to_deprefix
if _config.loaded:
Expand All @@ -201,18 +219,19 @@ def _load_client(self):

import torch.cuda

if not torch.cuda.is_available():
logging.debug("Using CPU, torch.cuda.is_available() returned False")
self.device = -1
# consider how this could be abstracted well
self.device = (
"cuda:" + str(self.device)
if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available() else "cpu"
)

logging.debug("Using %s, based on torch environment evaluation", self.device)

# Note that with pipeline, in order to access the tokenizer, model, or device, you must get the attribute
# directly from self.generator instead of from the ConversationalPipeline object itself.
self.generator = pipeline(
"conversational",
model=self.name,
do_sample=self.do_sample,
device=self.device,
)
pipline_kwargs = self._gather_pipeline_params(pipeline=pipeline)
self.generator = pipeline("conversational", **pipline_kwargs)
self.conversation = Conversation()
if not hasattr(self, "deprefix_prompt"):
self.deprefix_prompt = self.name in models_to_deprefix
Expand Down Expand Up @@ -460,9 +479,15 @@ def _load_client(self):
if _config.run.seed is not None:
transformers.set_seed(_config.run.seed)

self.init_device = "cuda:" + str(self.device)
import torch.cuda

# consider how this could be abstracted well
self.init_device = (
"cuda:" + str(self.device)
if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available() else "cpu"
)

if not torch.cuda.is_available():
logging.debug("Using CPU, torch.cuda.is_available() returned False")
self.device = -1
Expand Down

0 comments on commit fd06da1

Please sign in to comment.