Skip to content

Commit

Permalink
Merge pull request #56 from unionai/liger-kernel-benchmark
Browse files Browse the repository at this point in the history
add example to fine-tune an LLM with the liger kernel
  • Loading branch information
cosmicBboy authored Sep 25, 2024
2 parents 8d76d1e + 7b3a9ae commit 059d014
Show file tree
Hide file tree
Showing 6 changed files with 828 additions and 0 deletions.
8 changes: 8 additions & 0 deletions run_commands.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,11 @@ tutorials/video_translation/video_translation.py:
- git clone https://github.com/unionai/unionai-examples
- cd unionai-examples/tutorials/video_translation
- union run --remote --copy-all video_translation.py video_translation_wf
tutorials/liger_kernel_finetuning/liger_kernel_finetuning.py:
- git clone https://github.com/unionai/unionai-examples
- cd unionai-examples/tutorials/liger_kernel_finetuning
- "# create a huggingface key: https://huggingface.co/settings/tokens, then run the following command"
- union secrets create huggingface_api_key --value <your_huggingface_api_key>
- "# create a weights and biases key: https://wandb.ai/settings, then run the following command"
- union secrets create wandb_api_key --value <your_wandb_api_key>
- union run --remote liger_kernel_finetuning.py benchmarking_experiment --inputs-file phi3_inputs.yaml
269 changes: 269 additions & 0 deletions tutorials/liger_kernel_finetuning/callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,269 @@
"""Callback module to track the efficiency of the training process.
This is adapted from the official Liger Kernel repository example:
https://github.com/linkedin/Liger-Kernel/blob/main/examples/huggingface/callback.py
"""

import time
from dataclasses import dataclass

import torch
import transformers
from transformers import TrainerControl, TrainerState, TrainingArguments

# https://simple.wikipedia.org/wiki/Byte
# For memory, we use binary system
M_BIN_UNIT = 2**20
# For metrics (tflops), we use decimal system
T_DEC_UNIT = 10**12


def round_to_n_decimal(x, n):
return round(x, n)


@dataclass
class Precision:
"""
Precision is a dataclass to store the number of decimal points for each metric.
"""

n_decimal_time: int
n_decimal_memory: int
n_decimal_TPS: int


@dataclass
class State:
"""
State is a dataclass to store the internal state of the efficiency callback.
"""

n_warmup_steps: int = 0
total_peak_memory_allocated: float = float("-inf")
total_peak_memory_reserved: float = float("-inf")

step_start_time: float = 0.0
elapsed_time: float = 0.0

elapsed_step: int = 0

step_start_tokens_seen: int = 0
elapsed_tokens_seen: int = 0

global_start_step: int = 0


@dataclass
class Time:
"""
Time is a dataclass to store the time-related metrics.
"""

step: int = 0
step_time_sec: float = 0.0
avg_step_time_sec: float = 0.0
time_to_completion_sec: float = 0.0
estimated_total_time_sec: float = 0.0


@dataclass
class Memory:
"""
Memory is a dataclass to store the memory-related metrics.
"""

step_peak_memory_allocated_MB: float = 0.0
step_peak_memory_reserved_MB: float = 0.0
total_peak_memory_allocated_MB: float = 0.0
total_peak_memory_reserved_MB: float = 0.0


@dataclass
class TPS:
"""
TPS is a dataclass to store the tokens per second metrics.
"""

step_tokens_per_second: float = 0.0
avg_tokens_per_second: float = 0.0


class EfficiencyCallback(transformers.TrainerCallback):
"""
EfficiencyCallback is a callback to track the efficiency of the training process.
The tracked stats include: step time, memory, and throughput.
It requires including `--include_num_input_tokens_seen` and `logging_steps=1` in the training arguments.
Args:
n_warmup_steps: number of warmup steps
The stats in the first n_warmup_steps will not be added into the aggregated stats
This is because the first few steps might take longer due to jit compliation and other initialization overheads
n_decimal_time: number of decimal points for time
n_decimal_memory: number of decimal points for memory
n_decimal_TPS: number of decimal points for TPS
"""

def __init__(
self, n_warmup_steps=2, n_decimal_time=2, n_decimal_memory=2, n_decimal_TPS=2
):
self.state = State(
n_warmup_steps,
)

