Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DeepSpeed-ZeRO v DeepSpeed-Inference #4234

Closed
lashoun opened this issue Aug 30, 2023 · 11 comments
Closed

DeepSpeed-ZeRO v DeepSpeed-Inference #4234

lashoun opened this issue Aug 30, 2023 · 11 comments
Assignees

Comments

@lashoun
Copy link

lashoun commented Aug 30, 2023

Hi, first of all, sorry if that information is already written somewhere. I have tried to search in the docs and in the issues but I did not find a clear answer to my simple question.

I have several nodes of 8x80GB A100 GPUs at my disposal. If my goal is to run a large inference (no training required for now) job with 70B-parameter-llama models, what is the most efficient way / fastest way to run it ? I'm making a distinction in the case where there is a slower method than the fastest one that saves a lot of computational power.

I am still having trouble understanding the difference between DeepSpeed-ZeRO and DeepSpeed-Inference (and DeepSpeed-Chat while we're at it), especially since ZeRO also seems to perform inference tasks.

I have been trying to use DeepSpeed-Inference but, from my understanding, before using deepspeed.init_inference, I have to load the whole model on GPU or CPU and that always triggers an out of memory error. Since meta tensors are not yet supported for Llama models on the latest DeepSpeed release, I'm a bit stumped.

If someone could explain the rationale between the different DeepSpeed versions and what would be most adapted to my use case, I would be very grateful. A minimal working example would be the icing on the cake, but I'd be fine with just knowing where to look first. Surely if the Open LLM Leaderboard is full of different 70B-LLMs, there must be a simple way to do what I want, but I'm not sure how everyone does it.

For reference, below is the script I'm working with at the moment:

tokenizer = AutoTokenizer.from_pretrained(model_path, legacy=False)
config = AutoConfig.from_pretrained(model_path)

dtype = torch.float16

# # meta tensors not yet supported for llama models...
# with deepspeed.OnDevice(dtype=dtype, device="meta"):
#     model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.bfloat16)

# ... but I get an OOM error here if the model is too big to fit on the CPU
# (I think it's loaded on the CPU? Not even sure, and it's odd since I think
# there is enough memory on the CPU, but apparently not)
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True)

model = model.eval()
infer_dtype = "float16"
tp_config = deepspeed.inference.config.DeepSpeedTPConfig()
tp_config.tp_size = world_size

checkpoints_json = os.path.join(model_path, "ds_inference_config.json")
model = deepspeed.init_inference(
    model,
    tensor_parallel=tp_config,
    base_dir=model_path,
    dtype=getattr(torch, infer_dtype),
    checkpoint=checkpoints_json,
    replace_with_kernel_inject=True,
)

model = model.module

I get

slurmstepd: error: Detected 1 oom_kill event in StepId=1509929.0. Some of the step tasks have been OOM Killed.
srun: error: [REDACTED]: task 0: Out Of Memory
srun: Terminating StepId=1509929.0 
@mrwyattii mrwyattii self-assigned this Aug 30, 2023
@mrwyattii
Copy link
Contributor

mrwyattii commented Aug 30, 2023

what is the most efficient way / fastest way to run it ?

DeepSpeed-Inference will provide the best latency. While we do not have kernel injection support for the 70B model yet (but we do for the smaller variants!), you can still split the model across several GPUs with Auto Tensor Parallelism.

Since meta tensors are not yet supported for Llama models on the latest DeepSpeed release, I'm a bit stumped.

We don't support meta tensor loading of the Llama models with kernel injection, but we can still load it with AutoTP (see example below).

If someone could explain the rationale between the different DeepSpeed versions and what would be most adapted to my use case, I would be very grateful.

ZeRO-Inference is primarily targeting cases where we want to run inference with very large models on very limited GPU memory. It takes advantage of ZeRO offloading capabilities to move most of the model weights to CPU memory (or even NVME storage). Because there is overhead associated with offloading weights, it is typically not well suited for use cases where low latency inference is a priority.

DeepSpeed-Inference is a separate engine that introduces lots of optimizations for running inference. For example, we support custom kernel injection on tens of thousands of models that can significantly improve latency and throughput. This will likely be your best bet for getting lowest latency when doing inference, but at the cost of needing much more GPU memory.

