Skip to content

Commit

Permalink
checking in the benchmark from Sean
Browse files Browse the repository at this point in the history
  • Loading branch information
blefaudeux committed Mar 24, 2021
1 parent 7012f0c commit 635071d
Show file tree
Hide file tree
Showing 5 changed files with 227 additions and 4 deletions.
8 changes: 8 additions & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ install_dep_171: &install_dep_171
command: |
pip install --progress-bar off torch==1.7.1+cu101 -f https://download.pytorch.org/whl/torch_stable.html
pip install --progress-bar off -r requirements-test.txt
pip install --progress-bar off -r requirements-benchmark.txt
python -c 'import torch; print("Torch version:", torch.__version__)'
python -c 'import torch; assert torch.__version__.split(".")[:2] == ["1", "7"], "wrong torch version"'
python -m torch.utils.collect_env
Expand Down Expand Up @@ -104,6 +105,11 @@ run_unittests: &run_unittests
command: |
pytest --junitxml=test-results/junit.xml --verbose --timeout 600
run_benchmarks: &run_benchmarks
- run:
name: Run Benchmarks
command: |
python3 benchmarks/benchmark_attention.py
# -------------------------------------------------------------------------------------
# Jobs to run
Expand Down Expand Up @@ -173,6 +179,8 @@ jobs:

- <<: *run_coverage

- <<: *run_benchmarks

- store_test_results:
path: test-results

Expand Down
2 changes: 1 addition & 1 deletion .isort.cfg
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
[settings]
known_third_party =attrdict,pytest,setuptools,torch
known_third_party =attrdict,pytest,setuptools,sklearn,torch,tqdm
16 changes: 13 additions & 3 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,16 @@ mypy --ignore-missing-imports --scripts-are-modules --pretty .

```
pytest
# single test
python -m pytest tests/hierarchy/single_test::test_target
```
or
```
python -m pytest
```

### Check test coverage

```
python -m pytest --cov-report term --cov=template tests/my_test_implementation::test_target
python -m pytest --cov-report term --cov=template tests
```

### CircleCI status
Expand Down Expand Up @@ -120,6 +122,14 @@ Must be one of the following:
generation
* **docs**: Documentation only changes

## Benchmarking
Eventually we'll probably have a launcher for a full benchmark suite.

For now, thanks to Sean Naren, you can start with
```
python3 benchmarks/benchmark_attention.py
```

## License
By contributing to *template*, you agree that your contributions will be licensed
under the LICENSE file in the root directory of this source tree.
203 changes: 203 additions & 0 deletions benchmarks/benchmark_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
import json
import time
from typing import Dict, Optional

import torch
import torch.nn.functional as F
from sklearn.model_selection import ParameterGrid
from tqdm import tqdm

from xformers.block_factory import xFormerEncoderBlock, xFormerEncoderConfig
from xformers.components import (
ATTENTION_REGISTRY,
AttentionConfig,
MultiHeadDispatchConfig,
)
from xformers.components.feedforward import (
FEEDFORWARD_REGISTRY,
Activations,
FeedforwardConfig,
)
from xformers.components.positional_encoding import PositionEncodingConfig

# Credits: Sean Naren


def _train_for_several_steps(
block: xFormerEncoderBlock,
num_steps: int,
batch_size: int,
sequence_length: int,
embed_dim: int,
autocast: bool,
device: torch.device,
lr: float = 0.01,
norm_type: Optional[float] = None,
) -> Dict[str, float]:
# use SGD with momentum instead of Adam, since Adam is scale invariant
# and this makes it bad for tests
optim = torch.optim.SGD(block.parameters(), lr=lr, momentum=0.9)

torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()

start_time = time.time()
for _ in range(num_steps):
optim.zero_grad()
with torch.cuda.amp.autocast(enabled=autocast):
input = torch.rand(batch_size, sequence_length, embed_dim)
input = input.to(device)
output = block(input)
loss = F.mse_loss(input, output, reduction="sum")

loss.backward()

if norm_type is not None:
clip_norm = 0.3
torch.nn.utils.clip_grad_norm_(block.parameters(), clip_norm, norm_type)
optim.step()

torch.cuda.synchronize()
max_memory = torch.cuda.max_memory_allocated() / 2 ** 20
run_time = time.time() - start_time

return {"run_time": run_time, "max_memory": max_memory}


def benchmark_model(num_warmup: int, num_steps: int, **kwargs) -> Dict[str, float]:
# Run warm-up first
_train_for_several_steps(num_steps=num_warmup, **kwargs)

return _train_for_several_steps(num_steps=num_steps, **kwargs)


def test_xformer_encoder_block(
attention_name: str,
feedforward_name: str,
heads: int,
attn_dropout: float,
residual_dropout: float,
causal: bool,
activation: Activations,
autocast: bool,
batch_size: int,
sequence_length: int,
embed_dim: int,
dropout: float,
num_steps: int,
num_warmup: int,
device: torch.device,
) -> Dict[str, float]:

block = instantiate_xformer(
activation=activation,
attention_name=attention_name,
attn_dropout=attn_dropout,
causal=causal,
feedforward_name=feedforward_name,
heads=heads,
residual_dropout=residual_dropout,
sequence_length=sequence_length,
embed_dim=embed_dim,
dropout=dropout,
)

block.to(device)

return benchmark_model(
num_steps=num_steps,
num_warmup=num_warmup,
block=block,
batch_size=batch_size,
sequence_length=sequence_length,
embed_dim=embed_dim,
autocast=autocast,
device=device,
)


def instantiate_xformer(
activation: Activations,
attention_name: str,
attn_dropout: float,
causal: bool,
feedforward_name: str,
heads: int,
residual_dropout: float,
sequence_length: int,
embed_dim: int,
dropout: float,
) -> xFormerEncoderBlock:

attention_config = {
"name": attention_name,
"dropout": attn_dropout,
"causal": causal,
"window_size": sequence_length // 8,
}

multi_head_config = {
"n_heads": heads,
"dim_seq": sequence_length,
"dim_model": embed_dim,
"residual_dropout": residual_dropout,
}

feedforward_config = {
"name": feedforward_name,
"dim_latent": embed_dim,
"dropout": dropout,
"activation": activation,
"hidden_layer_multiplier": 4,
}

position_encoding_config = {
"name": "sine",
"dim_model": embed_dim,
"seq_len": sequence_length,
}

block_config = xFormerEncoderConfig(
embed_dim,
AttentionConfig(**attention_config),
MultiHeadDispatchConfig(**multi_head_config),
FeedforwardConfig(**feedforward_config),
PositionEncodingConfig(**position_encoding_config),
)

block = xFormerEncoderBlock.from_config(block_config)
return block


if __name__ == "__main__":
constants = {
"device": torch.device("cuda"),
"num_warmup": 5,
"num_steps": 10,
"dropout": 0.0,
"attn_dropout": 0.0,
"residual_dropout": 0.0,
}

param_grid = {
"autocast": [False, True],
"causal": [False, True],
"heads": [8, 16],
"activation": [a.value for a in Activations],
"attention_name": ATTENTION_REGISTRY.keys(),
"feedforward_name": FEEDFORWARD_REGISTRY.keys(),
"sequence_length": [128, 512, 768],
"embed_dim": [64, 128, 512],
"batch_size": [8, 16, 32],
}

grid = ParameterGrid(param_grid)

grid_outputs = []

for params in tqdm(grid, total=len(grid)):
outputs = test_xformer_encoder_block(**constants, **params) # type: ignore
results = {**outputs, **params}
grid_outputs.append(results)

print(json.dumps(grid_outputs))
2 changes: 2 additions & 0 deletions requirements-benchmark.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Example requirement, can be anything that pip knows
# install with `pip install -r requirements.txt`, and make sure that CI does the same
torch >= 1.5.1
scikit-learn == 0.24.1
tqdm == 4.59.0

0 comments on commit 635071d

Please sign in to comment.