From 2536cbb20d65955ac02ef9ab797a38a54dc7f90f Mon Sep 17 00:00:00 2001 From: Benjamin Lefaudeux Date: Mon, 23 May 2022 18:01:09 -0700 Subject: [PATCH] issue #290 --- requirements-test.txt | 2 +- tests/test_pickling.py | 18 +++++++++----- xformers/triton/dropout.py | 19 +++++++++------ xformers/triton/k_activations.py | 42 ++++++++++++++++---------------- 4 files changed, 45 insertions(+), 36 deletions(-) diff --git a/requirements-test.txt b/requirements-test.txt index d8157d6e30..a38030d1fd 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -26,4 +26,4 @@ hydra-core >= 1.1 fairscale >= 0.4.5 # Dependency for fused layers, optional -triton == 2.0.0.dev20220403 +triton == 2.0.0.dev20220523 diff --git a/tests/test_pickling.py b/tests/test_pickling.py index 72f19ebd2a..1d73f4c3b5 100644 --- a/tests/test_pickling.py +++ b/tests/test_pickling.py @@ -7,7 +7,9 @@ # https://github.com/facebookresearch/xformers/issues/203 import pickle +from copy import deepcopy +import pytest from torch import nn from xformers.factory import xFormer, xFormerConfig @@ -30,7 +32,7 @@ }, }, "feedforward_config": { - "name": "MLP", + "name": "FusedMLP", "dropout": 0.1, "activation": "gelu", "hidden_layer_multiplier": 4, @@ -40,11 +42,15 @@ class ViT(nn.Module): - def __init__(self): + def __init__(self, mlp): super().__init__() - self.xformer = xFormer.from_config(xFormerConfig(test_config)) + test_config[0]["feedforward_config"]["name"] = mlp + xformer_config = xFormerConfig(test_config) + self.xformer = xFormer.from_config(xformer_config) -def test_pickling(): - test = ViT() - pickle.dumps(test) +@pytest.mark.parametrize("mlp", ["MLP", "FusedMLP"]) +def test_pickling(mlp): + test = ViT(mlp) + _ = pickle.dumps(test) + _ = deepcopy(test) diff --git a/xformers/triton/dropout.py b/xformers/triton/dropout.py index f14099d5ad..8b505350d7 100644 --- a/xformers/triton/dropout.py +++ b/xformers/triton/dropout.py @@ -13,11 +13,8 @@ import triton from torch.cuda.amp import custom_bwd, custom_fwd +import xformers.triton.k_activations from xformers.components.activations import Activation, build_activation -from xformers.triton.k_activations import ( - get_triton_activation_bwd_kernel, - get_triton_activation_kernel, -) from xformers.triton.k_dropout import k_dropout_bw, k_dropout_fw GROUP_M = 32 @@ -183,8 +180,10 @@ def dropout( return x # The normal triton enabled codepath - act_kernel = get_triton_activation_kernel(activation) - act_grad_kernel = get_triton_activation_bwd_kernel(activation) + act_kernel = xformers.triton.k_activations.get_triton_activation_kernel(activation) + act_grad_kernel = xformers.triton.k_activations.get_triton_activation_bwd_kernel( + activation + ) return _dropout.apply( x, float(p), @@ -216,9 +215,13 @@ def __init__( if bias_shape is not None else None ) - self.activation = get_triton_activation_kernel(activation) + self.activation = xformers.triton.k_activations.get_triton_activation_kernel( + activation + ) self.pytorch_activation = build_activation(self.activation_type) - self.activation_grad = get_triton_activation_bwd_kernel(activation) + self.activation_grad = ( + xformers.triton.k_activations.get_triton_activation_bwd_kernel(activation) + ) def forward(self, x: torch.Tensor) -> torch.Tensor: # Convenience, catch a possible type or device mismatch diff --git a/xformers/triton/k_activations.py b/xformers/triton/k_activations.py index 3180f943df..aaea263905 100644 --- a/xformers/triton/k_activations.py +++ b/xformers/triton/k_activations.py @@ -17,11 +17,11 @@ def get_triton_activation_kernel(activation: Optional[Activation]): return ( { - Activation.ReLU: relu, - Activation.LeakyReLU: leaky_relu, - Activation.GeLU: gelu, - Activation.SquaredReLU: squared_relu, - Activation.SmeLU: smelu, + Activation.ReLU: _relu, + Activation.LeakyReLU: _leaky_relu, + Activation.GeLU: _gelu, + Activation.SquaredReLU: _squared_relu, + Activation.SmeLU: _smelu, }[activation] if activation else None @@ -31,11 +31,11 @@ def get_triton_activation_kernel(activation: Optional[Activation]): def get_triton_activation_bwd_kernel(activation: Optional[Activation]): return ( { - Activation.ReLU: relu_grad, - Activation.LeakyReLU: leaky_relu_grad, - Activation.GeLU: gelu_grad, - Activation.SquaredReLU: squared_relu_grad, - Activation.SmeLU: smelu_grad, + Activation.ReLU: _relu_grad, + Activation.LeakyReLU: _leaky_relu_grad, + Activation.GeLU: _gelu_grad, + Activation.SquaredReLU: _squared_relu_grad, + Activation.SmeLU: _smelu_grad, }[activation] if activation else None @@ -59,7 +59,7 @@ def cosh(x): # ReLU @triton.jit -def relu(x): +def _relu(x): """ ReLU_ activation function @@ -70,7 +70,7 @@ def relu(x): @triton.jit -def relu_grad(x): +def _relu_grad(x): # ReLU is different from other activations # in that it does not require the input to retrospectively compute its gradient # here the input is the downstream gradient, and we return the upstream gradient directly @@ -80,24 +80,24 @@ def relu_grad(x): @triton.jit -def squared_relu(x): +def _squared_relu(x): """ Squared ReLU activation, as proposed in the Primer_ paper. .. _Primer: https://arxiv.org/abs/2109.08668 """ - x_ = relu(x) + x_ = _relu(x) return (x_ * x_).to(x.dtype) @triton.jit -def squared_relu_grad(x): +def _squared_relu_grad(x): return tl.where(x >= 0, 2.0 * x, 0.0) # Leaky ReLU @triton.jit -def leaky_relu(x): +def _leaky_relu(x): """ LeakyReLU_ activation @@ -109,7 +109,7 @@ def leaky_relu(x): @triton.jit -def leaky_relu_grad(x): +def _leaky_relu_grad(x): min_grad = 0.01 max_grad = 1 @@ -120,7 +120,7 @@ def leaky_relu_grad(x): @triton.jit -def gelu(x): +def _gelu(x): """ GeLU_ activation - Gaussian error linear unit @@ -130,7 +130,7 @@ def gelu(x): @triton.jit -def gelu_grad(x): +def _gelu_grad(x): # CREDITS: Fast implementation proposed in # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/fused_bias_gelu.py#L30 tanh_out = tanh(0.79788456 * x * (1 + 0.044715 * x * x)) @@ -140,7 +140,7 @@ def gelu_grad(x): @triton.jit -def smelu(x): +def _smelu(x): """ SmeLU_ activation - Smooth ReLU with beta=2.0 @@ -157,7 +157,7 @@ def smelu(x): @triton.jit -def smelu_grad(x): +def _smelu_grad(x): zero = 0.0 one = 1.0 two = 2.0