-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
7227 refactor transformer and diffusion model unet (#7715)
Part of #7227 . ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: KumoLiu <[email protected]> Signed-off-by: kaibo <[email protected]> Signed-off-by: heyufan1995 <[email protected]> Signed-off-by: YunLiu <[email protected]> Signed-off-by: binliu <[email protected]> Signed-off-by: dependabot[bot] <[email protected]> Signed-off-by: axel.vlaminck <[email protected]> Signed-off-by: monai-bot <[email protected]> Signed-off-by: Ibrahim Hadzic <[email protected]> Signed-off-by: Behrooz <[email protected]> Signed-off-by: Timothy Baker <[email protected]> Signed-off-by: Mathijs de Boer <[email protected]> Signed-off-by: Fabian Klopfer <[email protected]> Signed-off-by: Lucas Robinet <[email protected]> Signed-off-by: Lucas Robinet <[email protected]> Signed-off-by: chaoliu <[email protected]> Signed-off-by: cxlcl <[email protected]> Signed-off-by: chaoliu <[email protected]> Signed-off-by: Suraj Pai <[email protected]> Signed-off-by: Juan Pablo de la Cruz Gutiérrez <[email protected]> Signed-off-by: elitap <[email protected]> Signed-off-by: Felix Schnabel <[email protected]> Signed-off-by: YanxuanLiu <[email protected]> Signed-off-by: ytl0623 <[email protected]> Signed-off-by: Dženan Zukić <[email protected]> Signed-off-by: Ishan Dutta <[email protected]> Signed-off-by: John Zielke <[email protected]> Signed-off-by: Mingxin Zheng <[email protected]> Signed-off-by: Vladimir Chernyi <[email protected]> Signed-off-by: Yiheng Wang <[email protected]> Signed-off-by: Szabolcs Botond Lorincz Molnar <[email protected]> Signed-off-by: Lucas Robinet <[email protected]> Signed-off-by: Mingxin <[email protected]> Signed-off-by: Han Wang <[email protected]> Signed-off-by: Konstantin Sukharev <[email protected]> Signed-off-by: Ben Murray <[email protected]> Signed-off-by: Matthew Vine <[email protected]> Signed-off-by: Mark Graham <[email protected]> Signed-off-by: Peter Kaplinsky <[email protected]> Signed-off-by: Simon Jensen <[email protected]> Signed-off-by: NabJa <[email protected]> Co-authored-by: YunLiu <[email protected]> Co-authored-by: Kaibo Tang <[email protected]> Co-authored-by: Yufan He <[email protected]> Co-authored-by: binliunls <[email protected]> Co-authored-by: Ben Murray <[email protected]> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Eric Kerfoot <[email protected]> Co-authored-by: axel.vlaminck <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Mingxin Zheng <[email protected]> Co-authored-by: monai-bot <[email protected]> Co-authored-by: Ibrahim Hadzic <[email protected]> Co-authored-by: Dr. Behrooz Hashemian <[email protected]> Co-authored-by: Timothy J. Baker <[email protected]> Co-authored-by: Mathijs de Boer <[email protected]> Co-authored-by: Mathijs de Boer <[email protected]> Co-authored-by: Fabian Klopfer <[email protected]> Co-authored-by: Yiheng Wang <[email protected]> Co-authored-by: Lucas Robinet <[email protected]> Co-authored-by: Lucas Robinet <[email protected]> Co-authored-by: cxlcl <[email protected]> Co-authored-by: Suraj Pai <[email protected]> Co-authored-by: Juampa <[email protected]> Co-authored-by: elitap <[email protected]> Co-authored-by: Felix Schnabel <[email protected]> Co-authored-by: YanxuanLiu <[email protected]> Co-authored-by: ytl0623 <[email protected]> Co-authored-by: Dženan Zukić <[email protected]> Co-authored-by: Ishan Dutta <[email protected]> Co-authored-by: johnzielke <[email protected]> Co-authored-by: Vladimir Chernyi <[email protected]> Co-authored-by: Lőrincz-Molnár Szabolcs-Botond <[email protected]> Co-authored-by: Nic Ma <[email protected]> Co-authored-by: Lucas Robinet <[email protected]> Co-authored-by: Han Wang <[email protected]> Co-authored-by: Konstantin Sukharev <[email protected]> Co-authored-by: Matthew Vine <[email protected]> Co-authored-by: Pkaps25 <[email protected]> Co-authored-by: Peter Kaplinsky <[email protected]> Co-authored-by: Simon Jensen <[email protected]> Co-authored-by: NabJa <[email protected]>
- Loading branch information
1 parent
ba188e2
commit 1a57b55
Showing
90 changed files
with
1,327 additions
and
921 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,11 +18,10 @@ LABEL maintainer="[email protected]" | |
|
||
# TODO: remark for issue [revise the dockerfile](https://github.com/zarr-developers/numcodecs/issues/431) | ||
RUN if [[ $(uname -m) =~ "aarch64" ]]; then \ | ||
cd /opt && \ | ||
git clone --branch v0.12.1 --recursive https://github.com/zarr-developers/numcodecs && \ | ||
pip wheel numcodecs && \ | ||
rm -r /opt/*.whl && \ | ||
rm -rf /opt/numcodecs; \ | ||
export CFLAGS="-O3" && \ | ||
export DISABLE_NUMCODECS_SSE2=true && \ | ||
export DISABLE_NUMCODECS_AVX2=true && \ | ||
pip install numcodecs; \ | ||
fi | ||
|
||
WORKDIR /opt/monai | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
# Copyright (c) MONAI Consortium | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import annotations | ||
|
||
from typing import Optional, Tuple | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from monai.networks.layers.utils import get_rel_pos_embedding_layer | ||
from monai.utils import optional_import | ||
|
||
Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange") | ||
|
||
|
||
class CrossAttentionBlock(nn.Module): | ||
""" | ||
A cross-attention block, based on: "Dosovitskiy et al., | ||
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>" | ||
One can setup relative positional embedding as described in <https://arxiv.org/abs/2112.01526> | ||
""" | ||
|
||
def __init__( | ||
self, | ||
hidden_size: int, | ||
num_heads: int, | ||
dropout_rate: float = 0.0, | ||
hidden_input_size: int | None = None, | ||
context_input_size: int | None = None, | ||
dim_head: int | None = None, | ||
qkv_bias: bool = False, | ||
save_attn: bool = False, | ||
causal: bool = False, | ||
sequence_length: int | None = None, | ||
rel_pos_embedding: Optional[str] = None, | ||
input_size: Optional[Tuple] = None, | ||
attention_dtype: Optional[torch.dtype] = None, | ||
) -> None: | ||
""" | ||
Args: | ||
hidden_size (int): dimension of hidden layer. | ||
num_heads (int): number of attention heads. | ||
dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0. | ||
hidden_input_size (int, optional): dimension of the input tensor. Defaults to hidden_size. | ||
context_input_size (int, optional): dimension of the context tensor. Defaults to hidden_size. | ||
dim_head (int, optional): dimension of each head. Defaults to hidden_size // num_heads. | ||
qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False. | ||
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. | ||
causal: whether to use causal attention. | ||
sequence_length: if causal is True, it is necessary to specify the sequence length. | ||
rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map. | ||
For now only "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported. | ||
input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative | ||
positional parameter size. | ||
attention_dtype: cast attention operations to this dtype. | ||
""" | ||
|
||
super().__init__() | ||
|
||
if not (0 <= dropout_rate <= 1): | ||
raise ValueError("dropout_rate should be between 0 and 1.") | ||
|
||
if dim_head: | ||
inner_size = num_heads * dim_head | ||
self.head_dim = dim_head | ||
else: | ||
if hidden_size % num_heads != 0: | ||
raise ValueError("hidden size should be divisible by num_heads.") | ||
inner_size = hidden_size | ||
self.head_dim = hidden_size // num_heads | ||
|
||
if causal and sequence_length is None: | ||
raise ValueError("sequence_length is necessary for causal attention.") | ||
|
||
self.num_heads = num_heads | ||
self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size | ||
self.context_input_size = context_input_size if context_input_size else hidden_size | ||
self.out_proj = nn.Linear(inner_size, self.hidden_input_size) | ||
# key, query, value projections | ||
self.to_q = nn.Linear(self.hidden_input_size, inner_size, bias=qkv_bias) | ||
self.to_k = nn.Linear(self.context_input_size, inner_size, bias=qkv_bias) | ||
self.to_v = nn.Linear(self.context_input_size, inner_size, bias=qkv_bias) | ||
self.input_rearrange = Rearrange("b h (l d) -> b l h d", l=num_heads) | ||
|
||
self.out_rearrange = Rearrange("b h l d -> b l (h d)") | ||
self.drop_output = nn.Dropout(dropout_rate) | ||
self.drop_weights = nn.Dropout(dropout_rate) | ||
|
||
self.scale = self.head_dim**-0.5 | ||
self.save_attn = save_attn | ||
self.attention_dtype = attention_dtype | ||
|
||
self.causal = causal | ||
self.sequence_length = sequence_length | ||
|
||
if causal and sequence_length is not None: | ||
# causal mask to ensure that attention is only applied to the left in the input sequence | ||
self.register_buffer( | ||
"causal_mask", | ||
torch.tril(torch.ones(sequence_length, sequence_length)).view(1, 1, sequence_length, sequence_length), | ||
) | ||
self.causal_mask: torch.Tensor | ||
|
||
self.att_mat = torch.Tensor() | ||
self.rel_positional_embedding = ( | ||
get_rel_pos_embedding_layer(rel_pos_embedding, input_size, self.head_dim, self.num_heads) | ||
if rel_pos_embedding is not None | ||
else None | ||
) | ||
self.input_size = input_size | ||
|
||
def forward(self, x: torch.Tensor, context: torch.Tensor | None = None): | ||
""" | ||
Args: | ||
x (torch.Tensor): input tensor. B x (s_dim_1 * ... * s_dim_n) x C | ||
context (torch.Tensor, optional): context tensor. B x (s_dim_1 * ... * s_dim_n) x C | ||
Return: | ||
torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C | ||
""" | ||
# calculate query, key, values for all heads in batch and move head forward to be the batch dim | ||
b, t, c = x.size() # batch size, sequence length, embedding dimensionality (hidden_size) | ||
|
||
q = self.to_q(x) | ||
kv = context if context is not None else x | ||
_, kv_t, _ = kv.size() | ||
k = self.to_k(kv) | ||
v = self.to_v(kv) | ||
|
||
if self.attention_dtype is not None: | ||
q = q.to(self.attention_dtype) | ||
k = k.to(self.attention_dtype) | ||
|
||
q = q.view(b, t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, t, hs) | ||
k = k.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs) | ||
v = v.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs) | ||
att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale | ||
|
||
# apply relative positional embedding if defined | ||
att_mat = self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat | ||
|
||
if self.causal: | ||
att_mat = att_mat.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf")) | ||
|
||
att_mat = att_mat.softmax(dim=-1) | ||
|
||
if self.save_attn: | ||
# no gradients and new tensor; | ||
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html | ||
self.att_mat = att_mat.detach() | ||
|
||
att_mat = self.drop_weights(att_mat) | ||
x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) | ||
x = self.out_rearrange(x) | ||
x = self.out_proj(x) | ||
x = self.drop_output(x) | ||
return x |
Oops, something went wrong.