self.precision = Precision(n_decimal_time, n_decimal_memory, n_decimal_TPS)

self.time = Time()
self.memory = Memory()
self.tps = TPS()

def on_init_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
"""
Event called at the end of the initialization of the [`Trainer`].
"""
if not args.include_num_input_tokens_seen:
raise Exception(
'Please pass training argument "--include_num_input_tokens_seen" to track tokens per second'
)
if args.logging_steps != 1:
raise Exception(
"Please set logging_steps=1 to track the efficiency metrics accurately"
)

def on_train_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
# if loaded from checkpoints, global_start_step is not 1 but state.global_step
self.state.global_start_step = state.global_step

def on_log(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
logs: dict[str, float],
**kwargs,
):
if state.global_step < (self.state.global_start_step + self.state.n_warmup_steps):
return
else:
# spread self.time, self.memory, self.tps to logs
logs.update(self.time.__dict__)
logs.update(self.memory.__dict__)
logs.update(self.tps.__dict__)
if state.log_history and state.log_history[-1]["step"] == logs["step"]:
# override last log history if step is the same with the updated metrics
state.log_history[-1] = logs

def on_step_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
"""
Event called at the beginning of a training step. If using gradient accumulation, one training step might take
several inputs.
"""
# memory
torch.cuda.reset_peak_memory_stats()

# time
self.state.step_start_time = time.perf_counter()

def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
if state.global_step < (self.state.global_start_step + self.state.n_warmup_steps):
# The end the current step_start_tokens_seen is the start of next iteration

# tokens
self.state.step_start_tokens_seen = state.num_input_tokens_seen
return

# time
current_time = time.perf_counter()
step_time = current_time - self.state.step_start_time
self.state.elapsed_time += step_time

# step
global_step = state.global_step
self.state.elapsed_step += 1
avg_step_time = self.state.elapsed_time / self.state.elapsed_step

self.time.step = global_step
self.time.step_time_sec = round_to_n_decimal(step_time, self.precision.n_decimal_time)
self.time.avg_step_time_sec = round_to_n_decimal(
avg_step_time, self.precision.n_decimal_time
)
self.time.time_to_completion_sec = round_to_n_decimal(
avg_step_time * (state.max_steps - global_step),
self.precision.n_decimal_time,
)
self.time.estimated_total_time_sec = round_to_n_decimal(
avg_step_time * state.max_steps, self.precision.n_decimal_time
)

# memory
step_peak_memory_allocated = torch.cuda.memory.max_memory_allocated()
step_peak_memory_reserved = torch.cuda.memory.max_memory_reserved()

self.memory.step_peak_memory_allocated_MB = round_to_n_decimal(
step_peak_memory_allocated / M_BIN_UNIT, self.precision.n_decimal_memory
)
self.state.total_peak_memory_allocated = max(
self.state.total_peak_memory_allocated, step_peak_memory_allocated
)
self.memory.total_peak_memory_allocated_MB = round_to_n_decimal(
self.state.total_peak_memory_allocated / M_BIN_UNIT,
self.precision.n_decimal_memory,
)

self.memory.step_peak_memory_reserved_MB = round_to_n_decimal(
step_peak_memory_reserved / M_BIN_UNIT, self.precision.n_decimal_memory
)

self.state.total_peak_memory_reserved = max(
self.state.total_peak_memory_reserved, step_peak_memory_reserved
)

self.memory.total_peak_memory_reserved_MB = round_to_n_decimal(
self.state.total_peak_memory_reserved / M_BIN_UNIT,
self.precision.n_decimal_memory,
)

# tokens
step_tokens_seen = state.num_input_tokens_seen - self.state.step_start_tokens_seen

self.state.elapsed_tokens_seen += step_tokens_seen

self.tps.step_tokens_per_second = round_to_n_decimal(
step_tokens_seen / step_time,
self.precision.n_decimal_TPS,
)

self.tps.avg_tokens_per_second = round_to_n_decimal(
self.state.elapsed_tokens_seen / self.state.elapsed_time,
self.precision.n_decimal_TPS,
)

# The end the current step_start_tokens_seen is the start of next iteration

# tokens
self.state.step_start_tokens_seen = state.num_input_tokens_seen
Loading

0 comments on commit 059d014

Please sign in to comment.