Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Example] Use .fuse() primitive when possible #42

Merged
merged 36 commits into from
Feb 10, 2023
Merged
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
2564827
Support fuse bias layernorm
chhzh123 Feb 4, 2023
59177d1
Add fuse_bias_gelu to bert
chhzh123 Feb 4, 2023
9ed1f02
Fix bias_ln
chhzh123 Feb 4, 2023
34f2b24
Add separate op_fusion
chhzh123 Feb 5, 2023
7bb9562
Fix pylint
chhzh123 Feb 6, 2023
87d03b4
Add fuse_bias_gelu to albert & Fix shape
chhzh123 Feb 7, 2023
9a2b23d
Add flash_attn to list_envs
chhzh123 Feb 7, 2023
71edf5e
Add gelu act to albert
chhzh123 Feb 7, 2023
d7e2827
Add bias_gelu_fusion to roberta
chhzh123 Feb 7, 2023
8b98681
Add separate_fusion to opt
chhzh123 Feb 7, 2023
5138b75
Update README
chhzh123 Feb 7, 2023
b21727b
Uncomment fuse_bias_gelu in albert
chhzh123 Feb 7, 2023
63fb4f8
Add flag to albert
chhzh123 Feb 7, 2023
4318a89
Fix albert
chhzh123 Feb 7, 2023
2381d33
Disable fuse_bias_gelu
chhzh123 Feb 7, 2023
4c4fe5d
Refactor opt schedule
chhzh123 Feb 7, 2023
9239270
Remove timing test
chhzh123 Feb 7, 2023
29e26d0
Fix format
chhzh123 Feb 8, 2023
6f3ae92
Fix flag
chhzh123 Feb 8, 2023
770019f
Reset default value of disable_fuse_bias_gelu
chhzh123 Feb 8, 2023
1c8e277
Delete demo
chhzh123 Feb 9, 2023
94bd8bf
Fix attention name and signature
chhzh123 Feb 10, 2023
407814b
Use slapo.op to schedule albert
chhzh123 Feb 10, 2023
c247d6f
Fix GPT schedule
chhzh123 Feb 10, 2023
bec0c7d
Fix BERT
chhzh123 Feb 10, 2023
0b7b202
Fix bert & roberta
chhzh123 Feb 10, 2023
9d83269
Fix OPT
chhzh123 Feb 10, 2023
7b0ff5c
Update OPT schedule
chhzh123 Feb 10, 2023
6a54c62
Fix albert
chhzh123 Feb 10, 2023
010361f
Fix bert
chhzh123 Feb 10, 2023
aac2444
Fix gpt opt
chhzh123 Feb 10, 2023
85be556
Fix roberta
chhzh123 Feb 10, 2023
7bfd769
Fix format
chhzh123 Feb 10, 2023
0437c09
Add mlp to opt
chhzh123 Feb 10, 2023
eacf24c
Update op test
chhzh123 Feb 10, 2023
86a8b8a
Fix gpt2
chhzh123 Feb 10, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions benchmark/bench_single_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ def list_envs(append_to=None):
"epoi",
"transformers",
"xformers",
"flash_attn",
"megatron",
"deepspeed",
"triton",
Expand Down
12 changes: 12 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,18 @@ git submodule update --init --recursive
pip3 install -e ".[dev]"
```

Note currently we need to apply the following patch to the `xformers` library:
```
XFORMER_PATH=`python3 -c "import xformers, pathlib; print(pathlib.Path(xformers.__path__[0]).parent)"`
cp scripts/xformers_patch $XFORMER_PATH
pushd $XFORMER_PATH
git config --global --add safe.directory $XFORMER_PATH
git reset --hard
git apply xformers_patch
git --no-pager diff
popd
```

- flash-attention:
```
git clone https://github.com/jfc4050/flash-attention.git
Expand Down
13 changes: 12 additions & 1 deletion examples/albert/deepspeed_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,15 @@ def train(args):
model,
config,
prefix="albert",
attn_op_name=args.attn_op_name,
ckpt_ratio=args.checkpoint,
bcast_input=True,
group=group,
pipeline_cuts=pipeline_cuts,
delay_init=enable_pipeline,
)
if SINGLE_DEVICE_FOR_DEBUG:
slapo.build(sch)
slapo.build(sch, init_weights=model._init_weights)
assert False

if enable_pipeline:
Expand All @@ -118,6 +119,7 @@ def loss_fn(outputs, labels):
target="deepspeed",
config=ds_config_dict,
loss_fn=loss_fn,
init_weights=model._init_weights,
)
else:
if batch_size is not None:
Expand All @@ -133,6 +135,7 @@ def loss_fn(outputs, labels):
topology=topology,
target="deepspeed",
config=ds_config_dict,
init_weights=model._init_weights,
)
model = model.to(device)
report_memory(msg="After building model")
Expand Down Expand Up @@ -216,6 +219,14 @@ def loss_fn(outputs, labels):
default=None,
help="Micro batch size per GPU",
)
parser.add_argument(
"--attn_op_name",
type=str,
default="cuda",
help="Attention op name {'native_xformers', 'cutlass', 'triton', 'cuda'}. "
"'cuda' and 'triton' only support sm_80+, and other archs will "
"fallback to 'cutlas'",
)
parser.add_argument(
"--seq_len",
type=int,
Expand Down
4 changes: 2 additions & 2 deletions examples/albert/megatron_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@ def get_model(
sch = schedule_model(
model,
config,
disable_flash_attn=disable_flash_attn,
attn_op_name="native_xformers" if disable_flash_attn else "cuda",
fp16=fp16,
ckpt_ratio=ckpt_ratio,
delay_init=delay_init,
)
model, _ = slapo.build(sch)
model, _ = slapo.build(sch, init_weights=model._init_weights)
report_memory()

elif impl == "torchscript":
Expand Down
21 changes: 15 additions & 6 deletions examples/albert/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
broadcast_input,
checkpoint,
replace_and_shard_attention,
fuse_bias_gelu,
shard_mlp,
shard_word_embedding,
)
Expand All @@ -21,12 +22,13 @@ def schedule_model(
model,
config,
prefix="",
disable_flash_attn=False,
attn_op_name="cuda",
fp16=True,
ckpt_ratio=0.0,
group=None,
bcast_input=False,
pipeline_cuts=None,
disable_fuse_bias_gelu=True,
delay_init=True,
):
logger.info("Scheduling Albert", ranks=0)
Expand All @@ -39,15 +41,22 @@ def schedule_model(

# Replace self attention with flash attention, and shard QKV/output
# if MP group > 1.
if disable_flash_attn:
logger.info("Disabled Flash Attention", rank=0)
cnt = replace_and_shard_attention(
if attn_op_name == "native_xformers":
logger.info("Disabled Flash Attention", ranks=0)
cnt, applied_attn_op_name = replace_and_shard_attention(
sch[prefix],
config,
delay_init=delay_init,
disable_flash_attn=disable_flash_attn,
attn_op_name=attn_op_name,
)
logger.info(f"Replace {cnt} attention patterns", ranks=0)
logger.info(
f"Replace {cnt} attention layers with {applied_attn_op_name} op", ranks=0
)

# Operator fusion
if not disable_fuse_bias_gelu:
fuse_bias_gelu(sch[prefix], config)
logger.info(f"Fused Bias+GeLU", ranks=0)

# Shard other parameters if MP group > 1.
if sch.world_size > 1:
Expand Down
Loading