Skip to content

Commit

Permalink
[Example] Use .fuse() primitive when possible (#42)
Browse files Browse the repository at this point in the history
  • Loading branch information
chhzh123 authored Feb 10, 2023
1 parent 1189cf3 commit 6c2e235
Show file tree
Hide file tree
Showing 28 changed files with 552 additions and 1,018 deletions.
1 change: 1 addition & 0 deletions benchmark/bench_single_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ def list_envs(append_to=None):
"epoi",
"transformers",
"xformers",
"flash_attn",
"megatron",
"deepspeed",
"triton",
Expand Down
12 changes: 12 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,18 @@ git submodule update --init --recursive
pip3 install -e ".[dev]"
```

Note currently we need to apply the following patch to the `xformers` library:
```
XFORMER_PATH=`python3 -c "import xformers, pathlib; print(pathlib.Path(xformers.__path__[0]).parent)"`
cp scripts/xformers_patch $XFORMER_PATH
pushd $XFORMER_PATH
git config --global --add safe.directory $XFORMER_PATH
git reset --hard
git apply xformers_patch
git --no-pager diff
popd
```

- flash-attention:
```
git clone https://github.com/jfc4050/flash-attention.git
Expand Down
13 changes: 12 additions & 1 deletion examples/albert/deepspeed_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,15 @@ def train(args):
model,
config,
prefix="albert",
attn_op_name=args.attn_op_name,
ckpt_ratio=args.checkpoint,
bcast_input=True,
group=group,
pipeline_cuts=pipeline_cuts,
delay_init=enable_pipeline,
)
if SINGLE_DEVICE_FOR_DEBUG:
slapo.build(sch)
slapo.build(sch, init_weights=model._init_weights)
assert False

if enable_pipeline:
Expand All @@ -118,6 +119,7 @@ def loss_fn(outputs, labels):
target="deepspeed",
config=ds_config_dict,
loss_fn=loss_fn,
init_weights=model._init_weights,
)
else:
if batch_size is not None:
Expand All @@ -133,6 +135,7 @@ def loss_fn(outputs, labels):
topology=topology,
target="deepspeed",
config=ds_config_dict,
init_weights=model._init_weights,
)
model = model.to(device)
report_memory(msg="After building model")
Expand Down Expand Up @@ -216,6 +219,14 @@ def loss_fn(outputs, labels):
default=None,
help="Micro batch size per GPU",
)
parser.add_argument(
"--attn_op_name",
type=str,
default="cuda",
help="Attention op name {'native_xformers', 'cutlass', 'triton', 'cuda'}. "
"'cuda' and 'triton' only support sm_80+, and other archs will "
"fallback to 'cutlas'",
)
parser.add_argument(
"--seq_len",
type=int,
Expand Down
4 changes: 2 additions & 2 deletions examples/albert/megatron_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@ def get_model(
sch = schedule_model(
model,
config,
disable_flash_attn=disable_flash_attn,
attn_op_name="native_xformers" if disable_flash_attn else "cuda",
fp16=fp16,
ckpt_ratio=ckpt_ratio,
delay_init=delay_init,
)
model, _ = slapo.build(sch)
model, _ = slapo.build(sch, init_weights=model._init_weights)
report_memory()

elif impl == "torchscript":
Expand Down
21 changes: 15 additions & 6 deletions examples/albert/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
broadcast_input,
checkpoint,
replace_and_shard_attention,
fuse_bias_gelu,
shard_mlp,
shard_word_embedding,
)
Expand All @@ -21,12 +22,13 @@ def schedule_model(
model,
config,
prefix="",
disable_flash_attn=False,
attn_op_name="cuda",
fp16=True,
ckpt_ratio=0.0,
group=None,
bcast_input=False,
pipeline_cuts=None,
disable_fuse_bias_gelu=True,
delay_init=True,
):
logger.info("Scheduling Albert", ranks=0)
Expand All @@ -39,15 +41,22 @@ def schedule_model(

# Replace self attention with flash attention, and shard QKV/output
# if MP group > 1.
if disable_flash_attn:
logger.info("Disabled Flash Attention", rank=0)
cnt = replace_and_shard_attention(
if attn_op_name == "native_xformers":
logger.info("Disabled Flash Attention", ranks=0)
cnt, applied_attn_op_name = replace_and_shard_attention(
sch[prefix],
config,
delay_init=delay_init,
disable_flash_attn=disable_flash_attn,
attn_op_name=attn_op_name,
)
logger.info(f"Replace {cnt} attention patterns", ranks=0)
logger.info(
f"Replace {cnt} attention layers with {applied_attn_op_name} op", ranks=0
)

# Operator fusion
if not disable_fuse_bias_gelu:
fuse_bias_gelu(sch[prefix], config)
logger.info(f"Fused Bias+GeLU", ranks=0)

# Shard other parameters if MP group > 1.
if sch.world_size > 1:
Expand Down
Loading

0 comments on commit 6c2e235

Please sign in to comment.