-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DeepSpeed-ZeRO v DeepSpeed-Inference #4234
Comments
DeepSpeed-Inference will provide the best latency. While we do not have kernel injection support for the 70B model yet (but we do for the smaller variants!), you can still split the model across several GPUs with Auto Tensor Parallelism.
We don't support meta tensor loading of the Llama models with kernel injection, but we can still load it with AutoTP (see example below).
ZeRO-Inference is primarily targeting cases where we want to run inference with very large models on very limited GPU memory. It takes advantage of ZeRO offloading capabilities to move most of the model weights to CPU memory (or even NVME storage). Because there is overhead associated with offloading weights, it is typically not well suited for use cases where low latency inference is a priority. DeepSpeed-Inference is a separate engine that introduces lots of optimizations for running inference. For example, we support custom kernel injection on tens of thousands of models that can significantly improve latency and throughput. This will likely be your best bet for getting lowest latency when doing inference, but at the cost of needing much more GPU memory.
Here is your icing! You will need to make some changes if you have the checkpoint saved locally (this will try to download from HF). It will run both ZeRO-Inference and DeepSpeed-Inference examples of the 70B model. # llama-70b-example.py
# Launch with `deepspeed llama-70b-example.py`
import torch
import deepspeed
import os
import time
from transformers.deepspeed import HfDeepSpeedConfig
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
local_rank = int(os.getenv("LOCAL_RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
model_name = "meta-llama/Llama-2-70b-hf"
hf_token = "<your hf token>"
def run_zero_inference():
ds_config = {
"fp16": {"enabled": True},
"bf16": {"enabled": False},
"zero_optimization": {
"stage": 3,
"offload_param": {
"device": "cpu",
},
},
"train_micro_batch_size_per_gpu": 1,
}
# Share the DeepSpeed config with HuggingFace so we can properly load the
# large model with zero stage 3
hfdsc = HfDeepSpeedConfig(ds_config)
# Load the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)
model = AutoModelForCausalLM.from_pretrained(
model_name, token=hf_token, torch_dtype=torch.float16
)
# Initialize DeepSpeed
ds_engine = deepspeed.initialize(model=model, config_params=ds_config)[0]
ds_engine.module.eval()
model = ds_engine.module
# Run inference
start_time = time.time()
inputs = tokenizer.encode("DeepSpeed is", return_tensors="pt").to(
f"cuda:{local_rank}"
)
outputs = model.generate(inputs, max_new_tokens=20)
output_str = tokenizer.decode(outputs[0])
end_time = time.time()
print("ZeRO-inference time:", end_time - start_time)
def run_deepspeed_inference():
# Load the model on meta tensors
config = AutoConfig.from_pretrained(model_name, token=hf_token)
tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)
with deepspeed.OnDevice(dtype=torch.float16, device="meta", enabled=True):
model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.float16)
# Define the checkpoint dict. You may need to convert *.safetensors to
# *.bin for this work. Make sure you get all the *.bin and *.pt files in
# the checkpoint_files list.
checkpoint_dir = "~/.cache/huggingface/hub/models--meta-llama--Llama-2-70b-hf/snapshots/cc8aa03a000ff08b4d5c5b39673321a2a396c396"
checkpoint_files = [
os.path.join(checkpoint_dir, f"model-{i:05d}-of-000015.bin")
for i in range(1, 16)
]
checkpoint_dict = {
"type": "DS_MODEL",
"checkpoints": checkpoint_files,
"version": 1.0,
}
# Initialize DeepSpeed
model = deepspeed.init_inference(
model,
replace_with_kernel_inject=False,
mp_size=world_size,
dtype=torch.float16,
checkpoint=checkpoint_dict,
)
# Run inference
start_time = time.time()
inputs = tokenizer.encode("DeepSpeed is", return_tensors="pt").to(
f"cuda:{local_rank}"
)
outputs = model.generate(inputs, max_new_tokens=20)
output_str = tokenizer.decode(outputs[0])
end_time = time.time()
print("DeepSpeed-inference time:", end_time - start_time)
if __name__ == "__main__":
run_zero_inference()
run_deepspeed_inference() |
You are a lifesaver. Thank you so much for your clear and detailed answer! It should be put in the docs, I'm probably not the only one who's confused (but maybe I am, who knows). If that's all right, could you also quickly explain what kernel injection is, please? |
Hi, @mrwyattii . If I use zero3 + offload to train a model. Can I use this code to inference ? I have encountered a very strange problem that when I used my custom code to inference that some weight of layers always random initialization. ![]() self.bias_k = Parameter(torch.empty((1, 1, config.hidden_size))) self.bias_k and self.bias_v is in my custom nework. When I use it to train a model that can train and save the model in zero3+offload. But when I use model to inference, I find self.bias_k and self.bias_v can't load the model weight when load model finish, the weight of self.bias_v like this ( it seems both of them load fail): ![]() the weight of self.bias_v's in the model file(pytorch.bin): |
@iamsile would you mind opening a new issue for your unrelated problem? |
You're right, we should include short tutorials like this in the docs. I'll talk with the team about organizing this.
Sure! Kernel injection refers to the use of custom kernels (essentially pieces of code that get executed on your GPU) in DeepSpeed-Inference that provide better performance than the baseline kernels provided by pytorch/transformers. For example, we replace self attention layers with DeepSpeedSelfAttention, which uses custom QKV, linear, etc. kernels to decrease latency and increase throughput for inference! |
@iamsile please open a new issue and I can assist. If possible, pleaser provide a minimal example to reproduce the problem you are facing. Thanks! |
Thanks again, closing the issue! |
Hi @mrwyattii, I've adapted your script to my use case but I noticed that the output always ends up degenerate. It is the
|
try #4259 |
It does. Thanks @ZaVang. |
Hi, first of all, sorry if that information is already written somewhere. I have tried to search in the docs and in the issues but I did not find a clear answer to my simple question.
I have several nodes of 8x80GB A100 GPUs at my disposal. If my goal is to run a large inference (no training required for now) job with 70B-parameter-llama models, what is the most efficient way / fastest way to run it ? I'm making a distinction in the case where there is a slower method than the fastest one that saves a lot of computational power.
I am still having trouble understanding the difference between DeepSpeed-ZeRO and DeepSpeed-Inference (and DeepSpeed-Chat while we're at it), especially since ZeRO also seems to perform inference tasks.
I have been trying to use DeepSpeed-Inference but, from my understanding, before using
deepspeed.init_inference
, I have to load the whole model on GPU or CPU and that always triggers an out of memory error. Sincemeta
tensors are not yet supported for Llama models on the latest DeepSpeed release, I'm a bit stumped.If someone could explain the rationale between the different DeepSpeed versions and what would be most adapted to my use case, I would be very grateful. A minimal working example would be the icing on the cake, but I'd be fine with just knowing where to look first. Surely if the Open LLM Leaderboard is full of different 70B-LLMs, there must be a simple way to do what I want, but I'm not sure how everyone does it.
For reference, below is the script I'm working with at the moment:
I get
The text was updated successfully, but these errors were encountered: