Skip to content

Commit

Permalink
Replace FeedforwardBlock with a correct implementation (#211)
Browse files Browse the repository at this point in the history
* Replace FeedforwardBlock with a correct implementation

* Reduce number of classes in test_training
  • Loading branch information
mryab authored Apr 8, 2021
1 parent 1d364b7 commit ca6d87a
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 14 deletions.
25 changes: 14 additions & 11 deletions hivemind/server/layers/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,24 @@
from torch import nn as nn


# https://github.com/huggingface/transformers/blob/master/src/transformers/activations.py
@torch.jit.script
def gelu_fast(x):
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))


class FeedforwardBlock(nn.Module):
def __init__(self, hid_dim):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(hid_dim, 4 * hid_dim),
nn.LayerNorm(4 * hid_dim),
nn.ReLU(inplace=True),
nn.Linear(4 * hid_dim, 4 * hid_dim),
nn.LayerNorm(4 * hid_dim),
nn.ReLU(inplace=True),
nn.Linear(4 * hid_dim, hid_dim),
)
self.ffn = nn.Linear(hid_dim, 4 * hid_dim)
self.ffn_output = nn.Linear(4 * hid_dim, hid_dim)
self.layer_norm = nn.LayerNorm(hid_dim, eps=1e-12)

def forward(self, x):
return x + self.layers(x)
ffn_output = self.ffn(x)
ffn_output = gelu_fast(ffn_output)
ffn_output = self.ffn_output(ffn_output)
return self.layer_norm(x + ffn_output)


class TransformerEncoderLayer(nn.Module):
Expand All @@ -37,7 +40,7 @@ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)

self.activation = torch.nn.GELU()
self.activation = gelu_fast

def forward(self, src, src_key_padding_mask=None):
# (N, S, E) -> (S, N, E)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def test_client_anomaly_detection():
max_batch_size=16,
)

experts['expert.3'].expert.layers[0].weight.data[0, 0] = float('nan')
experts['expert.3'].expert.ffn.weight.data[0, 0] = float('nan')

dht = hivemind.DHT(start=True, expiration=999)
server = hivemind.Server(dht, experts, num_connection_handlers=1)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@

@pytest.mark.forked
def test_training(max_steps: int = 100, threshold: float = 0.9):
dataset = load_digits()
dataset = load_digits(n_class=2)
X_train, y_train = torch.tensor(dataset['data'], dtype=torch.float), torch.tensor(dataset['target'])
SGD = partial(torch.optim.SGD, lr=0.05)

with background_server(num_experts=2, device='cpu', optim_cls=SGD, hidden_dim=64, num_handlers=1,
no_dht=True) as (server_endpoint, dht_endpoint):
expert1 = RemoteExpert('expert.0', server_endpoint)
expert2 = RemoteExpert('expert.1', server_endpoint)
model = nn.Sequential(expert2, nn.Tanh(), expert1, nn.Linear(64, 10))
model = nn.Sequential(expert2, nn.ReLU(), expert1, nn.Linear(64, 2))

opt = torch.optim.SGD(model.parameters(), lr=0.05)

Expand Down

0 comments on commit ca6d87a

Please sign in to comment.