Skip to content

Commit

Permalink
issue #290
Browse files Browse the repository at this point in the history
  • Loading branch information
blefaudeux committed May 24, 2022
1 parent 3f027f4 commit 2536cbb
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 36 deletions.
2 changes: 1 addition & 1 deletion requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@ hydra-core >= 1.1
fairscale >= 0.4.5

# Dependency for fused layers, optional
triton == 2.0.0.dev20220403
triton == 2.0.0.dev20220523
18 changes: 12 additions & 6 deletions tests/test_pickling.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
# https://github.com/facebookresearch/xformers/issues/203

import pickle
from copy import deepcopy

import pytest
from torch import nn

from xformers.factory import xFormer, xFormerConfig
Expand All @@ -30,7 +32,7 @@
},
},
"feedforward_config": {
"name": "MLP",
"name": "FusedMLP",
"dropout": 0.1,
"activation": "gelu",
"hidden_layer_multiplier": 4,
Expand All @@ -40,11 +42,15 @@


class ViT(nn.Module):
def __init__(self):
def __init__(self, mlp):
super().__init__()
self.xformer = xFormer.from_config(xFormerConfig(test_config))
test_config[0]["feedforward_config"]["name"] = mlp
xformer_config = xFormerConfig(test_config)
self.xformer = xFormer.from_config(xformer_config)


def test_pickling():
test = ViT()
pickle.dumps(test)
@pytest.mark.parametrize("mlp", ["MLP", "FusedMLP"])
def test_pickling(mlp):
test = ViT(mlp)
_ = pickle.dumps(test)
_ = deepcopy(test)
19 changes: 11 additions & 8 deletions xformers/triton/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,8 @@
import triton
from torch.cuda.amp import custom_bwd, custom_fwd

import xformers.triton.k_activations
from xformers.components.activations import Activation, build_activation
from xformers.triton.k_activations import (
get_triton_activation_bwd_kernel,
get_triton_activation_kernel,
)
from xformers.triton.k_dropout import k_dropout_bw, k_dropout_fw

GROUP_M = 32
Expand Down Expand Up @@ -183,8 +180,10 @@ def dropout(
return x

# The normal triton enabled codepath
act_kernel = get_triton_activation_kernel(activation)
act_grad_kernel = get_triton_activation_bwd_kernel(activation)
act_kernel = xformers.triton.k_activations.get_triton_activation_kernel(activation)
act_grad_kernel = xformers.triton.k_activations.get_triton_activation_bwd_kernel(
activation
)
return _dropout.apply(
x,
float(p),
Expand Down Expand Up @@ -216,9 +215,13 @@ def __init__(
if bias_shape is not None
else None
)
self.activation = get_triton_activation_kernel(activation)
self.activation = xformers.triton.k_activations.get_triton_activation_kernel(
activation
)
self.pytorch_activation = build_activation(self.activation_type)
self.activation_grad = get_triton_activation_bwd_kernel(activation)
self.activation_grad = (
xformers.triton.k_activations.get_triton_activation_bwd_kernel(activation)
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
# Convenience, catch a possible type or device mismatch
Expand Down
42 changes: 21 additions & 21 deletions xformers/triton/k_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
def get_triton_activation_kernel(activation: Optional[Activation]):
return (
{
Activation.ReLU: relu,
Activation.LeakyReLU: leaky_relu,
Activation.GeLU: gelu,
Activation.SquaredReLU: squared_relu,
Activation.SmeLU: smelu,
Activation.ReLU: _relu,
Activation.LeakyReLU: _leaky_relu,
Activation.GeLU: _gelu,
Activation.SquaredReLU: _squared_relu,
Activation.SmeLU: _smelu,
}[activation]
if activation
else None
Expand All @@ -31,11 +31,11 @@ def get_triton_activation_kernel(activation: Optional[Activation]):
def get_triton_activation_bwd_kernel(activation: Optional[Activation]):
return (
{
Activation.ReLU: relu_grad,
Activation.LeakyReLU: leaky_relu_grad,
Activation.GeLU: gelu_grad,
Activation.SquaredReLU: squared_relu_grad,
Activation.SmeLU: smelu_grad,
Activation.ReLU: _relu_grad,
Activation.LeakyReLU: _leaky_relu_grad,
Activation.GeLU: _gelu_grad,
Activation.SquaredReLU: _squared_relu_grad,
Activation.SmeLU: _smelu_grad,
}[activation]
if activation
else None
Expand All @@ -59,7 +59,7 @@ def cosh(x):

# ReLU
@triton.jit
def relu(x):
def _relu(x):
"""
ReLU_ activation function
Expand All @@ -70,7 +70,7 @@ def relu(x):


@triton.jit
def relu_grad(x):
def _relu_grad(x):
# ReLU is different from other activations
# in that it does not require the input to retrospectively compute its gradient
# here the input is the downstream gradient, and we return the upstream gradient directly
Expand All @@ -80,24 +80,24 @@ def relu_grad(x):


@triton.jit
def squared_relu(x):
def _squared_relu(x):
"""
Squared ReLU activation, as proposed in the Primer_ paper.
.. _Primer: https://arxiv.org/abs/2109.08668
"""
x_ = relu(x)
x_ = _relu(x)
return (x_ * x_).to(x.dtype)


@triton.jit
def squared_relu_grad(x):
def _squared_relu_grad(x):
return tl.where(x >= 0, 2.0 * x, 0.0)


# Leaky ReLU
@triton.jit
def leaky_relu(x):
def _leaky_relu(x):
"""
LeakyReLU_ activation
Expand All @@ -109,7 +109,7 @@ def leaky_relu(x):


@triton.jit
def leaky_relu_grad(x):
def _leaky_relu_grad(x):
min_grad = 0.01
max_grad = 1

Expand All @@ -120,7 +120,7 @@ def leaky_relu_grad(x):


@triton.jit
def gelu(x):
def _gelu(x):
"""
GeLU_ activation - Gaussian error linear unit
Expand All @@ -130,7 +130,7 @@ def gelu(x):


@triton.jit
def gelu_grad(x):
def _gelu_grad(x):
# CREDITS: Fast implementation proposed in
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/fused_bias_gelu.py#L30
tanh_out = tanh(0.79788456 * x * (1 + 0.044715 * x * x))
Expand All @@ -140,7 +140,7 @@ def gelu_grad(x):


@triton.jit
def smelu(x):
def _smelu(x):
"""
SmeLU_ activation - Smooth ReLU with beta=2.0
Expand All @@ -157,7 +157,7 @@ def smelu(x):


@triton.jit
def smelu_grad(x):
def _smelu_grad(x):
zero = 0.0
one = 1.0
two = 2.0
Expand Down

0 comments on commit 2536cbb

Please sign in to comment.