Skip to content

Commit

Permalink
7227 refactor transformer and diffusion model unet (#7715)
Browse files Browse the repository at this point in the history
Part of #7227  .

### Description

A few sentences describing the changes proposed in this pull request.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: KumoLiu <[email protected]>
Signed-off-by: kaibo <[email protected]>
Signed-off-by: heyufan1995 <[email protected]>
Signed-off-by: YunLiu <[email protected]>
Signed-off-by: binliu <[email protected]>
Signed-off-by: dependabot[bot] <[email protected]>
Signed-off-by: axel.vlaminck <[email protected]>
Signed-off-by: monai-bot <[email protected]>
Signed-off-by: Ibrahim Hadzic <[email protected]>
Signed-off-by: Behrooz <[email protected]>
Signed-off-by: Timothy Baker <[email protected]>
Signed-off-by: Mathijs de Boer <[email protected]>
Signed-off-by: Fabian Klopfer <[email protected]>
Signed-off-by: Lucas Robinet <[email protected]>
Signed-off-by: Lucas Robinet <[email protected]>
Signed-off-by: chaoliu <[email protected]>
Signed-off-by: cxlcl <[email protected]>
Signed-off-by: chaoliu <[email protected]>
Signed-off-by: Suraj Pai <[email protected]>
Signed-off-by: Juan Pablo de la Cruz Gutiérrez <[email protected]>
Signed-off-by: elitap <[email protected]>
Signed-off-by: Felix Schnabel <[email protected]>
Signed-off-by: YanxuanLiu <[email protected]>
Signed-off-by: ytl0623 <[email protected]>
Signed-off-by: Dženan Zukić <[email protected]>
Signed-off-by: Ishan Dutta <[email protected]>
Signed-off-by: John Zielke <[email protected]>
Signed-off-by: Mingxin Zheng <[email protected]>
Signed-off-by: Vladimir Chernyi <[email protected]>
Signed-off-by: Yiheng Wang <[email protected]>
Signed-off-by: Szabolcs Botond Lorincz Molnar <[email protected]>
Signed-off-by: Lucas Robinet <[email protected]>
Signed-off-by: Mingxin <[email protected]>
Signed-off-by: Han Wang <[email protected]>
Signed-off-by: Konstantin Sukharev <[email protected]>
Signed-off-by: Ben Murray <[email protected]>
Signed-off-by: Matthew Vine <[email protected]>
Signed-off-by: Mark Graham <[email protected]>
Signed-off-by: Peter Kaplinsky <[email protected]>
Signed-off-by: Simon Jensen <[email protected]>
Signed-off-by: NabJa <[email protected]>
Co-authored-by: YunLiu <[email protected]>
Co-authored-by: Kaibo Tang <[email protected]>
Co-authored-by: Yufan He <[email protected]>
Co-authored-by: binliunls <[email protected]>
Co-authored-by: Ben Murray <[email protected]>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Eric Kerfoot <[email protected]>
Co-authored-by: axel.vlaminck <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Mingxin Zheng <[email protected]>
Co-authored-by: monai-bot <[email protected]>
Co-authored-by: Ibrahim Hadzic <[email protected]>
Co-authored-by: Dr. Behrooz Hashemian <[email protected]>
Co-authored-by: Timothy J. Baker <[email protected]>
Co-authored-by: Mathijs de Boer <[email protected]>
Co-authored-by: Mathijs de Boer <[email protected]>
Co-authored-by: Fabian Klopfer <[email protected]>
Co-authored-by: Yiheng Wang <[email protected]>
Co-authored-by: Lucas Robinet <[email protected]>
Co-authored-by: Lucas Robinet <[email protected]>
Co-authored-by: cxlcl <[email protected]>
Co-authored-by: Suraj Pai <[email protected]>
Co-authored-by: Juampa <[email protected]>
Co-authored-by: elitap <[email protected]>
Co-authored-by: Felix Schnabel <[email protected]>
Co-authored-by: YanxuanLiu <[email protected]>
Co-authored-by: ytl0623 <[email protected]>
Co-authored-by: Dženan Zukić <[email protected]>
Co-authored-by: Ishan Dutta <[email protected]>
Co-authored-by: johnzielke <[email protected]>
Co-authored-by: Vladimir Chernyi <[email protected]>
Co-authored-by: Lőrincz-Molnár Szabolcs-Botond <[email protected]>
Co-authored-by: Nic Ma <[email protected]>
Co-authored-by: Lucas Robinet <[email protected]>
Co-authored-by: Han Wang <[email protected]>
Co-authored-by: Konstantin Sukharev <[email protected]>
Co-authored-by: Matthew Vine <[email protected]>
Co-authored-by: Pkaps25 <[email protected]>
Co-authored-by: Peter Kaplinsky <[email protected]>
Co-authored-by: Simon Jensen <[email protected]>
Co-authored-by: NabJa <[email protected]>
  • Loading branch information
Show file tree
Hide file tree
Showing 90 changed files with 1,327 additions and 921 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/pythonapp-min.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ on:
- main
- releasing/*
pull_request:
head_ref-ignore:
- dev

concurrency:
# automatically cancel the previously triggered workflows when there's a newer version
Expand Down
6 changes: 4 additions & 2 deletions .github/workflows/pythonapp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ on:
- main
- releasing/*
pull_request:
head_ref-ignore:
- dev

concurrency:
# automatically cancel the previously triggered workflows when there's a newer version
Expand Down Expand Up @@ -68,10 +70,10 @@ jobs:
maximum-size: 16GB
disk-root: "D:"
- uses: actions/checkout@v4
- name: Set up Python 3.8
- name: Set up Python 3.9
uses: actions/setup-python@v5
with:
python-version: '3.8'
python-version: '3.9'
- name: Prepare pip wheel
run: |
which python
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ repos:
)$
- repo: https://github.com/hadialqattan/pycln
rev: v2.1.3
rev: v2.4.0
hooks:
- id: pycln
args: [--config=pyproject.toml]
9 changes: 4 additions & 5 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,10 @@ LABEL maintainer="[email protected]"

# TODO: remark for issue [revise the dockerfile](https://github.com/zarr-developers/numcodecs/issues/431)
RUN if [[ $(uname -m) =~ "aarch64" ]]; then \
cd /opt && \
git clone --branch v0.12.1 --recursive https://github.com/zarr-developers/numcodecs && \
pip wheel numcodecs && \
rm -r /opt/*.whl && \
rm -rf /opt/numcodecs; \
export CFLAGS="-O3" && \
export DISABLE_NUMCODECS_SSE2=true && \
export DISABLE_NUMCODECS_AVX2=true && \
pip install numcodecs; \
fi

WORKDIR /opt/monai
Expand Down
1 change: 1 addition & 0 deletions docs/source/networks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,7 @@ Layers
.. autoclass:: monai.networks.layers.vector_quantizer.VectorQuantizer
:members:

=======
`ConjugateGradient`
~~~~~~~~~~~~~~~~~~~
.. autoclass:: ConjugateGradient
Expand Down
7 changes: 6 additions & 1 deletion monai/apps/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,12 @@ def check_hash(filepath: PathLike, val: str | None = None, hash_type: str = "md5
logger.info(f"Expected {hash_type} is None, skip {hash_type} check for file {filepath}.")
return True
actual_hash_func = look_up_option(hash_type.lower(), SUPPORTED_HASH_TYPES)
actual_hash = actual_hash_func()

if sys.version_info >= (3, 9):
actual_hash = actual_hash_func(usedforsecurity=False) # allows checks on FIPS enabled machines
else:
actual_hash = actual_hash_func()

try:
with open(filepath, "rb") as f:
for chunk in iter(lambda: f.read(1024 * 1024), b""):
Expand Down
10 changes: 7 additions & 3 deletions monai/bundle/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ class ConfigWorkflow(BundleWorkflow):
logging_file: config file for `logging` module in the program. for more details:
https://docs.python.org/3/library/logging.config.html#logging.config.fileConfig.
If None, default to "configs/logging.conf", which is commonly used for bundles in MONAI model zoo.
If False, the logging logic for the bundle will not be modified.
init_id: ID name of the expected config expression to initialize before running, default to "initialize".
allow a config to have no `initialize` logic and the ID.
run_id: ID name of the expected config expression to run, default to "run".
Expand Down Expand Up @@ -278,7 +279,7 @@ def __init__(
self,
config_file: str | Sequence[str],
meta_file: str | Sequence[str] | None = None,
logging_file: str | None = None,
logging_file: str | bool | None = None,
init_id: str = "initialize",
run_id: str = "run",
final_id: str = "finalize",
Expand Down Expand Up @@ -307,15 +308,18 @@ def __init__(
super().__init__(workflow_type=workflow_type, meta_file=meta_file, properties_path=properties_path)
self.config_root_path = config_root_path
logging_file = str(self.config_root_path / "logging.conf") if logging_file is None else logging_file
if logging_file is not None:

if logging_file is False:
logger.warn(f"Logging file is set to {logging_file}, skipping logging.")
else:
if not os.path.isfile(logging_file):
if logging_file == str(self.config_root_path / "logging.conf"):
logger.warn(f"Default logging file in {logging_file} does not exist, skipping logging.")
else:
raise FileNotFoundError(f"Cannot find the logging config file: {logging_file}.")
else:
logger.info(f"Setting logging properties based on config: {logging_file}.")
fileConfig(logging_file, disable_existing_loggers=False)
fileConfig(str(logging_file), disable_existing_loggers=False)

self.parser = ConfigParser()
self.parser.read_config(f=config_file)
Expand Down
14 changes: 9 additions & 5 deletions monai/fl/client/monai_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,14 @@ def initialize(self, extra=None):
Args:
extra: Dict with additional information that should be provided by FL system,
i.e., `ExtraItems.CLIENT_NAME` and `ExtraItems.APP_ROOT`.
i.e., `ExtraItems.CLIENT_NAME`, `ExtraItems.APP_ROOT` and `ExtraItems.LOGGING_FILE`.
You can diable the logging logic in the monai bundle by setting {ExtraItems.LOGGING_FILE} to False.
"""
if extra is None:
extra = {}
self.client_name = extra.get(ExtraItems.CLIENT_NAME, "noname")
logging_file = extra.get(ExtraItems.LOGGING_FILE, None)
self.logger.info(f"Initializing {self.client_name} ...")

# FL platform needs to provide filepath to configuration files
Expand All @@ -149,7 +151,7 @@ def initialize(self, extra=None):
if self.workflow is None:
config_train_files = self._add_config_files(self.config_train_filename)
self.workflow = ConfigWorkflow(
config_file=config_train_files, meta_file=None, logging_file=None, workflow_type="train"
config_file=config_train_files, meta_file=None, logging_file=logging_file, workflow_type="train"
)
self.workflow.initialize()
self.workflow.bundle_root = self.bundle_root
Expand Down Expand Up @@ -412,13 +414,15 @@ def initialize(self, extra=None):
Args:
extra: Dict with additional information that should be provided by FL system,
i.e., `ExtraItems.CLIENT_NAME` and `ExtraItems.APP_ROOT`.
i.e., `ExtraItems.CLIENT_NAME`, `ExtraItems.APP_ROOT` and `ExtraItems.LOGGING_FILE`.
You can diable the logging logic in the monai bundle by setting {ExtraItems.LOGGING_FILE} to False.
"""
self._set_cuda_device()
if extra is None:
extra = {}
self.client_name = extra.get(ExtraItems.CLIENT_NAME, "noname")
logging_file = extra.get(ExtraItems.LOGGING_FILE, None)
timestamp = time.strftime("%Y%m%d_%H%M%S")
self.logger.info(f"Initializing {self.client_name} ...")
# FL platform needs to provide filepath to configuration files
Expand All @@ -434,7 +438,7 @@ def initialize(self, extra=None):
self.train_workflow = ConfigWorkflow(
config_file=config_train_files,
meta_file=None,
logging_file=None,
logging_file=logging_file,
workflow_type="train",
**self.train_kwargs,
)
Expand All @@ -459,7 +463,7 @@ def initialize(self, extra=None):
self.eval_workflow = ConfigWorkflow(
config_file=config_eval_files,
meta_file=None,
logging_file=None,
logging_file=logging_file,
workflow_type=self.eval_workflow_name,
**self.eval_kwargs,
)
Expand Down
1 change: 1 addition & 0 deletions monai/fl/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class ExtraItems(StrEnum):
CLIENT_NAME = "fl_client_name"
APP_ROOT = "fl_app_root"
STATS_SENDER = "fl_stats_sender"
LOGGING_FILE = "logging_file"


class FlPhase(StrEnum):
Expand Down
2 changes: 1 addition & 1 deletion monai/losses/ds_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, loss: _Loss, weight_mode: str = "exp", weights: list[float] |
weight_mode: {``"same"``, ``"exp"``, ``"two"``}
Specifies the weights calculation for each image level. Defaults to ``"exp"``.
- ``"same"``: all weights are equal to 1.
- ``"exp"``: exponentially decreasing weights by a power of 2: 0, 0.5, 0.25, 0.125, etc .
- ``"exp"``: exponentially decreasing weights by a power of 2: 1, 0.5, 0.25, 0.125, etc .
- ``"two"``: equal smaller weights for lower levels: 1, 0.5, 0.5, 0.5, 0.5, etc
weights: a list of weights to apply to each deeply supervised sub-loss, if provided, this will be used
regardless of the weight_mode
Expand Down
2 changes: 2 additions & 0 deletions monai/networks/blocks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .backbone_fpn_utils import BackboneWithFPN
from .convolutions import Convolution, ResidualUnit
from .crf import CRF
from .crossattention import CrossAttentionBlock
from .denseblock import ConvDenseBlock, DenseBlock
from .dints_block import ActiConvNormBlock, FactorizedIncreaseBlock, FactorizedReduceBlock, P3DActiConvNormBlock
from .downsample import MaxAvgPool
Expand All @@ -31,6 +32,7 @@
from .segresnet_block import ResBlock
from .selfattention import SABlock
from .spade_norm import SPADE
from .spatialattention import SpatialAttentionBlock
from .squeeze_and_excitation import (
ChannelSELayer,
ResidualSELayer,
Expand Down
166 changes: 166 additions & 0 deletions monai/networks/blocks/crossattention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Optional, Tuple

import torch
import torch.nn as nn

from monai.networks.layers.utils import get_rel_pos_embedding_layer
from monai.utils import optional_import

Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")


class CrossAttentionBlock(nn.Module):
"""
A cross-attention block, based on: "Dosovitskiy et al.,
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
One can setup relative positional embedding as described in <https://arxiv.org/abs/2112.01526>
"""

def __init__(
self,
hidden_size: int,
num_heads: int,
dropout_rate: float = 0.0,
hidden_input_size: int | None = None,
context_input_size: int | None = None,
dim_head: int | None = None,
qkv_bias: bool = False,
save_attn: bool = False,
causal: bool = False,
sequence_length: int | None = None,
rel_pos_embedding: Optional[str] = None,
input_size: Optional[Tuple] = None,
attention_dtype: Optional[torch.dtype] = None,
) -> None:
"""
Args:
hidden_size (int): dimension of hidden layer.
num_heads (int): number of attention heads.
dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0.
hidden_input_size (int, optional): dimension of the input tensor. Defaults to hidden_size.
context_input_size (int, optional): dimension of the context tensor. Defaults to hidden_size.
dim_head (int, optional): dimension of each head. Defaults to hidden_size // num_heads.
qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False.
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
causal: whether to use causal attention.
sequence_length: if causal is True, it is necessary to specify the sequence length.
rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map.
For now only "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported.
input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative
positional parameter size.
attention_dtype: cast attention operations to this dtype.
"""

super().__init__()

if not (0 <= dropout_rate <= 1):
raise ValueError("dropout_rate should be between 0 and 1.")

if dim_head:
inner_size = num_heads * dim_head
self.head_dim = dim_head
else:
if hidden_size % num_heads != 0:
raise ValueError("hidden size should be divisible by num_heads.")
inner_size = hidden_size
self.head_dim = hidden_size // num_heads

if causal and sequence_length is None:
raise ValueError("sequence_length is necessary for causal attention.")

self.num_heads = num_heads
self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size
self.context_input_size = context_input_size if context_input_size else hidden_size
self.out_proj = nn.Linear(inner_size, self.hidden_input_size)
# key, query, value projections
self.to_q = nn.Linear(self.hidden_input_size, inner_size, bias=qkv_bias)
self.to_k = nn.Linear(self.context_input_size, inner_size, bias=qkv_bias)
self.to_v = nn.Linear(self.context_input_size, inner_size, bias=qkv_bias)
self.input_rearrange = Rearrange("b h (l d) -> b l h d", l=num_heads)

self.out_rearrange = Rearrange("b h l d -> b l (h d)")
self.drop_output = nn.Dropout(dropout_rate)
self.drop_weights = nn.Dropout(dropout_rate)

self.scale = self.head_dim**-0.5
self.save_attn = save_attn
self.attention_dtype = attention_dtype

self.causal = causal
self.sequence_length = sequence_length

if causal and sequence_length is not None:
# causal mask to ensure that attention is only applied to the left in the input sequence
self.register_buffer(
"causal_mask",
torch.tril(torch.ones(sequence_length, sequence_length)).view(1, 1, sequence_length, sequence_length),
)
self.causal_mask: torch.Tensor

self.att_mat = torch.Tensor()
self.rel_positional_embedding = (
get_rel_pos_embedding_layer(rel_pos_embedding, input_size, self.head_dim, self.num_heads)
if rel_pos_embedding is not None
else None
)
self.input_size = input_size

def forward(self, x: torch.Tensor, context: torch.Tensor | None = None):
"""
Args:
x (torch.Tensor): input tensor. B x (s_dim_1 * ... * s_dim_n) x C
context (torch.Tensor, optional): context tensor. B x (s_dim_1 * ... * s_dim_n) x C
Return:
torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C
"""
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
b, t, c = x.size() # batch size, sequence length, embedding dimensionality (hidden_size)

q = self.to_q(x)
kv = context if context is not None else x
_, kv_t, _ = kv.size()
k = self.to_k(kv)
v = self.to_v(kv)

if self.attention_dtype is not None:
q = q.to(self.attention_dtype)
k = k.to(self.attention_dtype)

q = q.view(b, t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, t, hs)
k = k.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs)
v = v.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs)
att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale

# apply relative positional embedding if defined
att_mat = self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat

if self.causal:
att_mat = att_mat.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf"))

att_mat = att_mat.softmax(dim=-1)

if self.save_attn:
# no gradients and new tensor;
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
self.att_mat = att_mat.detach()

att_mat = self.drop_weights(att_mat)
x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
x = self.out_rearrange(x)
x = self.out_proj(x)
x = self.drop_output(x)
return x
Loading

0 comments on commit 1a57b55

Please sign in to comment.