Skip to content

Commit

Permalink
Cache
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerzr committed Aug 21, 2024
1 parent 1cb5649 commit 323ee76
Showing 1 changed file with 25 additions and 66 deletions.
91 changes: 25 additions & 66 deletions examples/inference/pippy/t5.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,20 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time

import os
import torch
from transformers import AutoModelForSeq2SeqLM

from accelerate import PartialState, prepare_pippy
from accelerate.utils import set_seed
from torch.distributed.pipelining import ScheduleGPipe, SplitPoint, pipeline


rank = int(os.environ.get("RANK", -1))

# Set the random seed to have reproducable outputs
set_seed(42)
torch.distributed.init_process_group(
backend="nccl",
rank=rank,
world_size=2
)

# Create an example model
model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")
model.eval()

# Input configs
# Create example inputs for the model
input = torch.randint(
low=0,
high=model.config.vocab_size,
Expand All @@ -38,52 +24,25 @@
requires_grad=False,
)

example_inputs = {"input_ids": input, "decoder_input_ids": input}

# Create a pipeline stage from the model
# Using `auto` is equivalent to letting `device_map="auto"` figure
# out device mapping and will also split the model according to the
# number of total GPUs available if it fits on one GPU
model = prepare_pippy(
example_args = ()
example_kwargs = {"input_ids": input, "decoder_input_ids": input}
num_chunks = 2
split_points = ['decoder.block.0']
print(f'Using split points: {split_points} on num_chunks: {num_chunks}')
split_spec = {split_point: SplitPoint.BEGINNING for split_point in split_points}
pipe = pipeline(
model,
no_split_module_classes=["T5Block"],
example_kwargs=example_inputs,
mb_args=example_args,
mb_kwargs=example_kwargs,
split_spec=split_spec
)
stage = pipe.build_stage(rank, torch.device("cuda"))
schedule = ScheduleGPipe(stage, num_chunks)

# You can pass `gather_output=True` to have the output from the model
# available on all GPUs
# model = prepare_pippy(
# model,
# no_split_module_classes=["T5Block"],
# example_kwargs=example_inputs,
# gather_outputs=True
# )

# The model expects a tuple during real inference
# with the data on the first device
args = (example_inputs["input_ids"].to("cuda:0"), example_inputs["decoder_input_ids"].to("cuda:0"))
batch_size = 2

# Take an average of 5 times
# Measure first batch
torch.cuda.synchronize()
start_time = time.time()
with torch.no_grad():
output = model(*args)
torch.cuda.synchronize()
end_time = time.time()
first_batch = end_time - start_time
kwargs = {"input_ids": example_kwargs["input_ids"].to("cuda:0"), "decoder_input_ids": example_kwargs["decoder_input_ids"].to("cuda:0")}

# Now that CUDA is init, measure after
torch.cuda.synchronize()
start_time = time.time()
for i in range(5):
with torch.no_grad():
output = model(*args)
torch.cuda.synchronize()
end_time = time.time()

# The outputs are only on the final process by default
if PartialState().is_last_process:
output = torch.stack(tuple(output[0]))
print(f"Time of first pass: {first_batch}")
print(f"Average time per batch: {(end_time - start_time) / 5}")
with torch.no_grad():
output = model(**kwargs)

0 comments on commit 323ee76

Please sign in to comment.