A minimal working example would be the icing on the cake

Here is your icing! You will need to make some changes if you have the checkpoint saved locally (this will try to download from HF). It will run both ZeRO-Inference and DeepSpeed-Inference examples of the 70B model.

# llama-70b-example.py
# Launch with `deepspeed llama-70b-example.py`

import torch
import deepspeed
import os
import time
from transformers.deepspeed import HfDeepSpeedConfig
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM

local_rank = int(os.getenv("LOCAL_RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
model_name = "meta-llama/Llama-2-70b-hf"
hf_token = "<your hf token>"


def run_zero_inference():
    ds_config = {
        "fp16": {"enabled": True},
        "bf16": {"enabled": False},
        "zero_optimization": {
            "stage": 3,
            "offload_param": {
                "device": "cpu",
            },
        },
        "train_micro_batch_size_per_gpu": 1,
    }
    # Share the DeepSpeed config with HuggingFace so we can properly load the
    # large model with zero stage 3
    hfdsc = HfDeepSpeedConfig(ds_config)

    # Load the model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)
    model = AutoModelForCausalLM.from_pretrained(
        model_name, token=hf_token, torch_dtype=torch.float16
    )

    # Initialize DeepSpeed
    ds_engine = deepspeed.initialize(model=model, config_params=ds_config)[0]
    ds_engine.module.eval()
    model = ds_engine.module

    # Run inference
    start_time = time.time()
    inputs = tokenizer.encode("DeepSpeed is", return_tensors="pt").to(
        f"cuda:{local_rank}"
    )
    outputs = model.generate(inputs, max_new_tokens=20)
    output_str = tokenizer.decode(outputs[0])
    end_time = time.time()
    print("ZeRO-inference time:", end_time - start_time)


def run_deepspeed_inference():
    # Load the model on meta tensors
    config = AutoConfig.from_pretrained(model_name, token=hf_token)
    tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)
    with deepspeed.OnDevice(dtype=torch.float16, device="meta", enabled=True):
        model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.float16)

    # Define the checkpoint dict. You may need to convert *.safetensors to
    # *.bin for this work. Make sure you get all the *.bin and *.pt files in
    # the checkpoint_files list.
    checkpoint_dir = "~/.cache/huggingface/hub/models--meta-llama--Llama-2-70b-hf/snapshots/cc8aa03a000ff08b4d5c5b39673321a2a396c396"
    checkpoint_files = [
        os.path.join(checkpoint_dir, f"model-{i:05d}-of-000015.bin")
        for i in range(1, 16)
    ]
    checkpoint_dict = {
        "type": "DS_MODEL",
        "checkpoints": checkpoint_files,
        "version": 1.0,
    }

    # Initialize DeepSpeed
    model = deepspeed.init_inference(
        model,
        replace_with_kernel_inject=False,
        mp_size=world_size,
        dtype=torch.float16,
        checkpoint=checkpoint_dict,
    )

    # Run inference
    start_time = time.time()
    inputs = tokenizer.encode("DeepSpeed is", return_tensors="pt").to(
        f"cuda:{local_rank}"
    )
    outputs = model.generate(inputs, max_new_tokens=20)
    output_str = tokenizer.decode(outputs[0])
    end_time = time.time()
    print("DeepSpeed-inference time:", end_time - start_time)


if __name__ == "__main__":
    run_zero_inference()
    run_deepspeed_inference()

@lashoun
Copy link
Author

lashoun commented Aug 30, 2023

You are a lifesaver. Thank you so much for your clear and detailed answer!

It should be put in the docs, I'm probably not the only one who's confused (but maybe I am, who knows).

If that's all right, could you also quickly explain what kernel injection is, please?

@iamsile
Copy link

iamsile commented Aug 31, 2023

what is the most efficient way / fastest way to run it ?

DeepSpeed-Inference will provide the best latency. While we do not have kernel injection support for the 70B model yet (but we do for the smaller variants!), you can still split the model across several GPUs with Auto Tensor Parallelism.

