diff --git a/hivemind/server/layers/common.py b/hivemind/server/layers/common.py index ee41e737e..c1c590cfc 100644 --- a/hivemind/server/layers/common.py +++ b/hivemind/server/layers/common.py @@ -2,21 +2,24 @@ from torch import nn as nn +# https://github.com/huggingface/transformers/blob/master/src/transformers/activations.py +@torch.jit.script +def gelu_fast(x): + return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x))) + + class FeedforwardBlock(nn.Module): def __init__(self, hid_dim): super().__init__() - self.layers = nn.Sequential( - nn.Linear(hid_dim, 4 * hid_dim), - nn.LayerNorm(4 * hid_dim), - nn.ReLU(inplace=True), - nn.Linear(4 * hid_dim, 4 * hid_dim), - nn.LayerNorm(4 * hid_dim), - nn.ReLU(inplace=True), - nn.Linear(4 * hid_dim, hid_dim), - ) + self.ffn = nn.Linear(hid_dim, 4 * hid_dim) + self.ffn_output = nn.Linear(4 * hid_dim, hid_dim) + self.layer_norm = nn.LayerNorm(hid_dim, eps=1e-12) def forward(self, x): - return x + self.layers(x) + ffn_output = self.ffn(x) + ffn_output = gelu_fast(ffn_output) + ffn_output = self.ffn_output(ffn_output) + return self.layer_norm(x + ffn_output) class TransformerEncoderLayer(nn.Module): @@ -37,7 +40,7 @@ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1): self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) - self.activation = torch.nn.GELU() + self.activation = gelu_fast def forward(self, src, src_key_padding_mask=None): # (N, S, E) -> (S, N, E) diff --git a/tests/test_moe.py b/tests/test_moe.py index c273611d8..1d43446e5 100644 --- a/tests/test_moe.py +++ b/tests/test_moe.py @@ -188,7 +188,7 @@ def test_client_anomaly_detection(): max_batch_size=16, ) - experts['expert.3'].expert.layers[0].weight.data[0, 0] = float('nan') + experts['expert.3'].expert.ffn.weight.data[0, 0] = float('nan') dht = hivemind.DHT(start=True, expiration=999) server = hivemind.Server(dht, experts, num_connection_handlers=1) diff --git a/tests/test_training.py b/tests/test_training.py index 1295ca12b..c9cd8574d 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -12,7 +12,7 @@ @pytest.mark.forked def test_training(max_steps: int = 100, threshold: float = 0.9): - dataset = load_digits() + dataset = load_digits(n_class=2) X_train, y_train = torch.tensor(dataset['data'], dtype=torch.float), torch.tensor(dataset['target']) SGD = partial(torch.optim.SGD, lr=0.05) @@ -20,7 +20,7 @@ def test_training(max_steps: int = 100, threshold: float = 0.9): no_dht=True) as (server_endpoint, dht_endpoint): expert1 = RemoteExpert('expert.0', server_endpoint) expert2 = RemoteExpert('expert.1', server_endpoint) - model = nn.Sequential(expert2, nn.Tanh(), expert1, nn.Linear(64, 10)) + model = nn.Sequential(expert2, nn.ReLU(), expert1, nn.Linear(64, 2)) opt = torch.optim.SGD(model.parameters(), lr=0.05)