diff --git a/monai/networks/blocks/mlp.py b/monai/networks/blocks/mlp.py index a1728365cf..0feeb044f3 100644 --- a/monai/networks/blocks/mlp.py +++ b/monai/networks/blocks/mlp.py @@ -9,8 +9,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Tuple, Union + import torch.nn as nn +from monai.networks.layers import get_act_layer +from monai.utils import look_up_option + +SUPPORTED_DROPOUT_MODE = {"vit", "swin"} + class MLPBlock(nn.Module): """ @@ -18,12 +25,26 @@ class MLPBlock(nn.Module): An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " """ - def __init__(self, hidden_size: int, mlp_dim: int, dropout_rate: float = 0.0) -> None: + def __init__( + self, + hidden_size: int, + mlp_dim: int, + dropout_rate: float = 0.0, + act: Union[Tuple, str] = "GELU", + dropout_mode="vit", + ) -> None: """ Args: hidden_size: dimension of hidden layer. - mlp_dim: dimension of feedforward layer. + mlp_dim: dimension of feedforward layer. If 0, `hidden_size` will be used. dropout_rate: faction of the input units to drop. + act: activation type and arguments. Defaults to GELU. + dropout_mode: dropout mode, can be "vit" or "swin". + "vit" mode uses two dropout instances as implemented in + https://github.com/google-research/vision_transformer/blob/main/vit_jax/models.py#L87 + "swin" corresponds to one instance as implemented in + https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_mlp.py#L23 + """ @@ -31,12 +52,18 @@ def __init__(self, hidden_size: int, mlp_dim: int, dropout_rate: float = 0.0) -> if not (0 <= dropout_rate <= 1): raise ValueError("dropout_rate should be between 0 and 1.") - + mlp_dim = mlp_dim or hidden_size self.linear1 = nn.Linear(hidden_size, mlp_dim) self.linear2 = nn.Linear(mlp_dim, hidden_size) - self.fn = nn.GELU() + self.fn = get_act_layer(act) self.drop1 = nn.Dropout(dropout_rate) - self.drop2 = nn.Dropout(dropout_rate) + dropout_opt = look_up_option(dropout_mode, SUPPORTED_DROPOUT_MODE) + if dropout_opt == "vit": + self.drop2 = nn.Dropout(dropout_rate) + elif dropout_opt == "swin": + self.drop2 = self.drop1 + else: + raise ValueError(f"dropout_mode should be one of {SUPPORTED_DROPOUT_MODE}") def forward(self, x): x = self.fn(self.linear1(x)) diff --git a/tests/test_mlp.py b/tests/test_mlp.py index 6fec5b6854..737762cfb1 100644 --- a/tests/test_mlp.py +++ b/tests/test_mlp.py @@ -21,7 +21,7 @@ TEST_CASE_MLP = [] for dropout_rate in np.linspace(0, 1, 4): for hidden_size in [128, 256, 512, 768]: - for mlp_dim in [512, 1028, 2048, 3072]: + for mlp_dim in [0, 1028, 2048, 3072]: test_case = [ {"hidden_size": hidden_size, "mlp_dim": mlp_dim, "dropout_rate": dropout_rate},