Since meta tensors are not yet supported for Llama models on the latest DeepSpeed release, I'm a bit stumped.

We don't support meta tensor loading of the Llama models with kernel injection, but we can still load it with AutoTP (see example below).

If someone could explain the rationale between the different DeepSpeed versions and what would be most adapted to my use case, I would be very grateful.

ZeRO-Inference is primarily targeting cases where we want to run inference with very large models on very limited GPU memory. It takes advantage of ZeRO offloading capabilities to move most of the model weights to CPU memory (or even NVME storage). Because there is overhead associated with offloading weights, it is typically not well suited for use cases where low latency inference is a priority.

DeepSpeed-Inference is a separate engine that introduces lots of optimizations for running inference. For example, we support custom kernel injection on tens of thousands of models that can significantly improve latency and throughput. This will likely be your best bet for getting lowest latency when doing inference, but at the cost of needing much more GPU memory.

A minimal working example would be the icing on the cake

Here is your icing! You will need to make some changes if you have the checkpoint saved locally (this will try to download from HF). It will run both ZeRO-Inference and DeepSpeed-Inference examples of the 70B model.

# llama-70b-example.py
# Launch with `deepspeed llama-70b-example.py`

import torch
import deepspeed
import os
import time
from transformers.deepspeed import HfDeepSpeedConfig
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM

local_rank = int(os.getenv("LOCAL_RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
model_name = "meta-llama/Llama-2-70b-hf"
hf_token = "<your hf token>"


def run_zero_inference():
    ds_config = {
        "fp16": {"enabled": True},
        "bf16": {"enabled": False},
        "zero_optimization": {
            "stage": 3,
            "offload_param": {
                "device": "cpu",
            },
        },
        "train_micro_batch_size_per_gpu": 1,
    }
    # Share the DeepSpeed config with HuggingFace so we can properly load the
    # large model with zero stage 3
    hfdsc = HfDeepSpeedConfig(ds_config)

    # Load the model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)
    model = AutoModelForCausalLM.from_pretrained(
        model_name, token=hf_token, torch_dtype=torch.float16
    )

    # Initialize DeepSpeed
    ds_engine = deepspeed.initialize(model=model, config_params=ds_config)[0]
    ds_engine.module.eval()
    model = ds_engine.module

    # Run inference
    start_time = time.time()
    inputs = tokenizer.encode("DeepSpeed is", return_tensors="pt").to(
        f"cuda:{local_rank}"
    )
    outputs = model.generate(inputs, max_new_tokens=20)
    output_str = tokenizer.decode(outputs[0])
    end_time = time.time()
    print("ZeRO-inference time:", end_time - start_time)


def run_deepspeed_inference():
    # Load the model on meta tensors
    config = AutoConfig.from_pretrained(model_name, token=hf_token)
    tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)
    with deepspeed.OnDevice(dtype=torch.float16, device="meta", enabled=True):
        model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.float16)

    # Define the checkpoint dict. You may need to convert *.safetensors to
    # *.bin for this work. Make sure you get all the *.bin and *.pt files in
    # the checkpoint_files list.
    checkpoint_dir = "~/.cache/huggingface/hub/models--meta-llama--Llama-2-70b-hf/snapshots/cc8aa03a000ff08b4d5c5b39673321a2a396c396"
    checkpoint_files = [
        os.path.join(checkpoint_dir, f"model-{i:05d}-of-000015.bin")
        for i in range(1, 16)
    ]
    checkpoint_dict = {
        "type": "DS_MODEL",
        "checkpoints": checkpoint_files,
        "version": 1.0,
    }

    # Initialize DeepSpeed
    model = deepspeed.init_inference(
        model,
        replace_with_kernel_inject=False,
        mp_size=world_size,
        dtype=torch.float16,
        checkpoint=checkpoint_dict,
    )

    # Run inference
    start_time = time.time()
    inputs = tokenizer.encode("DeepSpeed is", return_tensors="pt").to(
        f"cuda:{local_rank}"
    )
    outputs = model.generate(inputs, max_new_tokens=20)
    output_str = tokenizer.decode(outputs[0])
    end_time = time.time()
    print("DeepSpeed-inference time:", end_time - start_time)


