Skip to content

Commit

Permalink
Support sequence salience on HF models with PyTorch framework.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 619316810
  • Loading branch information
bdu91 authored and LIT team committed Mar 26, 2024
1 parent 99821d3 commit b9941ed
Show file tree
Hide file tree
Showing 3 changed files with 299 additions and 69 deletions.
79 changes: 60 additions & 19 deletions lit_nlp/examples/lm_salience_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,20 @@
--port=8890 --alsologtostderr
We strongly recommend a GPU or other accelerator to run this demo, although for
testing the smaller GPT-2 models run well on CPU; use
testing, the smaller GPT-2 models run well on CPU. To use tensorflow weights of
GPT2, set the flag values as below:
--models=gpt2:https://storage.googleapis.com/what-if-tool-resources/lit-models/gpt2.tar.gz
--hf_framework=tensorflow
We also support pytorch weights for GPT-2 model, simply set the flag values:
--models=gpt2:https://storage.googleapis.com/what-if-tool-resources/lit-models/gpt2-pt.tar.gz
--hf_framework=pytorch
A few more examples of the flag setup for other supported models (GPU required):
Llama2: --hf_framework=pytorch
--models=llama2:meta-llama/Llama-2-7b-hf
Mistral: --hf_framework=pytorch
--models=mistral:mistralai/Mistral-7B-v0.1
By default this include a small set of sample prompts, but you can load your
own examples using the --datasets flag or through the "Configure" menu in the
Expand Down Expand Up @@ -43,6 +55,13 @@
from lit_nlp.examples.models import pretrained_lms
from lit_nlp.lib import file_cache

# pytype: disable=import-error
try:
import torch
except (ModuleNotFoundError, ImportError):
logging.warning("PyTorch is not available.")
# pytype: enable=import-error

# NOTE: additional flags defined in server_flags.py

FLAGS = flags.FLAGS
Expand All @@ -55,8 +74,12 @@
"gemma_instruct_2b_en:gemma_instruct_2b_en",
"gpt2:https://storage.googleapis.com/what-if-tool-resources/lit-models/gpt2.tar.gz",
],
"Models to load, as <name>:<path>. Currently supports Gemma and GPT-2"
" variants.",
"Models to load, as <name>:<path>. Currently supports Gemma (Keras NLP) and"
"HuggingFace models. For HuggingFace models, GPT2, Llama, Mistral have been"
"verified to work with this demo. Thereotically, supported decoder models"
"in `transformers.AutoModelForCausalLM` should work, but adjustments might"
"be needed on their tokenizers (e.g. need to define custom pad_token when"
"eos_token is not available to use as pad_token).",
)

_DATASETS = flags.DEFINE_list(
Expand All @@ -76,10 +99,22 @@
),
)

_KERAS_FLOATX = flags.DEFINE_string(
"keras_floatx", "bfloat16", "Floating-point type for Keras models."
_HF_FRAMEWORK = flags.DEFINE_enum(
"hf_framework",
"tensorflow",
["tensorflow", "pytorch"],
"Deep learning framework for the HuggingFace model.",
)

_PRECISION = flags.DEFINE_enum(
"precision",
"bfloat16",
["bfloat16", "float32"],
"Floating point precision for the HuggingFace (PyTorch) and Keras models,"
"only `bfloat16` and `float32` are supported for now.",
)


# TODO(lit-dev): move these layouts to a separate .py file.
# Custom frontend layout; see api/layout.py
modules = layout.LitModuleName
Expand Down Expand Up @@ -168,13 +203,19 @@ def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]:
# Set Keras backend and floating-point precision.
os.environ["KERAS_BACKEND"] = "tensorflow"
if hasattr(keras, "config") and hasattr(keras.config, "set_floatx"):
keras.config.set_floatx(_KERAS_FLOATX.value)
keras.config.set_floatx(_PRECISION.value)
else:
# TODO(b/327281789): remove once we can guarantee Keras 3.
logging.warn(
"keras.config.set_floatx() not available; using default precision."
)

if _HF_FRAMEWORK.value == "pytorch":
if _PRECISION.value == "bfloat16":
torch.set_default_dtype(torch.bfloat16)
else:
torch.set_default_dtype(torch.float32)

plaintextPrompts = functools.partial( # pylint: disable=invalid-name
lm_data.PlaintextSents, field_name="prompt"
)
Expand Down Expand Up @@ -222,17 +263,7 @@ def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]:
# containing 'https://'
model_name, path = model_string.split(":", 1)
logging.info("Loading model '%s' from '%s'", model_name, path)
if model_name.startswith("gpt2") or model_name in ["distilgpt2"]:
models[model_name] = pretrained_lms.HFGenerativeModel(path)
# Salience wrapper, using same underlying Keras models so as not to
# load the weights twice.
models[f"_{model_name}_salience"] = (
pretrained_lms.HFSalienceModel.from_loaded(models[model_name])
)
models[f"_{model_name}_tokenizer"] = (
pretrained_lms.HFTokenizerModel.from_loaded(models[model_name])
)
elif model_name.startswith("gemma"):
if model_name.startswith("gemma"):
path = file_cache.cached_path(
path,
extract_compressed_file=path.endswith(".tar.gz"),
Expand All @@ -248,8 +279,18 @@ def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]:
# crashing? Maybe need n > 2 examples.
models[model_name].output_embeddings = False
else:
raise ValueError(
f"Unsupported model name '{model_name}' from path '{path}'"
# Assuming a valid decoder model name supported by
# `transformers.AutoModelForCausalLM` is provided to "path".
models[model_name] = pretrained_lms.HFGenerativeModel(
path, framework=_HF_FRAMEWORK.value, max_new_tokens=512
)
# Salience wrapper, using same underlying Keras models so as not to
# load the weights twice.
models[f"_{model_name}_salience"] = (
pretrained_lms.HFSalienceModel.from_loaded(models[model_name])
)
models[f"_{model_name}_tokenizer"] = (
pretrained_lms.HFTokenizerModel.from_loaded(models[model_name])
)

for name in datasets:
Expand Down
Loading

0 comments on commit b9941ed

Please sign in to comment.