Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Op] Add attention and bias_gelu ops #41

Merged
merged 7 commits into from
Feb 7, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions ci/task_unit_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,6 @@ set -o pipefail

nvidia-smi -L

echo "Running unit tests..."
# -r: redirect the output of local rank 1 to None so that
# only local rank 0's output is printed to the console.
# -p "no:randomly": disable randomly plugin for sharding tests.
torchrun --nproc_per_node 2 -r 1:1 -m pytest -p "no:randomly" tests

echo "Downloading test data..."
bash benchmark/download_benchmark_dataset.sh

# Remove this path when xFormers fixes this issue.
echo "Applying xFormers path..."
XFORMER_PATH=`python3 -c "import xformers, pathlib; print(pathlib.Path(xformers.__path__[0]).parent)"`
Expand All @@ -28,5 +19,17 @@ git apply xformers_patch
git --no-pager diff
popd

echo "Running unit tests..."
# torchrun:
# -r: redirect the output of local rank 1 to None so that
# only local rank 0's output is printed to the console.
# pytest:
# -rxXs: show extra info for each test, including xfailed, xpassed, and skipped.
# -p "no:randomly": disable randomly plugin for sharding tests.
torchrun --nproc_per_node 2 -r 1:1 -m pytest -rxXs -p "no:randomly" tests

echo "Downloading test data..."
bash benchmark/download_benchmark_dataset.sh

echo "Running end-to-end tests..."
python3 -m pytest -s -p "no:randomly" tests/end2end.py
2 changes: 1 addition & 1 deletion docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ RUN cd $HOME/xformers && \

# Install flash_attn
RUN git clone https://github.com/jfc4050/flash-attention.git $HOME/flash-attention && \
cd $HOME/flash-attention/ && git checkout f528682
cd $HOME/flash-attention/ && git checkout 528c70e
RUN cd $HOME/flash-attention/ && \
pip3 install -e ".[dev]"

Expand Down
2 changes: 1 addition & 1 deletion examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ pip3 install -e ".[dev]"
```
git clone https://github.com/jfc4050/flash-attention.git
cd flash-attention
git checkout f528682
git checkout 528c70e
pip3 install -e ".[dev]"
```

Expand Down
112 changes: 55 additions & 57 deletions examples/gpt/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,12 @@

import inspect

import torch
import torch.nn as nn
from torch.distributed import distributed_c10d as dist

import slapo
from slapo import init_empty_weights
from slapo.pattern import call_module
from slapo.op.linear import FusedQKV
from slapo.op import FlashSelfAttention, FlashAttentionOp, FusedMLP
from slapo import init_empty_weights, get_cuda_rng_tracker


Expand Down Expand Up @@ -55,32 +53,35 @@ def replace_and_shard_attention(
disable_flash_attn=False,
sequence_parallel=False,
):
from epoi.inject.policy.gpt import InjectHFGPTAttentionPolicy
from epoi.ops.xformers_attn import GenericSelfAttention, MemoryEfficientAttentionOp

try:
# Backward compatibility
from epoi.ops.flash_attention import FlashSelfAttention, FlashAttentionTritonOp
except ImportError:
FlashSelfAttention = None
FlashAttentionTritonOp = None

cuda_sm = torch.cuda.get_device_capability("cuda")
if not disable_flash_attn and FlashSelfAttention is not None and cuda_sm == (8, 0):
SelfAttentionModule = FlashSelfAttention
AttentionOp = FlashAttentionTritonOp
attn_op_name = "triton"
else:
SelfAttentionModule = GenericSelfAttention
AttentionOp = MemoryEfficientAttentionOp
attn_op_name = "native" if disable_flash_attn else "cutlass"
attn_op_name = "native_xformers" if disable_flash_attn else "triton"
init_config = dict(
hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads,
is_decoder=True,
attn_pdrop=config.attention_dropout,
resid_pdrop=config.resid_dropout,
attn_op_name=attn_op_name,
fused_qkv=True,
)

class SelfAttention(nn.Module):
"""A wrapper to align the original GPTNeoAttention forward signature."""

def __init__(self, **kwargs):
super().__init__()
self.module = SelfAttentionModule(**kwargs)
try:
self.module = FlashSelfAttention(**kwargs)
except Exception as err:
if kwargs["attn_op_name"] == "native_xformers":
raise RuntimeError(
f"Failed to create native attention: {err}"
) from None

# Failed to use the triton kernel. This may due to unsupported
# GPU (< sm_75) or flash-attention is not installed. Fallback
# to xFormers' cutlass.
kwargs["attn_op_name"] = "cutlass"
self.module = FlashSelfAttention(**kwargs)
comaniac marked this conversation as resolved.
Show resolved Hide resolved

def forward(
self,
Expand All @@ -91,61 +92,55 @@ def forward(
use_cache=False,
output_attentions=False,
):
outputs = self.module(hidden_states, attention_mask, layer_past, use_cache)
"""Match the original GPTNeoAttention forward signature."""
outputs = self.module(
hidden_states,
layer_past,
attention_mask,
head_mask,
None,
None,
use_cache,
output_attentions,
)
# FIXME: The original output is (hidden_states, None) where the None
# is present_key_value and only used by in inference.
return outputs[:1]

class MemoryEfficientAttentionWithRNGOp(AttentionOp):
class AttentionOpWithRNG(FlashAttentionOp):
def forward(self, query_layer, key_layer, value_layer, attention_mask, p):
with get_cuda_rng_tracker().fork():
return super().forward(
query_layer, key_layer, value_layer, attention_mask, p
)

num_layers, num_heads, hidden_size = (
config.num_layers,
config.num_heads,
config.hidden_size,
)

cnt = 0
for idx in range(num_layers):
for idx in range(config.num_layers):
sub_sch = sch[attn_path.replace("N", str(idx))]
init_config = InjectHFGPTAttentionPolicy.gen_init_config_from_object(
sub_sch.mod, attn_op_name=attn_op_name
)
with init_empty_weights(enable=delay_init):
new_mod = SelfAttention(**init_config)
sub_sch.replace(new_mod)
sub_sch.trace(
tracer="pytorch",
leaf_modules=[AttentionOp.__name__],
leaf_modules=["FlashAttentionOp"],
concrete_args={
"layer_past": None,
"head_mask": None,
"encoder_hidden_states": None,
"encoder_attention_mask": None,
"use_cache": False,
"output_attentions": False,
},
)

def pattern(x: torch.Tensor) -> torch.Tensor:
x = call_module("query|key|value", x)
new_x_shape = x.size()[:-1] + (num_heads, hidden_size)
x = x.view(new_x_shape)
return x

subgraphs = sub_sch["module"].find(pattern)
assert len(subgraphs) == 3
with init_empty_weights(enable=delay_init):
new_fused_qkv = FusedQKV(hidden_size, num_heads, sch.world_size)
sub_sch["module"].replace(new_fused_qkv, subgraphs)
if sch.world_size > 1:
sub_sch["module.FusedQKV_0.fused_linear"].shard("weight", axis=0)
sub_sch["module.FusedQKV_0.fused_linear"].shard("bias", axis=0)
sub_sch["module.qkv"].shard("weight", axis=0)
sub_sch["module.qkv"].shard("bias", axis=0)
sub_sch["module.out_proj"].shard("weight", axis=1)
fix_attention_mask_shape(sub_sch["module"])

if sequence_parallel:
sub_sch["module.FusedQKV_0.fused_linear"].sync(
sub_sch["module.qkv"].sync(
mode="fwd_pre", sync_op_or_fn="all_gather", axis=1
)

Expand All @@ -154,18 +149,17 @@ def pattern(x: torch.Tensor) -> torch.Tensor:
)
else:
# Shard qkv and output projection.
sub_sch["module.FusedQKV_0.fused_linear"].sync(
mode="bwd_post", sync_op_or_fn="all_reduce"
)
sub_sch["module.qkv"].sync(mode="bwd_post", sync_op_or_fn="all_reduce")
sub_sch["module.out_proj"].sync(
mode="fwd_post", sync_op_or_fn="all_reduce"
)

# In this case, the attention dropout in between has to
# use different random seeds.
new_op = MemoryEfficientAttentionWithRNGOp(
new_op = AttentionOpWithRNG(
sub_sch["module"]["attn_op"].mod.attn_op_name,
sub_sch["module"]["attn_op"].mod.apply_causal_mask,
sub_sch["module"]["attn_op"].mod.scale,
)
sub_sch["module"]["attn_op"].replace(new_op)

Expand Down Expand Up @@ -235,14 +229,18 @@ def replace_and_shard_mlp(
delay_init=True,
sequence_parallel=False,
):
from epoi.inject.policy.gpt import InjectHFGPTMLPPolicy

for idx in range(config.num_layers):
prefix = path.replace("N", str(idx))
if config.activation_function in ["gelu", "gelu_new"]:
sub_sch = sch[prefix]
inter_size, hidden_size = sub_sch.mod.c_fc.weight.shape
with init_empty_weights(enable=delay_init):
new_mod = InjectHFGPTMLPPolicy.init_from_object(sub_sch.mod)
new_mod = FusedMLP(
hidden_size,
inter_size,
config.activation_function,
config.resid_dropout,
)
sub_sch.replace(new_mod)
sub_sch.trace(leaf_modules=["FusedBiasGELU", "FusedBiasNewGELU"])

Expand Down
4 changes: 4 additions & 0 deletions slapo/op/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
"""Custom Ops."""
from .attention import FlashSelfAttention, FlashAttentionOp
from .bias_gelu import FusedBiasGELU, FusedBiasNewGELU
from .cross_entropy import ParallelCrossEntropy
from .dropout import DropoutWithTensorParallel
from .linear import FusedQKV
from .mlp import FusedMLP
Loading