if __name__ == "__main__":
    run_zero_inference()
    run_deepspeed_inference()

Hi, @mrwyattii . If I use zero3 + offload to train a model. Can I use this code to inference ? I have encountered a very strange problem that when I used my custom code to inference that some weight of layers always random initialization.

image

self.bias_k = Parameter(torch.empty((1, 1, config.hidden_size)))
self.bias_v = Parameter(torch.empty((1, 1, config.hidden_size)))

self.bias_k and self.bias_v is in my custom nework. When I use it to train a model that can train and save the model in zero3+offload. But when I use model to inference, I find self.bias_k and self.bias_v can't load the model weight

when load model finish, the weight of self.bias_v like this ( it seems both of them load fail):

image

the weight of self.bias_v's in the model file(pytorch.bin):

image

@lashoun
Copy link
Author

lashoun commented Aug 31, 2023

@iamsile would you mind opening a new issue for your unrelated problem?

@mrwyattii
Copy link
Contributor

It should be put in the docs

You're right, we should include short tutorials like this in the docs. I'll talk with the team about organizing this.

Could you also quickly explain what kernel injection is, please?

Sure! Kernel injection refers to the use of custom kernels (essentially pieces of code that get executed on your GPU) in DeepSpeed-Inference that provide better performance than the baseline kernels provided by pytorch/transformers. For example, we replace self attention layers with DeepSpeedSelfAttention, which uses custom QKV, linear, etc. kernels to decrease latency and increase throughput for inference!

@mrwyattii
Copy link
Contributor

@iamsile please open a new issue and I can assist. If possible, pleaser provide a minimal example to reproduce the problem you are facing. Thanks!

@lashoun
Copy link
Author

lashoun commented Aug 31, 2023

Thanks again, closing the issue!

@lashoun lashoun closed this as completed Aug 31, 2023
@lashoun lashoun reopened this Sep 5, 2023
@lashoun
Copy link
Author

lashoun commented Sep 5, 2023

Hi @mrwyattii, I've adapted your script to my use case but I noticed that the output always ends up degenerate. It is the meta tensor that causes it, loading a small model on GPU and using init_inference works as expected. Any idea what I could do?
For instance (I've removed the beginning of the input):

### Assistant:ker Ker sharing slide slide baseballeach each aliveoveove introduction function
functionsincePe Pe pe pe slide slide slidesigmaroy ice Ice Ice Den Dennumunu attend
attendtakeorenaren Marian opinion Ps Ps ps pla pla pla pla pla pla pla silence sale посс
Cov Cov Em Em Em Em Em consumer sole soleGrGr Finfin fin fin finfinfinpen Pen Peniry
Bu Bu Bu Contcont Cont Cont Cont pl pl pl Pul Pulvon Part Part Part Wor Sor sorkuku cu pu

@ZaVang
Copy link

ZaVang commented Sep 8, 2023

Hi @mrwyattii, I've adapted your script to my use case but I noticed that the output always ends up degenerate. It is the meta tensor that causes it, loading a small model on GPU and using init_inference works as expected. Any idea what I could do? For instance (I've removed the beginning of the input):

### Assistant:ker Ker sharing slide slide baseballeach each aliveoveove introduction function
functionsincePe Pe pe pe slide slide slidesigmaroy ice Ice Ice Den Dennumunu attend
attendtakeorenaren Marian opinion Ps Ps ps pla pla pla pla pla pla pla silence sale посс
Cov Cov Em Em Em Em Em consumer sole soleGrGr Finfin fin fin finfinfinpen Pen Peniry
Bu Bu Bu Contcont Cont Cont Cont pl pl pl Pul Pulvon Part Part Part Wor Sor sorkuku cu pu

try #4259

@mrwyattii
Copy link
Contributor

Thank you @ZaVang. @lashoun can you confirm if that PR fixes the output?

@lashoun
Copy link
Author

lashoun commented Sep 12, 2023

It does. Thanks @ZaVang.

@lashoun lashoun closed this as not planned Won't fix, can't repro, duplicate, stale Sep 12, 2023
@lashoun lashoun closed this as completed Sep 12, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants