From 1a57b551474cd1740065862113331a9dddce84ca Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Fri, 10 May 2024 13:16:32 +0100 Subject: [PATCH] 7227 refactor transformer and diffusion model unet (#7715) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Part of #7227 . ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: KumoLiu Signed-off-by: kaibo Signed-off-by: heyufan1995 Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: binliu Signed-off-by: dependabot[bot] Signed-off-by: axel.vlaminck Signed-off-by: monai-bot Signed-off-by: Ibrahim Hadzic Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> Signed-off-by: Timothy Baker Signed-off-by: Mathijs de Boer Signed-off-by: Fabian Klopfer Signed-off-by: Lucas Robinet Signed-off-by: Lucas Robinet <67736918+Lucas-rbnt@users.noreply.github.com> Signed-off-by: chaoliu Signed-off-by: cxlcl Signed-off-by: chaoliu Signed-off-by: Suraj Pai Signed-off-by: Juan Pablo de la Cruz Gutiérrez Signed-off-by: elitap Signed-off-by: Felix Schnabel Signed-off-by: YanxuanLiu Signed-off-by: ytl0623 Signed-off-by: Dženan Zukić Signed-off-by: Ishan Dutta Signed-off-by: John Zielke Signed-off-by: Mingxin Zheng Signed-off-by: Vladimir Chernyi <57420464+scalyvladimir@users.noreply.github.com> Signed-off-by: Yiheng Wang Signed-off-by: Szabolcs Botond Lorincz Molnar Signed-off-by: Lucas Robinet Signed-off-by: Mingxin Signed-off-by: Han Wang Signed-off-by: Konstantin Sukharev Signed-off-by: Ben Murray Signed-off-by: Matthew Vine <32849887+MattTheCuber@users.noreply.github.com> Signed-off-by: Mark Graham Signed-off-by: Peter Kaplinsky Signed-off-by: Simon Jensen <61684806+simojens@users.noreply.github.com> Signed-off-by: NabJa Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Kaibo Tang Co-authored-by: Yufan He <59374597+heyufan1995@users.noreply.github.com> Co-authored-by: binliunls <107988372+binliunls@users.noreply.github.com> Co-authored-by: Ben Murray Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: axel.vlaminck Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Mingxin Zheng <18563433+mingxin-zheng@users.noreply.github.com> Co-authored-by: monai-bot <64792179+monai-bot@users.noreply.github.com> Co-authored-by: Ibrahim Hadzic Co-authored-by: Dr. Behrooz Hashemian <3968947+drbeh@users.noreply.github.com> Co-authored-by: Timothy J. Baker <62781117+tim-the-baker@users.noreply.github.com> Co-authored-by: Mathijs de Boer <8137653+MathijsdeBoer@users.noreply.github.com> Co-authored-by: Mathijs de Boer Co-authored-by: Fabian Klopfer Co-authored-by: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> Co-authored-by: Lucas Robinet <67736918+Lucas-rbnt@users.noreply.github.com> Co-authored-by: Lucas Robinet Co-authored-by: cxlcl Co-authored-by: Suraj Pai Co-authored-by: Juampa <1523654+juampatronics@users.noreply.github.com> Co-authored-by: elitap Co-authored-by: Felix Schnabel Co-authored-by: YanxuanLiu <104543031+YanxuanLiu@users.noreply.github.com> Co-authored-by: ytl0623 Co-authored-by: Dženan Zukić Co-authored-by: Ishan Dutta Co-authored-by: johnzielke Co-authored-by: Vladimir Chernyi <57420464+scalyvladimir@users.noreply.github.com> Co-authored-by: Lőrincz-Molnár Szabolcs-Botond Co-authored-by: Nic Ma Co-authored-by: Lucas Robinet Co-authored-by: Han Wang Co-authored-by: Konstantin Sukharev <50718389+k-sukharev@users.noreply.github.com> Co-authored-by: Matthew Vine <32849887+MattTheCuber@users.noreply.github.com> Co-authored-by: Pkaps25 <43655728+Pkaps25@users.noreply.github.com> Co-authored-by: Peter Kaplinsky Co-authored-by: Simon Jensen <61684806+simojens@users.noreply.github.com> Co-authored-by: NabJa <32510324+NabJa@users.noreply.github.com> --- .github/workflows/pythonapp-min.yml | 2 + .github/workflows/pythonapp.yml | 6 +- .pre-commit-config.yaml | 2 +- Dockerfile | 9 +- docs/source/networks.rst | 1 + monai/apps/utils.py | 7 +- monai/bundle/workflows.py | 10 +- monai/fl/client/monai_algo.py | 14 +- monai/fl/utils/constants.py | 1 + monai/losses/ds_loss.py | 2 +- monai/networks/blocks/__init__.py | 2 + monai/networks/blocks/crossattention.py | 166 +++++ monai/networks/blocks/selfattention.py | 53 +- monai/networks/blocks/spade_norm.py | 2 +- monai/networks/blocks/spatialattention.py | 82 +++ monai/networks/blocks/transformerblock.py | 28 +- monai/networks/nets/attentionunet.py | 12 +- monai/networks/nets/autoencoderkl.py | 66 +- monai/networks/nets/controlnet.py | 9 - monai/networks/nets/diffusion_model_unet.py | 578 ++++++------------ monai/networks/nets/resnet.py | 1 - monai/networks/nets/spade_autoencoderkl.py | 8 +- .../nets/spade_diffusion_model_unet.py | 123 ++-- monai/networks/nets/transformer.py | 267 ++------ monai/networks/utils.py | 1 + monai/utils/misc.py | 2 +- requirements-dev.txt | 4 +- tests/test_attentionunet.py | 20 + tests/test_autoencoderkl.py | 37 +- tests/test_bundle_ckpt_export.py | 6 +- tests/test_bundle_get_data.py | 15 +- tests/test_bundle_trt_export.py | 12 +- tests/test_bundle_workflow.py | 6 +- tests/test_clip_intensity_percentilesd.py | 2 +- tests/test_component_store.py | 8 +- tests/test_compute_ho_ver_maps.py | 4 +- tests/test_compute_ho_ver_maps_d.py | 4 +- tests/test_compute_regression_metrics.py | 10 +- tests/test_concat_itemsd.py | 8 +- tests/test_config_parser.py | 2 +- tests/test_controlnet.py | 5 + tests/test_controlnet_inferers.py | 19 + tests/test_crossattention.py | 131 ++++ tests/test_cucim_dict_transform.py | 16 +- tests/test_cucim_transform.py | 16 +- tests/test_detect_envelope.py | 2 +- tests/test_diffusion_inferer.py | 10 + tests/test_diffusion_model_unet.py | 50 ++ tests/test_ensure_typed.py | 32 +- tests/test_flipd.py | 2 +- tests/test_freeze_layers.py | 8 +- tests/test_generalized_dice_loss.py | 4 +- tests/test_get_package_version.py | 6 +- tests/test_grid_patch.py | 6 +- tests/test_handler_stats.py | 16 +- tests/test_integration_bundle_run.py | 6 +- tests/test_inverse_collation.py | 2 +- tests/test_invertd.py | 2 +- tests/test_latent_diffusion_inferer.py | 12 + tests/test_load_imaged.py | 2 +- tests/test_load_spacing_orientation.py | 4 +- tests/test_look_up_option.py | 2 +- tests/test_matshow3d.py | 2 +- tests/test_median_filter.py | 2 +- tests/test_mednistdataset.py | 2 +- tests/test_meta_affine.py | 4 +- tests/test_meta_tensor.py | 4 +- tests/test_mmar_download.py | 2 +- tests/test_persistentdataset.py | 2 +- tests/test_rand_affined.py | 2 +- tests/test_rand_bias_field.py | 2 +- tests/test_rand_weighted_cropd.py | 2 +- tests/test_recon_net_utils.py | 2 +- tests/test_reg_loss_integration.py | 2 +- tests/test_resnet.py | 10 +- tests/test_selfattention.py | 55 ++ tests/test_sobel_gradient.py | 4 +- tests/test_sobel_gradientd.py | 4 +- tests/test_spade_diffusion_model_unet.py | 16 + tests/test_spatialattention.py | 55 ++ tests/test_threadcontainer.py | 2 +- tests/test_to_cupy.py | 16 +- tests/test_to_numpy.py | 12 +- tests/test_torchvision_fc_model.py | 4 +- tests/test_traceable_transform.py | 4 +- tests/test_transformer.py | 36 ++ tests/test_transformerblock.py | 29 +- tests/test_vqvaetransformer_inferer.py | 11 + tests/test_warp.py | 2 +- tests/testing_data/data_config.json | 15 + 90 files changed, 1327 insertions(+), 921 deletions(-) create mode 100644 monai/networks/blocks/crossattention.py create mode 100644 monai/networks/blocks/spatialattention.py create mode 100644 tests/test_crossattention.py create mode 100644 tests/test_spatialattention.py diff --git a/.github/workflows/pythonapp-min.yml b/.github/workflows/pythonapp-min.yml index bbe7579774..dffae10558 100644 --- a/.github/workflows/pythonapp-min.yml +++ b/.github/workflows/pythonapp-min.yml @@ -9,6 +9,8 @@ on: - main - releasing/* pull_request: + head_ref-ignore: + - dev concurrency: # automatically cancel the previously triggered workflows when there's a newer version diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml index b7f2cfb9db..b8b73907d4 100644 --- a/.github/workflows/pythonapp.yml +++ b/.github/workflows/pythonapp.yml @@ -9,6 +9,8 @@ on: - main - releasing/* pull_request: + head_ref-ignore: + - dev concurrency: # automatically cancel the previously triggered workflows when there's a newer version @@ -68,10 +70,10 @@ jobs: maximum-size: 16GB disk-root: "D:" - uses: actions/checkout@v4 - - name: Set up Python 3.8 + - name: Set up Python 3.9 uses: actions/setup-python@v5 with: - python-version: '3.8' + python-version: '3.9' - name: Prepare pip wheel run: | which python diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b71a2bac43..b9debaf08f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -69,7 +69,7 @@ repos: )$ - repo: https://github.com/hadialqattan/pycln - rev: v2.1.3 + rev: v2.4.0 hooks: - id: pycln args: [--config=pyproject.toml] diff --git a/Dockerfile b/Dockerfile index fc97227351..10931222dd 100644 --- a/Dockerfile +++ b/Dockerfile @@ -18,11 +18,10 @@ LABEL maintainer="monai.contact@gmail.com" # TODO: remark for issue [revise the dockerfile](https://github.com/zarr-developers/numcodecs/issues/431) RUN if [[ $(uname -m) =~ "aarch64" ]]; then \ - cd /opt && \ - git clone --branch v0.12.1 --recursive https://github.com/zarr-developers/numcodecs && \ - pip wheel numcodecs && \ - rm -r /opt/*.whl && \ - rm -rf /opt/numcodecs; \ + export CFLAGS="-O3" && \ + export DISABLE_NUMCODECS_SSE2=true && \ + export DISABLE_NUMCODECS_AVX2=true && \ + pip install numcodecs; \ fi WORKDIR /opt/monai diff --git a/docs/source/networks.rst b/docs/source/networks.rst index c51f5c88b1..8321fed1a4 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -426,6 +426,7 @@ Layers .. autoclass:: monai.networks.layers.vector_quantizer.VectorQuantizer :members: +======= `ConjugateGradient` ~~~~~~~~~~~~~~~~~~~ .. autoclass:: ConjugateGradient diff --git a/monai/apps/utils.py b/monai/apps/utils.py index db541923b5..0c998146a3 100644 --- a/monai/apps/utils.py +++ b/monai/apps/utils.py @@ -135,7 +135,12 @@ def check_hash(filepath: PathLike, val: str | None = None, hash_type: str = "md5 logger.info(f"Expected {hash_type} is None, skip {hash_type} check for file {filepath}.") return True actual_hash_func = look_up_option(hash_type.lower(), SUPPORTED_HASH_TYPES) - actual_hash = actual_hash_func() + + if sys.version_info >= (3, 9): + actual_hash = actual_hash_func(usedforsecurity=False) # allows checks on FIPS enabled machines + else: + actual_hash = actual_hash_func() + try: with open(filepath, "rb") as f: for chunk in iter(lambda: f.read(1024 * 1024), b""): diff --git a/monai/bundle/workflows.py b/monai/bundle/workflows.py index 471088994b..b42852cb0f 100644 --- a/monai/bundle/workflows.py +++ b/monai/bundle/workflows.py @@ -239,6 +239,7 @@ class ConfigWorkflow(BundleWorkflow): logging_file: config file for `logging` module in the program. for more details: https://docs.python.org/3/library/logging.config.html#logging.config.fileConfig. If None, default to "configs/logging.conf", which is commonly used for bundles in MONAI model zoo. + If False, the logging logic for the bundle will not be modified. init_id: ID name of the expected config expression to initialize before running, default to "initialize". allow a config to have no `initialize` logic and the ID. run_id: ID name of the expected config expression to run, default to "run". @@ -278,7 +279,7 @@ def __init__( self, config_file: str | Sequence[str], meta_file: str | Sequence[str] | None = None, - logging_file: str | None = None, + logging_file: str | bool | None = None, init_id: str = "initialize", run_id: str = "run", final_id: str = "finalize", @@ -307,7 +308,10 @@ def __init__( super().__init__(workflow_type=workflow_type, meta_file=meta_file, properties_path=properties_path) self.config_root_path = config_root_path logging_file = str(self.config_root_path / "logging.conf") if logging_file is None else logging_file - if logging_file is not None: + + if logging_file is False: + logger.warn(f"Logging file is set to {logging_file}, skipping logging.") + else: if not os.path.isfile(logging_file): if logging_file == str(self.config_root_path / "logging.conf"): logger.warn(f"Default logging file in {logging_file} does not exist, skipping logging.") @@ -315,7 +319,7 @@ def __init__( raise FileNotFoundError(f"Cannot find the logging config file: {logging_file}.") else: logger.info(f"Setting logging properties based on config: {logging_file}.") - fileConfig(logging_file, disable_existing_loggers=False) + fileConfig(str(logging_file), disable_existing_loggers=False) self.parser = ConfigParser() self.parser.read_config(f=config_file) diff --git a/monai/fl/client/monai_algo.py b/monai/fl/client/monai_algo.py index 9acf131bd9..a3ac58c221 100644 --- a/monai/fl/client/monai_algo.py +++ b/monai/fl/client/monai_algo.py @@ -134,12 +134,14 @@ def initialize(self, extra=None): Args: extra: Dict with additional information that should be provided by FL system, - i.e., `ExtraItems.CLIENT_NAME` and `ExtraItems.APP_ROOT`. + i.e., `ExtraItems.CLIENT_NAME`, `ExtraItems.APP_ROOT` and `ExtraItems.LOGGING_FILE`. + You can diable the logging logic in the monai bundle by setting {ExtraItems.LOGGING_FILE} to False. """ if extra is None: extra = {} self.client_name = extra.get(ExtraItems.CLIENT_NAME, "noname") + logging_file = extra.get(ExtraItems.LOGGING_FILE, None) self.logger.info(f"Initializing {self.client_name} ...") # FL platform needs to provide filepath to configuration files @@ -149,7 +151,7 @@ def initialize(self, extra=None): if self.workflow is None: config_train_files = self._add_config_files(self.config_train_filename) self.workflow = ConfigWorkflow( - config_file=config_train_files, meta_file=None, logging_file=None, workflow_type="train" + config_file=config_train_files, meta_file=None, logging_file=logging_file, workflow_type="train" ) self.workflow.initialize() self.workflow.bundle_root = self.bundle_root @@ -412,13 +414,15 @@ def initialize(self, extra=None): Args: extra: Dict with additional information that should be provided by FL system, - i.e., `ExtraItems.CLIENT_NAME` and `ExtraItems.APP_ROOT`. + i.e., `ExtraItems.CLIENT_NAME`, `ExtraItems.APP_ROOT` and `ExtraItems.LOGGING_FILE`. + You can diable the logging logic in the monai bundle by setting {ExtraItems.LOGGING_FILE} to False. """ self._set_cuda_device() if extra is None: extra = {} self.client_name = extra.get(ExtraItems.CLIENT_NAME, "noname") + logging_file = extra.get(ExtraItems.LOGGING_FILE, None) timestamp = time.strftime("%Y%m%d_%H%M%S") self.logger.info(f"Initializing {self.client_name} ...") # FL platform needs to provide filepath to configuration files @@ -434,7 +438,7 @@ def initialize(self, extra=None): self.train_workflow = ConfigWorkflow( config_file=config_train_files, meta_file=None, - logging_file=None, + logging_file=logging_file, workflow_type="train", **self.train_kwargs, ) @@ -459,7 +463,7 @@ def initialize(self, extra=None): self.eval_workflow = ConfigWorkflow( config_file=config_eval_files, meta_file=None, - logging_file=None, + logging_file=logging_file, workflow_type=self.eval_workflow_name, **self.eval_kwargs, ) diff --git a/monai/fl/utils/constants.py b/monai/fl/utils/constants.py index eda1a6b4f9..18beceeaee 100644 --- a/monai/fl/utils/constants.py +++ b/monai/fl/utils/constants.py @@ -30,6 +30,7 @@ class ExtraItems(StrEnum): CLIENT_NAME = "fl_client_name" APP_ROOT = "fl_app_root" STATS_SENDER = "fl_stats_sender" + LOGGING_FILE = "logging_file" class FlPhase(StrEnum): diff --git a/monai/losses/ds_loss.py b/monai/losses/ds_loss.py index 57fcff6b87..aacc16874d 100644 --- a/monai/losses/ds_loss.py +++ b/monai/losses/ds_loss.py @@ -33,7 +33,7 @@ def __init__(self, loss: _Loss, weight_mode: str = "exp", weights: list[float] | weight_mode: {``"same"``, ``"exp"``, ``"two"``} Specifies the weights calculation for each image level. Defaults to ``"exp"``. - ``"same"``: all weights are equal to 1. - - ``"exp"``: exponentially decreasing weights by a power of 2: 0, 0.5, 0.25, 0.125, etc . + - ``"exp"``: exponentially decreasing weights by a power of 2: 1, 0.5, 0.25, 0.125, etc . - ``"two"``: equal smaller weights for lower levels: 1, 0.5, 0.5, 0.5, 0.5, etc weights: a list of weights to apply to each deeply supervised sub-loss, if provided, this will be used regardless of the weight_mode diff --git a/monai/networks/blocks/__init__.py b/monai/networks/blocks/__init__.py index afb6664bd9..47abc4a1c4 100644 --- a/monai/networks/blocks/__init__.py +++ b/monai/networks/blocks/__init__.py @@ -17,6 +17,7 @@ from .backbone_fpn_utils import BackboneWithFPN from .convolutions import Convolution, ResidualUnit from .crf import CRF +from .crossattention import CrossAttentionBlock from .denseblock import ConvDenseBlock, DenseBlock from .dints_block import ActiConvNormBlock, FactorizedIncreaseBlock, FactorizedReduceBlock, P3DActiConvNormBlock from .downsample import MaxAvgPool @@ -31,6 +32,7 @@ from .segresnet_block import ResBlock from .selfattention import SABlock from .spade_norm import SPADE +from .spatialattention import SpatialAttentionBlock from .squeeze_and_excitation import ( ChannelSELayer, ResidualSELayer, diff --git a/monai/networks/blocks/crossattention.py b/monai/networks/blocks/crossattention.py new file mode 100644 index 0000000000..dc1d5d388e --- /dev/null +++ b/monai/networks/blocks/crossattention.py @@ -0,0 +1,166 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Optional, Tuple + +import torch +import torch.nn as nn + +from monai.networks.layers.utils import get_rel_pos_embedding_layer +from monai.utils import optional_import + +Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange") + + +class CrossAttentionBlock(nn.Module): + """ + A cross-attention block, based on: "Dosovitskiy et al., + An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " + One can setup relative positional embedding as described in + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + dropout_rate: float = 0.0, + hidden_input_size: int | None = None, + context_input_size: int | None = None, + dim_head: int | None = None, + qkv_bias: bool = False, + save_attn: bool = False, + causal: bool = False, + sequence_length: int | None = None, + rel_pos_embedding: Optional[str] = None, + input_size: Optional[Tuple] = None, + attention_dtype: Optional[torch.dtype] = None, + ) -> None: + """ + Args: + hidden_size (int): dimension of hidden layer. + num_heads (int): number of attention heads. + dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0. + hidden_input_size (int, optional): dimension of the input tensor. Defaults to hidden_size. + context_input_size (int, optional): dimension of the context tensor. Defaults to hidden_size. + dim_head (int, optional): dimension of each head. Defaults to hidden_size // num_heads. + qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False. + save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. + causal: whether to use causal attention. + sequence_length: if causal is True, it is necessary to specify the sequence length. + rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map. + For now only "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported. + input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative + positional parameter size. + attention_dtype: cast attention operations to this dtype. + """ + + super().__init__() + + if not (0 <= dropout_rate <= 1): + raise ValueError("dropout_rate should be between 0 and 1.") + + if dim_head: + inner_size = num_heads * dim_head + self.head_dim = dim_head + else: + if hidden_size % num_heads != 0: + raise ValueError("hidden size should be divisible by num_heads.") + inner_size = hidden_size + self.head_dim = hidden_size // num_heads + + if causal and sequence_length is None: + raise ValueError("sequence_length is necessary for causal attention.") + + self.num_heads = num_heads + self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size + self.context_input_size = context_input_size if context_input_size else hidden_size + self.out_proj = nn.Linear(inner_size, self.hidden_input_size) + # key, query, value projections + self.to_q = nn.Linear(self.hidden_input_size, inner_size, bias=qkv_bias) + self.to_k = nn.Linear(self.context_input_size, inner_size, bias=qkv_bias) + self.to_v = nn.Linear(self.context_input_size, inner_size, bias=qkv_bias) + self.input_rearrange = Rearrange("b h (l d) -> b l h d", l=num_heads) + + self.out_rearrange = Rearrange("b h l d -> b l (h d)") + self.drop_output = nn.Dropout(dropout_rate) + self.drop_weights = nn.Dropout(dropout_rate) + + self.scale = self.head_dim**-0.5 + self.save_attn = save_attn + self.attention_dtype = attention_dtype + + self.causal = causal + self.sequence_length = sequence_length + + if causal and sequence_length is not None: + # causal mask to ensure that attention is only applied to the left in the input sequence + self.register_buffer( + "causal_mask", + torch.tril(torch.ones(sequence_length, sequence_length)).view(1, 1, sequence_length, sequence_length), + ) + self.causal_mask: torch.Tensor + + self.att_mat = torch.Tensor() + self.rel_positional_embedding = ( + get_rel_pos_embedding_layer(rel_pos_embedding, input_size, self.head_dim, self.num_heads) + if rel_pos_embedding is not None + else None + ) + self.input_size = input_size + + def forward(self, x: torch.Tensor, context: torch.Tensor | None = None): + """ + Args: + x (torch.Tensor): input tensor. B x (s_dim_1 * ... * s_dim_n) x C + context (torch.Tensor, optional): context tensor. B x (s_dim_1 * ... * s_dim_n) x C + + Return: + torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C + """ + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + b, t, c = x.size() # batch size, sequence length, embedding dimensionality (hidden_size) + + q = self.to_q(x) + kv = context if context is not None else x + _, kv_t, _ = kv.size() + k = self.to_k(kv) + v = self.to_v(kv) + + if self.attention_dtype is not None: + q = q.to(self.attention_dtype) + k = k.to(self.attention_dtype) + + q = q.view(b, t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, t, hs) + k = k.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs) + v = v.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs) + att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale + + # apply relative positional embedding if defined + att_mat = self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat + + if self.causal: + att_mat = att_mat.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf")) + + att_mat = att_mat.softmax(dim=-1) + + if self.save_attn: + # no gradients and new tensor; + # https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html + self.att_mat = att_mat.detach() + + att_mat = self.drop_weights(att_mat) + x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) + x = self.out_rearrange(x) + x = self.out_proj(x) + x = self.drop_output(x) + return x diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 3bef24b4e8..370ad38595 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -34,22 +34,32 @@ def __init__( hidden_size: int, num_heads: int, dropout_rate: float = 0.0, + hidden_input_size: int | None = None, + dim_head: int | None = None, qkv_bias: bool = False, save_attn: bool = False, + causal: bool = False, + sequence_length: int | None = None, rel_pos_embedding: Optional[str] = None, input_size: Optional[Tuple] = None, + attention_dtype: Optional[torch.dtype] = None, ) -> None: """ Args: hidden_size (int): dimension of hidden layer. num_heads (int): number of attention heads. dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0. + hidden_input_size (int, optional): dimension of the input tensor. Defaults to hidden_size. + dim_head (int, optional): dimension of each head. Defaults to hidden_size // num_heads. qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False. + save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. + causal: whether to use causal attention (see https://arxiv.org/abs/1706.03762). + sequence_length: if causal is True, it is necessary to specify the sequence length. rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map. For now only "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported. input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative positional parameter size. - save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. + attention_dtype: cast attention operations to this dtype. """ @@ -58,22 +68,43 @@ def __init__( if not (0 <= dropout_rate <= 1): raise ValueError("dropout_rate should be between 0 and 1.") - if hidden_size % num_heads != 0: - raise ValueError("hidden size should be divisible by num_heads.") + if dim_head: + inner_dim = num_heads * dim_head + self.dim_head = dim_head + else: + if hidden_size % num_heads != 0: + raise ValueError("hidden size should be divisible by num_heads.") + inner_dim = hidden_size + self.dim_head = hidden_size // num_heads + + if causal and sequence_length is None: + raise ValueError("sequence_length is necessary for causal attention.") self.num_heads = num_heads - self.out_proj = nn.Linear(hidden_size, hidden_size) - self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias) + self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size + self.out_proj = nn.Linear(inner_dim, self.hidden_input_size) + self.qkv = nn.Linear(self.hidden_input_size, inner_dim * 3, bias=qkv_bias) self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads) self.out_rearrange = Rearrange("b h l d -> b l (h d)") self.drop_output = nn.Dropout(dropout_rate) self.drop_weights = nn.Dropout(dropout_rate) - self.head_dim = hidden_size // num_heads - self.scale = self.head_dim**-0.5 + self.scale = self.dim_head**-0.5 self.save_attn = save_attn + self.attention_dtype = attention_dtype + self.causal = causal + self.sequence_length = sequence_length + + if causal and sequence_length is not None: + # causal mask to ensure that attention is only applied to the left in the input sequence + self.register_buffer( + "causal_mask", + torch.tril(torch.ones(sequence_length, sequence_length)).view(1, 1, sequence_length, sequence_length), + ) + self.causal_mask: torch.Tensor + self.att_mat = torch.Tensor() self.rel_positional_embedding = ( - get_rel_pos_embedding_layer(rel_pos_embedding, input_size, self.head_dim, self.num_heads) + get_rel_pos_embedding_layer(rel_pos_embedding, input_size, self.dim_head, self.num_heads) if rel_pos_embedding is not None else None ) @@ -89,11 +120,17 @@ def forward(self, x: torch.Tensor): """ output = self.input_rearrange(self.qkv(x)) q, k, v = output[0], output[1], output[2] + if self.attention_dtype is not None: + q = q.to(self.attention_dtype) + k = k.to(self.attention_dtype) att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale # apply relative positional embedding if defined att_mat = self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat + if self.causal: + att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[1], : x.shape[1]] == 0, float("-inf")) + att_mat = att_mat.softmax(dim=-1) if self.save_attn: diff --git a/monai/networks/blocks/spade_norm.py b/monai/networks/blocks/spade_norm.py index b1046f3154..8e082defe0 100644 --- a/monai/networks/blocks/spade_norm.py +++ b/monai/networks/blocks/spade_norm.py @@ -85,7 +85,7 @@ def forward(self, x: torch.Tensor, segmap: torch.Tensor) -> torch.Tensor: """ # Part 1. generate parameter-free normalized activations - normalized = self.param_free_norm(x) + normalized = self.param_free_norm(x.contiguous()) # Part 2. produce scaling and bias conditioned on semantic map segmap = F.interpolate(segmap, size=x.size()[2:], mode="nearest") diff --git a/monai/networks/blocks/spatialattention.py b/monai/networks/blocks/spatialattention.py new file mode 100644 index 0000000000..020d8d23fd --- /dev/null +++ b/monai/networks/blocks/spatialattention.py @@ -0,0 +1,82 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Optional + +import torch +import torch.nn as nn + +from monai.networks.blocks import SABlock +from monai.utils import optional_import + +Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange") + + +class SpatialAttentionBlock(nn.Module): + """Perform spatial self-attention on the input tensor. + + The input tensor is reshaped to B x (x_dim * y_dim [ * z_dim]) x C, where C is the number of channels, and then + self-attention is performed on the reshaped tensor. The output tensor is reshaped back to the original shape. + + Args: + spatial_dims: number of spatial dimensions, could be 1, 2, or 3. + num_channels: number of input channels. Must be divisible by num_head_channels. + num_head_channels: number of channels per head. + attention_dtype: cast attention operations to this dtype. + + """ + + def __init__( + self, + spatial_dims: int, + num_channels: int, + num_head_channels: int | None = None, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + attention_dtype: Optional[torch.dtype] = None, + ) -> None: + super().__init__() + + self.spatial_dims = spatial_dims + self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels, eps=norm_eps, affine=True) + # check num_head_channels is divisible by num_channels + if num_head_channels is not None and num_channels % num_head_channels != 0: + raise ValueError("num_channels must be divisible by num_head_channels") + num_heads = num_channels // num_head_channels if num_head_channels is not None else 1 + self.attn = SABlock( + hidden_size=num_channels, num_heads=num_heads, qkv_bias=True, attention_dtype=attention_dtype + ) + + def forward(self, x: torch.Tensor): + residual = x + + if self.spatial_dims == 1: + h = x.shape[2] + rearrange_input = Rearrange("b c h -> b h c") + rearrange_output = Rearrange("b h c -> b c h", h=h) + if self.spatial_dims == 2: + h, w = x.shape[2], x.shape[3] + rearrange_input = Rearrange("b c h w -> b (h w) c") + rearrange_output = Rearrange("b (h w) c -> b c h w", h=h, w=w) + if self.spatial_dims == 3: + h, w, d = x.shape[2], x.shape[3], x.shape[4] + rearrange_input = Rearrange("b c h w d -> b (h w d) c") + rearrange_output = Rearrange("b (h w d) c -> b c h w d", h=h, w=w, d=d) + + x = self.norm(x) + x = rearrange_input(x) # B x (x_dim * y_dim [ * z_dim]) x C + + x = self.attn(x) + x = rearrange_output(x) # B x x C x x_dim * y_dim * [z_dim] + x = x + residual + return x diff --git a/monai/networks/blocks/transformerblock.py b/monai/networks/blocks/transformerblock.py index ddf959dad2..2458902cba 100644 --- a/monai/networks/blocks/transformerblock.py +++ b/monai/networks/blocks/transformerblock.py @@ -11,10 +11,10 @@ from __future__ import annotations +import torch import torch.nn as nn -from monai.networks.blocks.mlp import MLPBlock -from monai.networks.blocks.selfattention import SABlock +from monai.networks.blocks import CrossAttentionBlock, MLPBlock, SABlock class TransformerBlock(nn.Module): @@ -31,6 +31,9 @@ def __init__( dropout_rate: float = 0.0, qkv_bias: bool = False, save_attn: bool = False, + causal: bool = False, + sequence_length: int | None = None, + with_cross_attention: bool = False, ) -> None: """ Args: @@ -53,10 +56,27 @@ def __init__( self.mlp = MLPBlock(hidden_size, mlp_dim, dropout_rate) self.norm1 = nn.LayerNorm(hidden_size) - self.attn = SABlock(hidden_size, num_heads, dropout_rate, qkv_bias, save_attn) + self.attn = SABlock( + hidden_size, + num_heads, + dropout_rate, + qkv_bias=qkv_bias, + save_attn=save_attn, + causal=causal, + sequence_length=sequence_length, + ) self.norm2 = nn.LayerNorm(hidden_size) + self.with_cross_attention = with_cross_attention - def forward(self, x): + if self.with_cross_attention: + self.norm_cross_attn = nn.LayerNorm(hidden_size) + self.cross_attn = CrossAttentionBlock( + hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, qkv_bias=qkv_bias, causal=False + ) + + def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: x = x + self.attn(self.norm1(x)) + if self.with_cross_attention: + x = x + self.cross_attn(self.norm_cross_attn(x), context=context) x = x + self.mlp(self.norm2(x)) return x diff --git a/monai/networks/nets/attentionunet.py b/monai/networks/nets/attentionunet.py index 5689cf1071..fdf31d9701 100644 --- a/monai/networks/nets/attentionunet.py +++ b/monai/networks/nets/attentionunet.py @@ -29,7 +29,7 @@ def __init__( spatial_dims: int, in_channels: int, out_channels: int, - kernel_size: int = 3, + kernel_size: Sequence[int] | int = 3, strides: int = 1, dropout=0.0, ): @@ -219,7 +219,13 @@ def __init__( self.kernel_size = kernel_size self.dropout = dropout - head = ConvBlock(spatial_dims=spatial_dims, in_channels=in_channels, out_channels=channels[0], dropout=dropout) + head = ConvBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=channels[0], + dropout=dropout, + kernel_size=self.kernel_size, + ) reduce_channels = Convolution( spatial_dims=spatial_dims, in_channels=channels[0], @@ -245,6 +251,7 @@ def _create_block(channels: Sequence[int], strides: Sequence[int]) -> nn.Module: out_channels=channels[1], strides=strides[0], dropout=self.dropout, + kernel_size=self.kernel_size, ), subblock, ), @@ -271,6 +278,7 @@ def _get_bottom_layer(self, in_channels: int, out_channels: int, strides: int) - out_channels=out_channels, strides=strides, dropout=self.dropout, + kernel_size=self.kernel_size, ), up_kernel_size=self.up_kernel_size, strides=strides, diff --git a/monai/networks/nets/autoencoderkl.py b/monai/networks/nets/autoencoderkl.py index 372e704d53..17bb90d6f6 100644 --- a/monai/networks/nets/autoencoderkl.py +++ b/monai/networks/nets/autoencoderkl.py @@ -18,8 +18,7 @@ import torch.nn as nn import torch.nn.functional as F -from monai.networks.blocks import Convolution, Upsample -from monai.networks.blocks.selfattention import SABlock +from monai.networks.blocks import Convolution, SpatialAttentionBlock, Upsample from monai.utils import ensure_tuple_rep, optional_import Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange") @@ -144,61 +143,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x + h -class AttentionBlock(nn.Module): - """Perform spatial self-attention on the input tensor. - - The input tensor is reshaped to B x (x_dim * y_dim [ * z_dim]) x C, where C is the number of channels. - - Args: - spatial_dims: number of spatial dimensions, could be 1, 2, or 3. - num_channels: number of input channels. Must be divisible by num_head_channels. - num_head_channels: number of channels per head. - """ - - def __init__( - self, - spatial_dims: int, - num_channels: int, - num_head_channels: int | None = None, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - ) -> None: - super().__init__() - - self.spatial_dims = spatial_dims - self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels, eps=norm_eps, affine=True) - # check num_head_channels is divisible by num_channels - if num_head_channels is not None and num_channels % num_head_channels != 0: - raise ValueError("num_channels must be divisible by num_head_channels") - num_heads = num_channels // num_head_channels if num_head_channels is not None else 1 - - self.attn = SABlock(hidden_size=num_channels, num_heads=num_heads, qkv_bias=True) - - def forward(self, x: torch.Tensor): - residual = x - - if self.spatial_dims == 1: - h = x.shape[2] - rearrange_input = Rearrange("b c h -> b h c") - rearrange_output = Rearrange("b h c -> b c h", h=h) - if self.spatial_dims == 2: - h, w = x.shape[2], x.shape[3] - rearrange_input = Rearrange("b c h w -> b (h w) c") - rearrange_output = Rearrange("b (h w) c -> b c h w", h=h, w=w) - if self.spatial_dims == 3: - h, w, d = x.shape[2], x.shape[3], x.shape[4] - rearrange_input = Rearrange("b c h w d -> b (h w d) c") - rearrange_output = Rearrange("b (h w d) c -> b c h w d", h=h, w=w, d=d) - - x = self.norm(x) - x = rearrange_input(x) # B x (x_dim * y_dim [ * z_dim]) x C - - x = self.attn(x) - x = rearrange_output(x) # B x x C x x_dim * y_dim * [z_dim] - x = x + residual - return x - - class Encoder(nn.Module): """ Convolutional cascade that downsamples the image into a spatial latent space. @@ -271,7 +215,7 @@ def __init__( input_channel = output_channel if attention_levels[i]: blocks.append( - AttentionBlock( + SpatialAttentionBlock( spatial_dims=spatial_dims, num_channels=input_channel, norm_num_groups=norm_num_groups, @@ -294,7 +238,7 @@ def __init__( ) blocks.append( - AttentionBlock( + SpatialAttentionBlock( spatial_dims=spatial_dims, num_channels=channels[-1], norm_num_groups=norm_num_groups, @@ -401,7 +345,7 @@ def __init__( ) ) blocks.append( - AttentionBlock( + SpatialAttentionBlock( spatial_dims=spatial_dims, num_channels=reversed_block_out_channels[0], norm_num_groups=norm_num_groups, @@ -440,7 +384,7 @@ def __init__( if reversed_attention_levels[i]: blocks.append( - AttentionBlock( + SpatialAttentionBlock( spatial_dims=spatial_dims, num_channels=block_in_ch, norm_num_groups=norm_num_groups, diff --git a/monai/networks/nets/controlnet.py b/monai/networks/nets/controlnet.py index d98755f401..7450c87314 100644 --- a/monai/networks/nets/controlnet.py +++ b/monai/networks/nets/controlnet.py @@ -141,7 +141,6 @@ class ControlNet(nn.Module): num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` classes. upcast_attention: if True, upcast attention operations to full precision. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. conditioning_embedding_in_channels: number of input channels for the conditioning embedding. conditioning_embedding_num_channels: number of channels for the blocks in the conditioning embedding. """ @@ -162,7 +161,6 @@ def __init__( cross_attention_dim: int | None = None, num_class_embeds: int | None = None, upcast_attention: bool = False, - use_flash_attention: bool = False, conditioning_embedding_in_channels: int = 1, conditioning_embedding_num_channels: Sequence[int] = (16, 32, 96, 256), ) -> None: @@ -209,11 +207,6 @@ def __init__( f"`num_channels`, but got num_res_blocks={num_res_blocks} and channels={channels}." ) - if use_flash_attention is True and not torch.cuda.is_available(): - raise ValueError( - "torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU." - ) - self.in_channels = in_channels self.block_out_channels = channels self.num_res_blocks = num_res_blocks @@ -289,7 +282,6 @@ def __init__( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, ) self.down_blocks.append(down_block) @@ -334,7 +326,6 @@ def __init__( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, ) controlnet_block = Convolution( diff --git a/monai/networks/nets/diffusion_model_unet.py b/monai/networks/nets/diffusion_model_unet.py index 0441cc9cfe..38d7f816a9 100644 --- a/monai/networks/nets/diffusion_model_unet.py +++ b/monai/networks/nets/diffusion_model_unet.py @@ -35,17 +35,13 @@ from collections.abc import Sequence import torch -import torch.nn.functional as F from torch import nn -from monai.networks.blocks import Convolution, MLPBlock +from monai.networks.blocks import Convolution, CrossAttentionBlock, MLPBlock, SABlock, SpatialAttentionBlock, Upsample from monai.networks.layers.factories import Pool from monai.utils import ensure_tuple_rep, optional_import -# To install xformers, use pip install xformers==0.0.16rc401 - -xops, has_xformers = optional_import("xformers.ops") - +Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange") __all__ = ["DiffusionModelUNet"] @@ -59,122 +55,9 @@ def zero_module(module: nn.Module) -> nn.Module: return module -class _CrossAttention(nn.Module): - """ - NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make - use of this block as support is not guaranteed. For more information see: - https://github.com/Project-MONAI/MONAI/issues/7227 - - A cross attention layer. - - Args: - query_dim: number of channels in the query. - cross_attention_dim: number of channels in the context. - num_attention_heads: number of heads to use for multi-head attention. - num_head_channels: number of channels in each head. - dropout: dropout probability to use. - upcast_attention: if True, upcast attention operations to full precision. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - """ - - def __init__( - self, - query_dim: int, - cross_attention_dim: int | None = None, - num_attention_heads: int = 8, - num_head_channels: int = 64, - dropout: float = 0.0, - upcast_attention: bool = False, - use_flash_attention: bool = False, - ) -> None: - super().__init__() - self.use_flash_attention = use_flash_attention - inner_dim = num_head_channels * num_attention_heads - cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim - - self.scale = 1 / math.sqrt(num_head_channels) - self.num_heads = num_attention_heads - - self.upcast_attention = upcast_attention - - self.to_q = nn.Linear(query_dim, inner_dim, bias=False) - self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False) - self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=False) - - self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) - - def reshape_heads_to_batch_dim(self, x: torch.Tensor) -> torch.Tensor: - """ - Divide hidden state dimension to the multiple attention heads and reshape their input as instances in the batch. - """ - batch_size, seq_len, dim = x.shape - x = x.reshape(batch_size, seq_len, self.num_heads, dim // self.num_heads) - x = x.permute(0, 2, 1, 3).reshape(batch_size * self.num_heads, seq_len, dim // self.num_heads) - return x - - def reshape_batch_dim_to_heads(self, x: torch.Tensor) -> torch.Tensor: - """Combine the output of the attention heads back into the hidden state dimension.""" - batch_size, seq_len, dim = x.shape - x = x.reshape(batch_size // self.num_heads, self.num_heads, seq_len, dim) - x = x.permute(0, 2, 1, 3).reshape(batch_size // self.num_heads, seq_len, dim * self.num_heads) - return x - - def _memory_efficient_attention_xformers( - self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor - ) -> torch.Tensor: - query = query.contiguous() - key = key.contiguous() - value = value.contiguous() - x: torch.Tensor = xops.memory_efficient_attention(query, key, value, attn_bias=None) - return x - - def _attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: - dtype = query.dtype - if self.upcast_attention: - query = query.float() - key = key.float() - - attention_scores = torch.baddbmm( - torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), - query, - key.transpose(-1, -2), - beta=0, - alpha=self.scale, - ) - attention_probs = attention_scores.softmax(dim=-1) - attention_probs = attention_probs.to(dtype=dtype) - - x = torch.bmm(attention_probs, value) - return x - - def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: - query = self.to_q(x) - context = context if context is not None else x - key = self.to_k(context) - value = self.to_v(context) - - # Multi-Head Attention - query = self.reshape_heads_to_batch_dim(query) - key = self.reshape_heads_to_batch_dim(key) - value = self.reshape_heads_to_batch_dim(value) - if self.use_flash_attention: - x = self._memory_efficient_attention_xformers(query, key, value) - else: - x = self._attention(query, key, value) - - x = self.reshape_batch_dim_to_heads(x) - x = x.to(query.dtype) - output: torch.Tensor = self.to_out(x) - return output - - -class _BasicTransformerBlock(nn.Module): +class DiffusionUNetTransformerBlock(nn.Module): """ - NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make - use of this block as support is not guaranteed. For more information see: - https://github.com/Project-MONAI/MONAI/issues/7227 - - A basic Transformer block. + A Transformer block that allows for the input dimension to differ from the hidden dimension. Args: num_channels: number of channels in the input and output. @@ -183,7 +66,7 @@ class _BasicTransformerBlock(nn.Module): dropout: dropout probability to use. cross_attention_dim: size of the context vector for cross attention. upcast_attention: if True, upcast attention operations to full precision. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ def __init__( @@ -194,27 +77,26 @@ def __init__( dropout: float = 0.0, cross_attention_dim: int | None = None, upcast_attention: bool = False, - use_flash_attention: bool = False, ) -> None: super().__init__() - self.attn1 = _CrossAttention( - query_dim=num_channels, - num_attention_heads=num_attention_heads, - num_head_channels=num_head_channels, - dropout=dropout, - upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, - ) # is a self-attention + self.attn1 = SABlock( + hidden_size=num_attention_heads * num_head_channels, + hidden_input_size=num_channels, + num_heads=num_attention_heads, + dim_head=num_head_channels, + dropout_rate=dropout, + attention_dtype=torch.float if upcast_attention else None, + ) self.ff = MLPBlock(hidden_size=num_channels, mlp_dim=num_channels * 4, act="GEGLU", dropout_rate=dropout) - self.attn2 = _CrossAttention( - query_dim=num_channels, - cross_attention_dim=cross_attention_dim, - num_attention_heads=num_attention_heads, - num_head_channels=num_head_channels, - dropout=dropout, - upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, - ) # is a self-attention if context is None + self.attn2 = CrossAttentionBlock( + hidden_size=num_attention_heads * num_head_channels, + num_heads=num_attention_heads, + hidden_input_size=num_channels, + context_input_size=cross_attention_dim, + dim_head=num_head_channels, + dropout_rate=dropout, + attention_dtype=torch.float if upcast_attention else None, + ) self.norm1 = nn.LayerNorm(num_channels) self.norm2 = nn.LayerNorm(num_channels) self.norm3 = nn.LayerNorm(num_channels) @@ -231,7 +113,7 @@ def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch return x -class _SpatialTransformer(nn.Module): +class SpatialTransformer(nn.Module): """ NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make use of this block as support is not guaranteed. For more information see: @@ -251,7 +133,6 @@ class _SpatialTransformer(nn.Module): norm_eps: epsilon for the normalization. cross_attention_dim: number of context dimensions to use. upcast_attention: if True, upcast attention operations to full precision. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. """ def __init__( @@ -266,7 +147,6 @@ def __init__( norm_eps: float = 1e-6, cross_attention_dim: int | None = None, upcast_attention: bool = False, - use_flash_attention: bool = False, ) -> None: super().__init__() self.spatial_dims = spatial_dims @@ -287,14 +167,13 @@ def __init__( self.transformer_blocks = nn.ModuleList( [ - _BasicTransformerBlock( + DiffusionUNetTransformerBlock( num_channels=inner_dim, num_attention_heads=num_attention_heads, num_head_channels=num_head_channels, dropout=dropout, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, ) for _ in range(num_layers) ] @@ -343,126 +222,6 @@ def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch return x + residual -class _AttentionBlock(nn.Module): - """ - NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make - use of this block as support is not guaranteed. For more information see: - https://github.com/Project-MONAI/MONAI/issues/7227 - - An attention block that allows spatial positions to attend to each other. Uses three q, k, v linear layers to - compute attention. - - Args: - spatial_dims: number of spatial dimensions. - num_channels: number of input channels. - num_head_channels: number of channels in each attention head. - norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of - channels is divisible by this number. - norm_eps: epsilon value to use for the normalisation. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - """ - - def __init__( - self, - spatial_dims: int, - num_channels: int, - num_head_channels: int | None = None, - norm_num_groups: int = 32, - norm_eps: float = 1e-6, - use_flash_attention: bool = False, - ) -> None: - super().__init__() - self.use_flash_attention = use_flash_attention - self.spatial_dims = spatial_dims - self.num_channels = num_channels - - self.num_heads = num_channels // num_head_channels if num_head_channels is not None else 1 - self.scale = 1 / math.sqrt(num_channels / self.num_heads) - - self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels, eps=norm_eps, affine=True) - - self.to_q = nn.Linear(num_channels, num_channels) - self.to_k = nn.Linear(num_channels, num_channels) - self.to_v = nn.Linear(num_channels, num_channels) - - self.proj_attn = nn.Linear(num_channels, num_channels) - - def reshape_heads_to_batch_dim(self, x: torch.Tensor) -> torch.Tensor: - batch_size, seq_len, dim = x.shape - x = x.reshape(batch_size, seq_len, self.num_heads, dim // self.num_heads) - x = x.permute(0, 2, 1, 3).reshape(batch_size * self.num_heads, seq_len, dim // self.num_heads) - return x - - def reshape_batch_dim_to_heads(self, x: torch.Tensor) -> torch.Tensor: - batch_size, seq_len, dim = x.shape - x = x.reshape(batch_size // self.num_heads, self.num_heads, seq_len, dim) - x = x.permute(0, 2, 1, 3).reshape(batch_size // self.num_heads, seq_len, dim * self.num_heads) - return x - - def _memory_efficient_attention_xformers( - self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor - ) -> torch.Tensor: - query = query.contiguous() - key = key.contiguous() - value = value.contiguous() - x: torch.Tensor = xops.memory_efficient_attention(query, key, value, attn_bias=None) - return x - - def _attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: - attention_scores = torch.baddbmm( - torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), - query, - key.transpose(-1, -2), - beta=0, - alpha=self.scale, - ) - attention_probs = attention_scores.softmax(dim=-1) - x = torch.bmm(attention_probs, value) - return x - - def forward(self, x: torch.Tensor) -> torch.Tensor: - residual = x - - batch = channel = height = width = depth = -1 - if self.spatial_dims == 2: - batch, channel, height, width = x.shape - if self.spatial_dims == 3: - batch, channel, height, width, depth = x.shape - - # norm - x = self.norm(x.contiguous()) - - if self.spatial_dims == 2: - x = x.view(batch, channel, height * width).transpose(1, 2) - if self.spatial_dims == 3: - x = x.view(batch, channel, height * width * depth).transpose(1, 2) - - # proj to q, k, v - query = self.to_q(x) - key = self.to_k(x) - value = self.to_v(x) - - # Multi-Head Attention - query = self.reshape_heads_to_batch_dim(query) - key = self.reshape_heads_to_batch_dim(key) - value = self.reshape_heads_to_batch_dim(value) - - if self.use_flash_attention: - x = self._memory_efficient_attention_xformers(query, key, value) - else: - x = self._attention(query, key, value) - - x = self.reshape_batch_dim_to_heads(x) - x = x.to(query.dtype) - - if self.spatial_dims == 2: - x = x.transpose(-1, -2).reshape(batch, channel, height, width) - if self.spatial_dims == 3: - x = x.transpose(-1, -2).reshape(batch, channel, height, width, depth) - - return x + residual - - def get_timestep_embedding(timesteps: torch.Tensor, embedding_dim: int, max_period: int = 10000) -> torch.Tensor: """ Create sinusoidal timestep embeddings following the implementation in Ho et al. "Denoising Diffusion Probabilistic @@ -490,12 +249,8 @@ def get_timestep_embedding(timesteps: torch.Tensor, embedding_dim: int, max_peri return embedding -class _Downsample(nn.Module): +class DiffusionUnetDownsample(nn.Module): """ - NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make - use of this block as support is not guaranteed. For more information see: - https://github.com/Project-MONAI/MONAI/issues/7227 - Downsampling layer. Args: @@ -541,68 +296,19 @@ def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Ten return output -class _Upsample(nn.Module): +class WrappedUpsample(Upsample): """ - NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make - use of this block as support is not guaranteed. For more information see: - https://github.com/Project-MONAI/MONAI/issues/7227 - - Upsampling layer with an optional convolution. - - Args: - spatial_dims: number of spatial dimensions. - num_channels: number of input channels. - use_conv: if True uses Convolution instead of Pool average to perform downsampling. - out_channels: number of output channels. - padding: controls the amount of implicit zero-paddings on both sides for padding number of points for each - dimension. + Wraps MONAI upsample block to allow for calling with timestep embeddings. """ - def __init__( - self, spatial_dims: int, num_channels: int, use_conv: bool, out_channels: int | None = None, padding: int = 1 - ) -> None: - super().__init__() - self.num_channels = num_channels - self.out_channels = out_channels or num_channels - self.use_conv = use_conv - if use_conv: - self.conv = Convolution( - spatial_dims=spatial_dims, - in_channels=self.num_channels, - out_channels=self.out_channels, - strides=1, - kernel_size=3, - padding=padding, - conv_only=True, - ) - def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor: del emb - if x.shape[1] != self.num_channels: - raise ValueError("Input channels should be equal to num_channels") - - # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 - # https://github.com/pytorch/pytorch/issues/86679 - dtype = x.dtype - if dtype == torch.bfloat16: - x = x.to(torch.float32) - - x = F.interpolate(x, scale_factor=2.0, mode="nearest") + upsampled: torch.Tensor = super().forward(x) + return upsampled - # If the input is bfloat16, we cast back to bfloat16 - if dtype == torch.bfloat16: - x = x.to(dtype) - if self.use_conv: - x = self.conv(x) - return x - - -class _ResnetBlock(nn.Module): +class DiffusionUNetResnetBlock(nn.Module): """ - NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make - use of this block as support is not guaranteed. For more information see: - https://github.com/Project-MONAI/MONAI/issues/7227 Residual block with timestep conditioning. Args: @@ -649,9 +355,17 @@ def __init__( self.upsample = self.downsample = None if self.up: - self.upsample = _Upsample(spatial_dims, in_channels, use_conv=False) + self.upsample = WrappedUpsample( + spatial_dims=spatial_dims, + mode="nontrainable", + in_channels=in_channels, + out_channels=in_channels, + interp_mode="nearest", + scale_factor=2.0, + align_corners=None, + ) elif down: - self.downsample = _Downsample(spatial_dims, in_channels, use_conv=False) + self.downsample = DiffusionUnetDownsample(spatial_dims, in_channels, use_conv=False) self.time_emb_proj = nn.Linear(temb_channels, self.out_channels) @@ -749,7 +463,7 @@ def __init__( for i in range(num_res_blocks): in_channels = in_channels if i == 0 else out_channels resnets.append( - _ResnetBlock( + DiffusionUNetResnetBlock( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, @@ -764,7 +478,7 @@ def __init__( if add_downsample: self.downsampler: nn.Module | None if resblock_updown: - self.downsampler = _ResnetBlock( + self.downsampler = DiffusionUNetResnetBlock( spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels, @@ -774,7 +488,7 @@ def __init__( down=True, ) else: - self.downsampler = _Downsample( + self.downsampler = DiffusionUnetDownsample( spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, @@ -817,7 +531,6 @@ class AttnDownBlock(nn.Module): resblock_updown: if True use residual blocks for downsampling. downsample_padding: padding used in the downsampling block. num_head_channels: number of channels in each attention head. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. """ def __init__( @@ -833,7 +546,6 @@ def __init__( resblock_updown: bool = False, downsample_padding: int = 1, num_head_channels: int = 1, - use_flash_attention: bool = False, ) -> None: super().__init__() self.resblock_updown = resblock_updown @@ -844,7 +556,7 @@ def __init__( for i in range(num_res_blocks): in_channels = in_channels if i == 0 else out_channels resnets.append( - _ResnetBlock( + DiffusionUNetResnetBlock( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, @@ -854,13 +566,12 @@ def __init__( ) ) attentions.append( - _AttentionBlock( + SpatialAttentionBlock( spatial_dims=spatial_dims, num_channels=out_channels, num_head_channels=num_head_channels, norm_num_groups=norm_num_groups, norm_eps=norm_eps, - use_flash_attention=use_flash_attention, ) ) @@ -870,7 +581,7 @@ def __init__( self.downsampler: nn.Module | None if add_downsample: if resblock_updown: - self.downsampler = _ResnetBlock( + self.downsampler = DiffusionUNetResnetBlock( spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels, @@ -880,7 +591,7 @@ def __init__( down=True, ) else: - self.downsampler = _Downsample( + self.downsampler = DiffusionUnetDownsample( spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, @@ -927,7 +638,6 @@ class CrossAttnDownBlock(nn.Module): transformer_num_layers: number of layers of Transformer blocks to use. cross_attention_dim: number of context dimensions to use. upcast_attention: if True, upcast attention operations to full precision. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers """ @@ -947,7 +657,6 @@ def __init__( transformer_num_layers: int = 1, cross_attention_dim: int | None = None, upcast_attention: bool = False, - use_flash_attention: bool = False, dropout_cattn: float = 0.0, ) -> None: super().__init__() @@ -959,7 +668,7 @@ def __init__( for i in range(num_res_blocks): in_channels = in_channels if i == 0 else out_channels resnets.append( - _ResnetBlock( + DiffusionUNetResnetBlock( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, @@ -970,7 +679,7 @@ def __init__( ) attentions.append( - _SpatialTransformer( + SpatialTransformer( spatial_dims=spatial_dims, in_channels=out_channels, num_attention_heads=out_channels // num_head_channels, @@ -980,7 +689,6 @@ def __init__( norm_eps=norm_eps, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, dropout=dropout_cattn, ) ) @@ -991,7 +699,7 @@ def __init__( self.downsampler: nn.Module | None if add_downsample: if resblock_updown: - self.downsampler = _ResnetBlock( + self.downsampler = DiffusionUNetResnetBlock( spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels, @@ -1001,7 +709,7 @@ def __init__( down=True, ) else: - self.downsampler = _Downsample( + self.downsampler = DiffusionUnetDownsample( spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, @@ -1039,7 +747,6 @@ class AttnMidBlock(nn.Module): norm_num_groups: number of groups for the group normalization. norm_eps: epsilon for the group normalization. num_head_channels: number of channels in each attention head. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. """ def __init__( @@ -1050,11 +757,10 @@ def __init__( norm_num_groups: int = 32, norm_eps: float = 1e-6, num_head_channels: int = 1, - use_flash_attention: bool = False, ) -> None: super().__init__() - self.resnet_1 = _ResnetBlock( + self.resnet_1 = DiffusionUNetResnetBlock( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=in_channels, @@ -1062,16 +768,15 @@ def __init__( norm_num_groups=norm_num_groups, norm_eps=norm_eps, ) - self.attention = _AttentionBlock( + self.attention = SpatialAttentionBlock( spatial_dims=spatial_dims, num_channels=in_channels, num_head_channels=num_head_channels, norm_num_groups=norm_num_groups, norm_eps=norm_eps, - use_flash_attention=use_flash_attention, ) - self.resnet_2 = _ResnetBlock( + self.resnet_2 = DiffusionUNetResnetBlock( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=in_channels, @@ -1105,7 +810,6 @@ class CrossAttnMidBlock(nn.Module): transformer_num_layers: number of layers of Transformer blocks to use. cross_attention_dim: number of context dimensions to use. upcast_attention: if True, upcast attention operations to full precision. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. """ def __init__( @@ -1119,12 +823,11 @@ def __init__( transformer_num_layers: int = 1, cross_attention_dim: int | None = None, upcast_attention: bool = False, - use_flash_attention: bool = False, dropout_cattn: float = 0.0, ) -> None: super().__init__() - self.resnet_1 = _ResnetBlock( + self.resnet_1 = DiffusionUNetResnetBlock( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=in_channels, @@ -1132,7 +835,7 @@ def __init__( norm_num_groups=norm_num_groups, norm_eps=norm_eps, ) - self.attention = _SpatialTransformer( + self.attention = SpatialTransformer( spatial_dims=spatial_dims, in_channels=in_channels, num_attention_heads=in_channels // num_head_channels, @@ -1142,10 +845,9 @@ def __init__( norm_eps=norm_eps, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, dropout=dropout_cattn, ) - self.resnet_2 = _ResnetBlock( + self.resnet_2 = DiffusionUNetResnetBlock( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=in_channels, @@ -1203,7 +905,7 @@ def __init__( resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( - _ResnetBlock( + DiffusionUNetResnetBlock( spatial_dims=spatial_dims, in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, @@ -1218,7 +920,7 @@ def __init__( self.upsampler: nn.Module | None if add_upsample: if resblock_updown: - self.upsampler = _ResnetBlock( + self.upsampler = DiffusionUNetResnetBlock( spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels, @@ -1228,9 +930,26 @@ def __init__( up=True, ) else: - self.upsampler = _Upsample( - spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels + post_conv = Convolution( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + self.upsampler = WrappedUpsample( + spatial_dims=spatial_dims, + mode="nontrainable", + in_channels=out_channels, + out_channels=out_channels, + interp_mode="nearest", + scale_factor=2.0, + post_conv=post_conv, + align_corners=None, ) + else: self.upsampler = None @@ -1272,7 +991,6 @@ class AttnUpBlock(nn.Module): add_upsample: if True add downsample block. resblock_updown: if True use residual blocks for upsampling. num_head_channels: number of channels in each attention head. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. """ def __init__( @@ -1288,7 +1006,6 @@ def __init__( add_upsample: bool = True, resblock_updown: bool = False, num_head_channels: int = 1, - use_flash_attention: bool = False, ) -> None: super().__init__() self.resblock_updown = resblock_updown @@ -1301,7 +1018,7 @@ def __init__( resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( - _ResnetBlock( + DiffusionUNetResnetBlock( spatial_dims=spatial_dims, in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, @@ -1311,13 +1028,12 @@ def __init__( ) ) attentions.append( - _AttentionBlock( + SpatialAttentionBlock( spatial_dims=spatial_dims, num_channels=out_channels, num_head_channels=num_head_channels, norm_num_groups=norm_num_groups, norm_eps=norm_eps, - use_flash_attention=use_flash_attention, ) ) @@ -1327,7 +1043,7 @@ def __init__( self.upsampler: nn.Module | None if add_upsample: if resblock_updown: - self.upsampler = _ResnetBlock( + self.upsampler = DiffusionUNetResnetBlock( spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels, @@ -1337,8 +1053,25 @@ def __init__( up=True, ) else: - self.upsampler = _Upsample( - spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels + + post_conv = Convolution( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + self.upsampler = WrappedUpsample( + spatial_dims=spatial_dims, + mode="nontrainable", + in_channels=out_channels, + out_channels=out_channels, + interp_mode="nearest", + scale_factor=2.0, + post_conv=post_conv, + align_corners=None, ) else: self.upsampler = None @@ -1385,7 +1118,6 @@ class CrossAttnUpBlock(nn.Module): transformer_num_layers: number of layers of Transformer blocks to use. cross_attention_dim: number of context dimensions to use. upcast_attention: if True, upcast attention operations to full precision. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers """ @@ -1405,7 +1137,6 @@ def __init__( transformer_num_layers: int = 1, cross_attention_dim: int | None = None, upcast_attention: bool = False, - use_flash_attention: bool = False, dropout_cattn: float = 0.0, ) -> None: super().__init__() @@ -1419,7 +1150,7 @@ def __init__( resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( - _ResnetBlock( + DiffusionUNetResnetBlock( spatial_dims=spatial_dims, in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, @@ -1429,7 +1160,7 @@ def __init__( ) ) attentions.append( - _SpatialTransformer( + SpatialTransformer( spatial_dims=spatial_dims, in_channels=out_channels, num_attention_heads=out_channels // num_head_channels, @@ -1439,7 +1170,6 @@ def __init__( num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, dropout=dropout_cattn, ) ) @@ -1450,7 +1180,7 @@ def __init__( self.upsampler: nn.Module | None if add_upsample: if resblock_updown: - self.upsampler = _ResnetBlock( + self.upsampler = DiffusionUNetResnetBlock( spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels, @@ -1460,8 +1190,25 @@ def __init__( up=True, ) else: - self.upsampler = _Upsample( - spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels + + post_conv = Convolution( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + self.upsampler = WrappedUpsample( + spatial_dims=spatial_dims, + mode="nontrainable", + in_channels=out_channels, + out_channels=out_channels, + interp_mode="nearest", + scale_factor=2.0, + post_conv=post_conv, + align_corners=None, ) else: self.upsampler = None @@ -1504,7 +1251,6 @@ def get_down_block( transformer_num_layers: int, cross_attention_dim: int | None, upcast_attention: bool = False, - use_flash_attention: bool = False, dropout_cattn: float = 0.0, ) -> nn.Module: if with_attn: @@ -1519,7 +1265,6 @@ def get_down_block( add_downsample=add_downsample, resblock_updown=resblock_updown, num_head_channels=num_head_channels, - use_flash_attention=use_flash_attention, ) elif with_cross_attn: return CrossAttnDownBlock( @@ -1536,7 +1281,6 @@ def get_down_block( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, dropout_cattn=dropout_cattn, ) else: @@ -1564,7 +1308,6 @@ def get_mid_block( transformer_num_layers: int, cross_attention_dim: int | None, upcast_attention: bool = False, - use_flash_attention: bool = False, dropout_cattn: float = 0.0, ) -> nn.Module: if with_conditioning: @@ -1578,7 +1321,6 @@ def get_mid_block( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, dropout_cattn=dropout_cattn, ) else: @@ -1589,7 +1331,6 @@ def get_mid_block( norm_num_groups=norm_num_groups, norm_eps=norm_eps, num_head_channels=num_head_channels, - use_flash_attention=use_flash_attention, ) @@ -1610,7 +1351,6 @@ def get_up_block( transformer_num_layers: int, cross_attention_dim: int | None, upcast_attention: bool = False, - use_flash_attention: bool = False, dropout_cattn: float = 0.0, ) -> nn.Module: if with_attn: @@ -1626,7 +1366,6 @@ def get_up_block( add_upsample=add_upsample, resblock_updown=resblock_updown, num_head_channels=num_head_channels, - use_flash_attention=use_flash_attention, ) elif with_cross_attn: return CrossAttnUpBlock( @@ -1644,7 +1383,6 @@ def get_up_block( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, dropout_cattn=dropout_cattn, ) else: @@ -1685,7 +1423,6 @@ class DiffusionModelUNet(nn.Module): num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` classes. upcast_attention: if True, upcast attention operations to full precision. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers """ @@ -1706,7 +1443,6 @@ def __init__( cross_attention_dim: int | None = None, num_class_embeds: int | None = None, upcast_attention: bool = False, - use_flash_attention: bool = False, dropout_cattn: float = 0.0, ) -> None: super().__init__() @@ -1747,14 +1483,6 @@ def __init__( "`num_channels`." ) - if use_flash_attention and not has_xformers: - raise ValueError("use_flash_attention is True but xformers is not installed.") - - if use_flash_attention is True and not torch.cuda.is_available(): - raise ValueError( - "torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU." - ) - self.in_channels = in_channels self.block_out_channels = channels self.out_channels = out_channels @@ -1809,7 +1537,6 @@ def __init__( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, dropout_cattn=dropout_cattn, ) @@ -1827,7 +1554,6 @@ def __init__( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, dropout_cattn=dropout_cattn, ) @@ -1862,7 +1588,6 @@ def __init__( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, dropout_cattn=dropout_cattn, ) @@ -1944,7 +1669,7 @@ def forward( down_block_res_samples = new_down_block_res_samples # 5. mid - h = self.middle_block(hidden_states=h, temb=emb, context=context) + h = self.middle_block(hidden_states=h.contiguous(), temb=emb, context=context) # Additional residual conections for Controlnets if mid_block_additional_residual is not None: @@ -1961,6 +1686,63 @@ def forward( return output + def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None: + """ + Load a state dict from a DiffusionModelUNet trained with + [MONAI Generative](https://github.com/Project-MONAI/GenerativeModels). + + Args: + old_state_dict: state dict from the old DecoderOnlyTransformer model. + """ + + new_state_dict = self.state_dict() + # if all keys match, just load the state dict + if all(k in new_state_dict for k in old_state_dict): + print("All keys match, loading state dict.") + self.load_state_dict(old_state_dict) + return + + if verbose: + # print all new_state_dict keys that are not in old_state_dict + for k in new_state_dict: + if k not in old_state_dict: + print(f"key {k} not found in old state dict") + # and vice versa + print("----------------------------------------------") + for k in old_state_dict: + if k not in new_state_dict: + print(f"key {k} not found in new state dict") + + # copy over all matching keys + for k in new_state_dict: + if k in old_state_dict: + new_state_dict[k] = old_state_dict[k] + + # fix the attention blocks + attention_blocks = [k.replace(".attn1.qkv.weight", "") for k in new_state_dict if "attn1.qkv.weight" in k] + for block in attention_blocks: + new_state_dict[f"{block}.attn1.qkv.weight"] = torch.concat( + [ + old_state_dict[f"{block}.attn1.to_q.weight"], + old_state_dict[f"{block}.attn1.to_k.weight"], + old_state_dict[f"{block}.attn1.to_v.weight"], + ], + dim=0, + ) + + # projection + new_state_dict[f"{block}.attn1.out_proj.weight"] = old_state_dict[f"{block}.attn1.to_out.0.weight"] + new_state_dict[f"{block}.attn1.out_proj.bias"] = old_state_dict[f"{block}.attn1.to_out.0.bias"] + + new_state_dict[f"{block}.attn2.out_proj.weight"] = old_state_dict[f"{block}.attn2.to_out.0.weight"] + new_state_dict[f"{block}.attn2.out_proj.bias"] = old_state_dict[f"{block}.attn2.to_out.0.bias"] + # fix the upsample conv blocks which were renamed postconv + for k in new_state_dict: + if "postconv" in k: + old_name = k.replace("postconv", "conv") + new_state_dict[k] = old_state_dict[old_name] + self.load_state_dict(new_state_dict) + class DiffusionModelEncoder(nn.Module): """ diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py index 99975271da..74d15bc6bf 100644 --- a/monai/networks/nets/resnet.py +++ b/monai/networks/nets/resnet.py @@ -46,7 +46,6 @@ "resnet200", ] - resnet_params = { # model_name: (block, layers, shortcut_type, bias_downsample, datasets23) "resnet10": ("basic", [1, 1, 1, 1], "B", False, True), diff --git a/monai/networks/nets/spade_autoencoderkl.py b/monai/networks/nets/spade_autoencoderkl.py index 0949e307b9..294b121c94 100644 --- a/monai/networks/nets/spade_autoencoderkl.py +++ b/monai/networks/nets/spade_autoencoderkl.py @@ -17,9 +17,9 @@ import torch.nn as nn import torch.nn.functional as F -from monai.networks.blocks import Convolution, Upsample +from monai.networks.blocks import Convolution, SpatialAttentionBlock, Upsample from monai.networks.blocks.spade_norm import SPADE -from monai.networks.nets.autoencoderkl import AttentionBlock, Encoder +from monai.networks.nets.autoencoderkl import Encoder from monai.utils import ensure_tuple_rep __all__ = ["SPADEAutoencoderKL"] @@ -195,7 +195,7 @@ def __init__( ) ) blocks.append( - AttentionBlock( + SpatialAttentionBlock( spatial_dims=spatial_dims, num_channels=reversed_block_out_channels[0], norm_num_groups=norm_num_groups, @@ -238,7 +238,7 @@ def __init__( if reversed_attention_levels[i]: blocks.append( - AttentionBlock( + SpatialAttentionBlock( spatial_dims=spatial_dims, num_channels=block_in_ch, norm_num_groups=norm_num_groups, diff --git a/monai/networks/nets/spade_diffusion_model_unet.py b/monai/networks/nets/spade_diffusion_model_unet.py index bffc9c5465..e019d21c11 100644 --- a/monai/networks/nets/spade_diffusion_model_unet.py +++ b/monai/networks/nets/spade_diffusion_model_unet.py @@ -36,24 +36,19 @@ import torch from torch import nn -from monai.networks.blocks import Convolution +from monai.networks.blocks import Convolution, SpatialAttentionBlock from monai.networks.blocks.spade_norm import SPADE from monai.networks.nets.diffusion_model_unet import ( - _AttentionBlock, - _Downsample, - _ResnetBlock, - _SpatialTransformer, - _Upsample, + DiffusionUnetDownsample, + DiffusionUNetResnetBlock, + SpatialTransformer, + WrappedUpsample, get_down_block, get_mid_block, get_timestep_embedding, zero_module, ) -from monai.utils import ensure_tuple_rep, optional_import - -# To install xformers, use pip install xformers==0.0.16rc401 -xops, has_xformers = optional_import("xformers.ops") - +from monai.utils import ensure_tuple_rep __all__ = ["SPADEDiffusionModelUNet"] @@ -120,9 +115,17 @@ def __init__( self.upsample = self.downsample = None if self.up: - self.upsample = _Upsample(spatial_dims, in_channels, use_conv=False) + self.upsample = WrappedUpsample( + spatial_dims=spatial_dims, + mode="nontrainable", + in_channels=in_channels, + out_channels=in_channels, + interp_mode="nearest", + scale_factor=2.0, + align_corners=None, + ) elif down: - self.downsample = _Downsample(spatial_dims, in_channels, use_conv=False) + self.downsample = DiffusionUnetDownsample(spatial_dims, in_channels, use_conv=False) self.time_emb_proj = nn.Linear(temb_channels, self.out_channels) @@ -252,7 +255,7 @@ def __init__( self.upsampler: nn.Module | None if add_upsample: if resblock_updown: - self.upsampler = _ResnetBlock( + self.upsampler = DiffusionUNetResnetBlock( spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels, @@ -262,8 +265,24 @@ def __init__( up=True, ) else: - self.upsampler = _Upsample( - spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels + post_conv = Convolution( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + self.upsampler = WrappedUpsample( + spatial_dims=spatial_dims, + mode="nontrainable", + in_channels=out_channels, + out_channels=out_channels, + interp_mode="nearest", + scale_factor=2.0, + post_conv=post_conv, + align_corners=None, ) else: self.upsampler = None @@ -308,7 +327,6 @@ class SPADEAttnUpBlock(nn.Module): add_upsample: if True add downsample block. resblock_updown: if True use residual blocks for upsampling. num_head_channels: number of channels in each attention head. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. spade_intermediate_channels: number of intermediate channels for SPADE block layer """ @@ -326,7 +344,6 @@ def __init__( add_upsample: bool = True, resblock_updown: bool = False, num_head_channels: int = 1, - use_flash_attention: bool = False, spade_intermediate_channels: int = 128, ) -> None: super().__init__() @@ -351,13 +368,12 @@ def __init__( ) ) attentions.append( - _AttentionBlock( + SpatialAttentionBlock( spatial_dims=spatial_dims, num_channels=out_channels, num_head_channels=num_head_channels, norm_num_groups=norm_num_groups, norm_eps=norm_eps, - use_flash_attention=use_flash_attention, ) ) @@ -367,7 +383,7 @@ def __init__( self.upsampler: nn.Module | None if add_upsample: if resblock_updown: - self.upsampler = _ResnetBlock( + self.upsampler = DiffusionUNetResnetBlock( spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels, @@ -377,8 +393,24 @@ def __init__( up=True, ) else: - self.upsampler = _Upsample( - spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels + post_conv = Convolution( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + self.upsampler = WrappedUpsample( + spatial_dims=spatial_dims, + mode="nontrainable", + in_channels=out_channels, + out_channels=out_channels, + interp_mode="nearest", + scale_factor=2.0, + post_conv=post_conv, + align_corners=None, ) else: self.upsampler = None @@ -427,7 +459,6 @@ class SPADECrossAttnUpBlock(nn.Module): transformer_num_layers: number of layers of Transformer blocks to use. cross_attention_dim: number of context dimensions to use. upcast_attention: if True, upcast attention operations to full precision. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. spade_intermediate_channels: number of intermediate channels for SPADE block layer. """ @@ -448,7 +479,6 @@ def __init__( transformer_num_layers: int = 1, cross_attention_dim: int | None = None, upcast_attention: bool = False, - use_flash_attention: bool = False, spade_intermediate_channels: int = 128, ) -> None: super().__init__() @@ -473,7 +503,7 @@ def __init__( ) ) attentions.append( - _SpatialTransformer( + SpatialTransformer( spatial_dims=spatial_dims, in_channels=out_channels, num_attention_heads=out_channels // num_head_channels, @@ -483,7 +513,6 @@ def __init__( num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, ) ) @@ -493,7 +522,7 @@ def __init__( self.upsampler: nn.Module | None if add_upsample: if resblock_updown: - self.upsampler = _ResnetBlock( + self.upsampler = DiffusionUNetResnetBlock( spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels, @@ -503,8 +532,24 @@ def __init__( up=True, ) else: - self.upsampler = _Upsample( - spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels + post_conv = Convolution( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + self.upsampler = WrappedUpsample( + spatial_dims=spatial_dims, + mode="nontrainable", + in_channels=out_channels, + out_channels=out_channels, + interp_mode="nearest", + scale_factor=2.0, + post_conv=post_conv, + align_corners=None, ) else: self.upsampler = None @@ -549,7 +594,6 @@ def get_spade_up_block( label_nc: int, cross_attention_dim: int | None, upcast_attention: bool = False, - use_flash_attention: bool = False, spade_intermediate_channels: int = 128, ) -> nn.Module: if with_attn: @@ -566,7 +610,6 @@ def get_spade_up_block( add_upsample=add_upsample, resblock_updown=resblock_updown, num_head_channels=num_head_channels, - use_flash_attention=use_flash_attention, spade_intermediate_channels=spade_intermediate_channels, ) elif with_cross_attn: @@ -586,7 +629,6 @@ def get_spade_up_block( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, spade_intermediate_channels=spade_intermediate_channels, ) else: @@ -630,7 +672,6 @@ class SPADEDiffusionModelUNet(nn.Module): num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` classes. upcast_attention: if True, upcast attention operations to full precision. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. spade_intermediate_channels: number of intermediate channels for SPADE block layer """ @@ -652,7 +693,6 @@ def __init__( cross_attention_dim: int | None = None, num_class_embeds: int | None = None, upcast_attention: bool = False, - use_flash_attention: bool = False, spade_intermediate_channels: int = 128, ) -> None: super().__init__() @@ -691,14 +731,6 @@ def __init__( "`num_channels`." ) - if use_flash_attention and not has_xformers: - raise ValueError("use_flash_attention is True but xformers is not installed.") - - if use_flash_attention is True and not torch.cuda.is_available(): - raise ValueError( - "torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU." - ) - self.in_channels = in_channels self.block_out_channels = channels self.out_channels = out_channels @@ -754,7 +786,6 @@ def __init__( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, ) self.down_blocks.append(down_block) @@ -771,7 +802,6 @@ def __init__( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, ) # up @@ -805,7 +835,6 @@ def __init__( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, - use_flash_attention=use_flash_attention, label_nc=label_nc, spade_intermediate_channels=spade_intermediate_channels, ) @@ -890,7 +919,7 @@ def forward( down_block_res_samples = new_down_block_res_samples # 5. mid - h = self.middle_block(hidden_states=h, temb=emb, context=context) + h = self.middle_block(hidden_states=h.contiguous(), temb=emb, context=context) # Additional residual conections for Controlnets if mid_block_additional_residual is not None: diff --git a/monai/networks/nets/transformer.py b/monai/networks/nets/transformer.py index b742c12205..215e8d11a9 100644 --- a/monai/networks/nets/transformer.py +++ b/monai/networks/nets/transformer.py @@ -11,221 +11,14 @@ from __future__ import annotations -import math - import torch import torch.nn as nn -import torch.nn.functional as F -from monai.networks.blocks.mlp import MLPBlock -from monai.utils import optional_import +from monai.networks.blocks import TransformerBlock -xops, has_xformers = optional_import("xformers.ops") __all__ = ["DecoderOnlyTransformer"] -class _SABlock(nn.Module): - """ - NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make - use of this block as support is not guaranteed. For more information see: - https://github.com/Project-MONAI/MONAI/issues/7227 - - A self-attention block, based on: "Dosovitskiy et al., - An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " - - Args: - hidden_size: dimension of hidden layer. - num_heads: number of attention heads. - dropout_rate: dropout ratio. Defaults to no dropout. - qkv_bias: bias term for the qkv linear layer. - causal: whether to use causal attention. - sequence_length: if causal is True, it is necessary to specify the sequence length. - with_cross_attention: Whether to use cross attention for conditioning. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - """ - - def __init__( - self, - hidden_size: int, - num_heads: int, - dropout_rate: float = 0.0, - qkv_bias: bool = False, - causal: bool = False, - sequence_length: int | None = None, - with_cross_attention: bool = False, - use_flash_attention: bool = False, - ) -> None: - super().__init__() - self.hidden_size = hidden_size - self.num_heads = num_heads - self.head_dim = hidden_size // num_heads - self.scale = 1.0 / math.sqrt(self.head_dim) - self.causal = causal - self.sequence_length = sequence_length - self.with_cross_attention = with_cross_attention - self.use_flash_attention = use_flash_attention - - if not (0 <= dropout_rate <= 1): - raise ValueError("dropout_rate should be between 0 and 1.") - self.dropout_rate = dropout_rate - - if hidden_size % num_heads != 0: - raise ValueError("hidden size should be divisible by num_heads.") - - if causal and sequence_length is None: - raise ValueError("sequence_length is necessary for causal attention.") - - if use_flash_attention and not has_xformers: - raise ValueError("use_flash_attention is True but xformers is not installed.") - - # key, query, value projections - self.to_q = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) - self.to_k = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) - self.to_v = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) - - # regularization - self.drop_weights = nn.Dropout(dropout_rate) - self.drop_output = nn.Dropout(dropout_rate) - - # output projection - self.out_proj = nn.Linear(hidden_size, hidden_size) - - if causal and sequence_length is not None: - # causal mask to ensure that attention is only applied to the left in the input sequence - self.register_buffer( - "causal_mask", - torch.tril(torch.ones(sequence_length, sequence_length)).view(1, 1, sequence_length, sequence_length), - ) - self.causal_mask: torch.Tensor - - def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: - b, t, c = x.size() # batch size, sequence length, embedding dimensionality (hidden_size) - - # calculate query, key, values for all heads in batch and move head forward to be the batch dim - query = self.to_q(x) - - kv = context if context is not None else x - _, kv_t, _ = kv.size() - key = self.to_k(kv) - value = self.to_v(kv) - - query = query.view(b, t, self.num_heads, c // self.num_heads) # (b, t, nh, hs) - key = key.view(b, kv_t, self.num_heads, c // self.num_heads) # (b, kv_t, nh, hs) - value = value.view(b, kv_t, self.num_heads, c // self.num_heads) # (b, kv_t, nh, hs) - y: torch.Tensor - if self.use_flash_attention: - query = query.contiguous() - key = key.contiguous() - value = value.contiguous() - y = xops.memory_efficient_attention( - query=query, - key=key, - value=value, - scale=self.scale, - p=self.dropout_rate, - attn_bias=xops.LowerTriangularMask() if self.causal else None, - ) - - else: - query = query.transpose(1, 2) # (b, nh, t, hs) - key = key.transpose(1, 2) # (b, nh, kv_t, hs) - value = value.transpose(1, 2) # (b, nh, kv_t, hs) - - # manual implementation of attention - query = query * self.scale - attention_scores = query @ key.transpose(-2, -1) - - if self.causal: - attention_scores = attention_scores.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf")) - - attention_probs = F.softmax(attention_scores, dim=-1) - attention_probs = self.drop_weights(attention_probs) - y = attention_probs @ value # (b, nh, t, kv_t) x (b, nh, kv_t, hs) -> (b, nh, t, hs) - - y = y.transpose(1, 2) # (b, nh, t, hs) -> (b, t, nh, hs) - - y = y.contiguous().view(b, t, c) # re-assemble all head outputs side by side - - y = self.out_proj(y) - y = self.drop_output(y) - return y - - -class _TransformerBlock(nn.Module): - """ - NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make - use of this block as support is not guaranteed. For more information see: - https://github.com/Project-MONAI/MONAI/issues/7227 - - A transformer block, based on: "Dosovitskiy et al., - An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " - - Args: - hidden_size: dimension of hidden layer. - mlp_dim: dimension of feedforward layer. - num_heads: number of attention heads. - dropout_rate: faction of the input units to drop. - qkv_bias: apply bias term for the qkv linear layer - causal: whether to use causal attention. - sequence_length: if causal is True, it is necessary to specify the sequence length. - with_cross_attention: Whether to use cross attention for conditioning. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. - """ - - def __init__( - self, - hidden_size: int, - mlp_dim: int, - num_heads: int, - dropout_rate: float = 0.0, - qkv_bias: bool = False, - causal: bool = False, - sequence_length: int | None = None, - with_cross_attention: bool = False, - use_flash_attention: bool = False, - ) -> None: - self.with_cross_attention = with_cross_attention - super().__init__() - - if not (0 <= dropout_rate <= 1): - raise ValueError("dropout_rate should be between 0 and 1.") - - if hidden_size % num_heads != 0: - raise ValueError("hidden_size should be divisible by num_heads.") - - self.norm1 = nn.LayerNorm(hidden_size) - self.attn = _SABlock( - hidden_size=hidden_size, - num_heads=num_heads, - dropout_rate=dropout_rate, - qkv_bias=qkv_bias, - causal=causal, - sequence_length=sequence_length, - use_flash_attention=use_flash_attention, - ) - - if self.with_cross_attention: - self.norm2 = nn.LayerNorm(hidden_size) - self.cross_attn = _SABlock( - hidden_size=hidden_size, - num_heads=num_heads, - dropout_rate=dropout_rate, - qkv_bias=qkv_bias, - with_cross_attention=with_cross_attention, - causal=False, - use_flash_attention=use_flash_attention, - ) - self.norm3 = nn.LayerNorm(hidden_size) - self.mlp = MLPBlock(hidden_size, mlp_dim, dropout_rate) - - def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: - x = x + self.attn(self.norm1(x)) - if self.with_cross_attention: - x = x + self.cross_attn(self.norm2(x), context=context) - x = x + self.mlp(self.norm3(x)) - return x - - class AbsolutePositionalEmbedding(nn.Module): """Absolute positional embedding. @@ -258,7 +51,6 @@ class DecoderOnlyTransformer(nn.Module): attn_layers_heads: Number of attention heads. with_cross_attention: Whether to use cross attention for conditioning. embedding_dropout_rate: Dropout rate for the embedding. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. """ def __init__( @@ -270,7 +62,6 @@ def __init__( attn_layers_heads: int, with_cross_attention: bool = False, embedding_dropout_rate: float = 0.0, - use_flash_attention: bool = False, ) -> None: super().__init__() self.num_tokens = num_tokens @@ -286,7 +77,7 @@ def __init__( self.blocks = nn.ModuleList( [ - _TransformerBlock( + TransformerBlock( hidden_size=attn_layers_dim, mlp_dim=attn_layers_dim * 4, num_heads=attn_layers_heads, @@ -295,7 +86,6 @@ def __init__( causal=True, sequence_length=max_seq_len, with_cross_attention=with_cross_attention, - use_flash_attention=use_flash_attention, ) for _ in range(attn_layers_depth) ] @@ -312,3 +102,56 @@ def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch x = block(x, context=context) logits: torch.Tensor = self.to_logits(x) return logits + + def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None: + """ + Load a state dict from a DecoderOnlyTransformer trained with + [MONAI Generative](https://github.com/Project-MONAI/GenerativeModels). + + Args: + old_state_dict: state dict from the old DecoderOnlyTransformer model. + """ + + new_state_dict = self.state_dict() + # if all keys match, just load the state dict + if all(k in new_state_dict for k in old_state_dict): + print("All keys match, loading state dict.") + self.load_state_dict(old_state_dict) + return + + if verbose: + # print all new_state_dict keys that are not in old_state_dict + for k in new_state_dict: + if k not in old_state_dict: + print(f"key {k} not found in old state dict") + # and vice versa + print("----------------------------------------------") + for k in old_state_dict: + if k not in new_state_dict: + print(f"key {k} not found in new state dict") + + # copy over all matching keys + for k in new_state_dict: + if k in old_state_dict: + new_state_dict[k] = old_state_dict[k] + + # fix the attention blocks + attention_blocks = [k.replace(".attn.qkv.weight", "") for k in new_state_dict if "attn.qkv.weight" in k] + for block in attention_blocks: + new_state_dict[f"{block}.attn.qkv.weight"] = torch.concat( + [ + old_state_dict[f"{block}.attn.to_q.weight"], + old_state_dict[f"{block}.attn.to_k.weight"], + old_state_dict[f"{block}.attn.to_v.weight"], + ], + dim=0, + ) + + # fix the renamed norm blocks first norm2 -> norm_cross_attention , norm3 -> norm2 + for k in old_state_dict: + if "norm2" in k: + new_state_dict[k.replace("norm2", "norm_cross_attn")] = old_state_dict[k] + if "norm3" in k: + new_state_dict[k.replace("norm3", "norm2")] = old_state_dict[k] + + self.load_state_dict(new_state_dict) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index ecf237a2ff..6a97434215 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -987,6 +987,7 @@ def scale_batch_size(input_shape: Sequence[int], scale_num: int): inputs=input_placeholder, enabled_precisions=convert_precision, device=target_device, + ir="torchscript", **kwargs, ) diff --git a/monai/utils/misc.py b/monai/utils/misc.py index 886103a0ab..84dd3ad1f6 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -527,7 +527,7 @@ def doc_images() -> str | None: @staticmethod def algo_hash() -> str | None: - return os.environ.get("MONAI_ALGO_HASH", "4403f94") + return os.environ.get("MONAI_ALGO_HASH", "e4cf5a1") @staticmethod def trace_transform() -> str | None: diff --git a/requirements-dev.txt b/requirements-dev.txt index b207b56b19..ce28d3ebe2 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -46,13 +46,13 @@ pynrrd pre-commit pydicom h5py -nni; platform_system == "Linux" and "arm" not in platform_machine and "aarch" not in platform_machine +nni==2.10.1; platform_system == "Linux" and "arm" not in platform_machine and "aarch" not in platform_machine optuna git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded onnx>=1.13.0 onnxruntime; python_version <= '3.10' typeguard<3 # https://github.com/microsoft/nni/issues/5457 -filelock!=3.12.0 # https://github.com/microsoft/nni/issues/5523 +filelock<3.12.0 # https://github.com/microsoft/nni/issues/5523 zarr lpips==0.1.4 nvidia-ml-py diff --git a/tests/test_attentionunet.py b/tests/test_attentionunet.py index 83f6cabc5e..6a577f763f 100644 --- a/tests/test_attentionunet.py +++ b/tests/test_attentionunet.py @@ -14,11 +14,17 @@ import unittest import torch +import torch.nn as nn import monai.networks.nets.attentionunet as att from tests.utils import skip_if_no_cuda, skip_if_quick +def get_net_parameters(net: nn.Module) -> int: + """Returns the total number of parameters in a Module.""" + return sum(param.numel() for param in net.parameters()) + + class TestAttentionUnet(unittest.TestCase): def test_attention_block(self): @@ -50,6 +56,20 @@ def test_attentionunet(self): self.assertEqual(output.shape[0], input.shape[0]) self.assertEqual(output.shape[1], 2) + def test_attentionunet_kernel_size(self): + args_dict = { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 2, + "channels": (3, 4, 5), + "up_kernel_size": 5, + "strides": (1, 2), + } + model_a = att.AttentionUnet(**args_dict, kernel_size=5) + model_b = att.AttentionUnet(**args_dict, kernel_size=7) + self.assertEqual(get_net_parameters(model_a), 3534) + self.assertEqual(get_net_parameters(model_b), 5574) + @skip_if_no_cuda def test_attentionunet_gpu(self): for dims in [2, 3]: diff --git a/tests/test_autoencoderkl.py b/tests/test_autoencoderkl.py index 3cc671a1d0..d15cb79084 100644 --- a/tests/test_autoencoderkl.py +++ b/tests/test_autoencoderkl.py @@ -11,20 +11,26 @@ from __future__ import annotations +import os +import tempfile import unittest +from unittest import skipUnless import torch from parameterized import parameterized +from monai.apps import download_url from monai.networks import eval_mode from monai.networks.nets import AutoencoderKL from monai.utils import optional_import -from tests.utils import SkipIfBeforePyTorchVersion +from tests.utils import SkipIfBeforePyTorchVersion, skip_if_downloading_fails, testing_data_config tqdm, has_tqdm = optional_import("tqdm", name="tqdm") -einops, has_einops = optional_import("einops") +_, has_einops = optional_import("einops") + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + CASES_NO_ATTENTION = [ [ { @@ -299,6 +305,33 @@ def test_shape_decode_convtranspose_and_checkpointing(self): result = net.decode(torch.randn(latent_shape).to(device)) self.assertEqual(result.shape, expected_input_shape) + @skipUnless(has_einops, "Requires einops") + def test_compatibility_with_monai_generative(self): + # test loading weights from a model saved in MONAI Generative, version 0.2.3 + with skip_if_downloading_fails(): + net = AutoencoderKL( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(4, 4, 4), + latent_channels=4, + attention_levels=(False, False, True), + num_res_blocks=1, + norm_num_groups=4, + ).to(device) + + tmpdir = tempfile.mkdtemp() + key = "autoencoderkl_monai_generative_weights" + url = testing_data_config("models", key, "url") + hash_type = testing_data_config("models", key, "hash_type") + hash_val = testing_data_config("models", key, "hash_val") + filename = "autoencoderkl_monai_generative_weights.pt" + + weight_path = os.path.join(tmpdir, filename) + download_url(url=url, filepath=weight_path, hash_val=hash_val, hash_type=hash_type) + + net.load_old_state_dict(torch.load(weight_path), verbose=False) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_bundle_ckpt_export.py b/tests/test_bundle_ckpt_export.py index 8f376a06d5..cfcadcfc4c 100644 --- a/tests/test_bundle_ckpt_export.py +++ b/tests/test_bundle_ckpt_export.py @@ -72,9 +72,9 @@ def test_export(self, key_in_ckpt, use_trace): _, metadata, extra_files = load_net_with_metadata( ts_file, more_extra_files=["inference.json", "def_args.json"] ) - self.assertTrue("schema" in metadata) - self.assertTrue("meta_file" in json.loads(extra_files["def_args.json"])) - self.assertTrue("network_def" in json.loads(extra_files["inference.json"])) + self.assertIn("schema", metadata) + self.assertIn("meta_file", json.loads(extra_files["def_args.json"])) + self.assertIn("network_def", json.loads(extra_files["inference.json"])) @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_default_value(self, key_in_ckpt, use_trace): diff --git a/tests/test_bundle_get_data.py b/tests/test_bundle_get_data.py index 605b3945bb..f84713fbe3 100644 --- a/tests/test_bundle_get_data.py +++ b/tests/test_bundle_get_data.py @@ -51,8 +51,8 @@ class TestGetBundleData(unittest.TestCase): def test_get_all_bundles_list(self, params): with skip_if_downloading_fails(): output = get_all_bundles_list(**params) - self.assertTrue(isinstance(output, list)) - self.assertTrue(isinstance(output[0], tuple)) + self.assertIsInstance(output, list) + self.assertIsInstance(output[0], tuple) self.assertTrue(len(output[0]) == 2) @parameterized.expand([TEST_CASE_1, TEST_CASE_5]) @@ -60,16 +60,17 @@ def test_get_all_bundles_list(self, params): def test_get_bundle_versions(self, params): with skip_if_downloading_fails(): output = get_bundle_versions(**params) - self.assertTrue(isinstance(output, dict)) - self.assertTrue("latest_version" in output and "all_versions" in output) - self.assertTrue("0.1.0" in output["all_versions"]) + self.assertIsInstance(output, dict) + self.assertIn("latest_version", output) + self.assertIn("all_versions", output) + self.assertIn("0.1.0", output["all_versions"]) @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) @skip_if_quick def test_get_bundle_info(self, params): with skip_if_downloading_fails(): output = get_bundle_info(**params) - self.assertTrue(isinstance(output, dict)) + self.assertIsInstance(output, dict) for key in ["id", "name", "size", "download_count", "browser_download_url"]: self.assertTrue(key in output) @@ -78,7 +79,7 @@ def test_get_bundle_info(self, params): def test_get_bundle_info_monaihosting(self, params): with skip_if_downloading_fails(): output = get_bundle_info(**params) - self.assertTrue(isinstance(output, dict)) + self.assertIsInstance(output, dict) for key in ["name", "browser_download_url"]: self.assertTrue(key in output) diff --git a/tests/test_bundle_trt_export.py b/tests/test_bundle_trt_export.py index 47034852ef..833a0ca1dc 100644 --- a/tests/test_bundle_trt_export.py +++ b/tests/test_bundle_trt_export.py @@ -91,9 +91,9 @@ def test_trt_export(self, convert_precision, input_shape, dynamic_batch): _, metadata, extra_files = load_net_with_metadata( ts_file, more_extra_files=["inference.json", "def_args.json"] ) - self.assertTrue("schema" in metadata) - self.assertTrue("meta_file" in json.loads(extra_files["def_args.json"])) - self.assertTrue("network_def" in json.loads(extra_files["inference.json"])) + self.assertIn("schema", metadata) + self.assertIn("meta_file", json.loads(extra_files["def_args.json"])) + self.assertIn("network_def", json.loads(extra_files["inference.json"])) @parameterized.expand([TEST_CASE_3, TEST_CASE_4]) @unittest.skipUnless( @@ -129,9 +129,9 @@ def test_onnx_trt_export(self, convert_precision, input_shape, dynamic_batch): _, metadata, extra_files = load_net_with_metadata( ts_file, more_extra_files=["inference.json", "def_args.json"] ) - self.assertTrue("schema" in metadata) - self.assertTrue("meta_file" in json.loads(extra_files["def_args.json"])) - self.assertTrue("network_def" in json.loads(extra_files["inference.json"])) + self.assertIn("schema", metadata) + self.assertIn("meta_file", json.loads(extra_files["def_args.json"])) + self.assertIn("network_def", json.loads(extra_files["inference.json"])) if __name__ == "__main__": diff --git a/tests/test_bundle_workflow.py b/tests/test_bundle_workflow.py index 9a276b577f..1727fcdf53 100644 --- a/tests/test_bundle_workflow.py +++ b/tests/test_bundle_workflow.py @@ -138,11 +138,11 @@ def test_train_config(self, config_file): self.assertListEqual(trainer.check_properties(), []) # test read / write the properties dataset = trainer.train_dataset - self.assertTrue(isinstance(dataset, Dataset)) + self.assertIsInstance(dataset, Dataset) inferer = trainer.train_inferer - self.assertTrue(isinstance(inferer, SimpleInferer)) + self.assertIsInstance(inferer, SimpleInferer) # test optional properties get - self.assertTrue(trainer.train_key_metric is None) + self.assertIsNone(trainer.train_key_metric) trainer.train_dataset = deepcopy(dataset) trainer.train_inferer = deepcopy(inferer) # test optional properties set diff --git a/tests/test_clip_intensity_percentilesd.py b/tests/test_clip_intensity_percentilesd.py index fa727b6adb..ed4fc588cb 100644 --- a/tests/test_clip_intensity_percentilesd.py +++ b/tests/test_clip_intensity_percentilesd.py @@ -96,7 +96,7 @@ def test_channel_wise(self, p): for i, c in enumerate(im): lower, upper = percentile(c, (5, 95)) expected = clip(c, lower, upper) - assert_allclose(result[key][i], p(expected), type_test="tensor", rtol=1e-4, atol=0) + assert_allclose(result[key][i], p(expected), type_test="tensor", rtol=1e-3, atol=0) def test_ill_sharpness_factor(self): key = "img" diff --git a/tests/test_component_store.py b/tests/test_component_store.py index 424eceb3d1..7e7c6dd19d 100644 --- a/tests/test_component_store.py +++ b/tests/test_component_store.py @@ -48,17 +48,17 @@ def test_add2(self): self.cs.add("test_obj2", "Test object", test_obj2) self.assertEqual(len(self.cs), 2) - self.assertTrue("test_obj1" in self.cs) - self.assertTrue("test_obj2" in self.cs) + self.assertIn("test_obj1", self.cs) + self.assertIn("test_obj2", self.cs) def test_add_def(self): - self.assertFalse("test_func" in self.cs) + self.assertNotIn("test_func", self.cs) @self.cs.add_def("test_func", "Test function") def test_func(): return 123 - self.assertTrue("test_func" in self.cs) + self.assertIn("test_func", self.cs) self.assertEqual(len(self.cs), 1) self.assertEqual(list(self.cs), [("test_func", test_func)]) diff --git a/tests/test_compute_ho_ver_maps.py b/tests/test_compute_ho_ver_maps.py index bbd5230f04..6e46cf2b1e 100644 --- a/tests/test_compute_ho_ver_maps.py +++ b/tests/test_compute_ho_ver_maps.py @@ -67,8 +67,8 @@ class ComputeHoVerMapsTests(unittest.TestCase): def test_horizontal_certical_maps(self, in_type, arguments, mask, hv_mask): input_image = in_type(mask) result = ComputeHoVerMaps(**arguments)(input_image) - self.assertTrue(isinstance(result, torch.Tensor)) - self.assertTrue(str(result.dtype).split(".")[1] == arguments.get("dtype", "float32")) + self.assertIsInstance(result, torch.Tensor) + self.assertEqual(str(result.dtype).split(".")[1], arguments.get("dtype", "float32")) assert_allclose(result, hv_mask, type_test="tensor") diff --git a/tests/test_compute_ho_ver_maps_d.py b/tests/test_compute_ho_ver_maps_d.py index 7b5ac0d9d7..0734e2e731 100644 --- a/tests/test_compute_ho_ver_maps_d.py +++ b/tests/test_compute_ho_ver_maps_d.py @@ -71,8 +71,8 @@ def test_horizontal_certical_maps(self, in_type, arguments, mask, hv_mask): for k in mask.keys(): input_image[k] = in_type(mask[k]) result = ComputeHoVerMapsd(keys="mask", **arguments)(input_image)[hv_key] - self.assertTrue(isinstance(result, torch.Tensor)) - self.assertTrue(str(result.dtype).split(".")[1] == arguments.get("dtype", "float32")) + self.assertIsInstance(result, torch.Tensor) + self.assertEqual(str(result.dtype).split(".")[1], arguments.get("dtype", "float32")) assert_allclose(result, hv_mask[hv_key], type_test="tensor") diff --git a/tests/test_compute_regression_metrics.py b/tests/test_compute_regression_metrics.py index a8b7f03e47..c407ab6ba6 100644 --- a/tests/test_compute_regression_metrics.py +++ b/tests/test_compute_regression_metrics.py @@ -70,22 +70,24 @@ def test_shape_reduction(self): mt = mt_fn(reduction="mean") mt(in_tensor, in_tensor) out_tensor = mt.aggregate() - self.assertTrue(len(out_tensor.shape) == 1) + self.assertEqual(len(out_tensor.shape), 1) mt = mt_fn(reduction="sum") mt(in_tensor, in_tensor) out_tensor = mt.aggregate() - self.assertTrue(len(out_tensor.shape) == 0) + self.assertEqual(len(out_tensor.shape), 0) mt = mt_fn(reduction="sum") # test reduction arg overriding mt(in_tensor, in_tensor) out_tensor = mt.aggregate(reduction="mean_channel") - self.assertTrue(len(out_tensor.shape) == 1 and out_tensor.shape[0] == batch) + self.assertEqual(len(out_tensor.shape), 1) + self.assertEqual(out_tensor.shape[0], batch) mt = mt_fn(reduction="sum_channel") mt(in_tensor, in_tensor) out_tensor = mt.aggregate() - self.assertTrue(len(out_tensor.shape) == 1 and out_tensor.shape[0] == batch) + self.assertEqual(len(out_tensor.shape), 1) + self.assertEqual(out_tensor.shape[0], batch) def test_compare_numpy(self): set_determinism(seed=123) diff --git a/tests/test_concat_itemsd.py b/tests/test_concat_itemsd.py index 64c5d6e255..564ddf5c1f 100644 --- a/tests/test_concat_itemsd.py +++ b/tests/test_concat_itemsd.py @@ -30,7 +30,7 @@ def test_tensor_values(self): "img2": torch.tensor([[0, 1], [1, 2]], device=device), } result = ConcatItemsd(keys=["img1", "img2"], name="cat_img")(input_data) - self.assertTrue("cat_img" in result) + self.assertIn("cat_img", result) result["cat_img"] += 1 assert_allclose(result["img1"], torch.tensor([[0, 1], [1, 2]], device=device)) assert_allclose(result["cat_img"], torch.tensor([[1, 2], [2, 3], [1, 2], [2, 3]], device=device)) @@ -42,8 +42,8 @@ def test_metatensor_values(self): "img2": MetaTensor([[0, 1], [1, 2]], device=device), } result = ConcatItemsd(keys=["img1", "img2"], name="cat_img")(input_data) - self.assertTrue("cat_img" in result) - self.assertTrue(isinstance(result["cat_img"], MetaTensor)) + self.assertIn("cat_img", result) + self.assertIsInstance(result["cat_img"], MetaTensor) self.assertEqual(result["img1"].meta, result["cat_img"].meta) result["cat_img"] += 1 assert_allclose(result["img1"], torch.tensor([[0, 1], [1, 2]], device=device)) @@ -52,7 +52,7 @@ def test_metatensor_values(self): def test_numpy_values(self): input_data = {"img1": np.array([[0, 1], [1, 2]]), "img2": np.array([[0, 1], [1, 2]])} result = ConcatItemsd(keys=["img1", "img2"], name="cat_img")(input_data) - self.assertTrue("cat_img" in result) + self.assertIn("cat_img", result) result["cat_img"] += 1 np.testing.assert_allclose(result["img1"], np.array([[0, 1], [1, 2]])) np.testing.assert_allclose(result["cat_img"], np.array([[1, 2], [2, 3], [1, 2], [2, 3]])) diff --git a/tests/test_config_parser.py b/tests/test_config_parser.py index cc890a0522..cf1edc8f08 100644 --- a/tests/test_config_parser.py +++ b/tests/test_config_parser.py @@ -185,7 +185,7 @@ def test_function(self, config): if id in ("compute", "cls_compute"): parser[f"{id}#_mode_"] = "callable" func = parser.get_parsed_content(id=id) - self.assertTrue(id in parser.ref_resolver.resolved_content) + self.assertIn(id, parser.ref_resolver.resolved_content) if id == "error_func": with self.assertRaises(TypeError): func(1, 2) diff --git a/tests/test_controlnet.py b/tests/test_controlnet.py index 07dfa2e49b..05ceb69fa3 100644 --- a/tests/test_controlnet.py +++ b/tests/test_controlnet.py @@ -12,13 +12,16 @@ from __future__ import annotations import unittest +from unittest import skipUnless import torch from parameterized import parameterized from monai.networks import eval_mode from monai.networks.nets.controlnet import ControlNet +from monai.utils import optional_import +_, has_einops = optional_import("einops") UNCOND_CASES_2D = [ [ { @@ -147,6 +150,7 @@ class TestControlNet(unittest.TestCase): @parameterized.expand(UNCOND_CASES_2D + UNCOND_CASES_3D) + @skipUnless(has_einops, "Requires einops") def test_shape_unconditioned_models(self, input_param, expected_output_shape): input_param["conditioning_embedding_in_channels"] = input_param["in_channels"] input_param["conditioning_embedding_num_channels"] = (input_param["channels"][0],) @@ -160,6 +164,7 @@ def test_shape_unconditioned_models(self, input_param, expected_output_shape): self.assertEqual(result[1].shape, expected_output_shape) @parameterized.expand(COND_CASES_2D) + @skipUnless(has_einops, "Requires einops") def test_shape_conditioned_models(self, input_param, expected_output_shape): input_param["conditioning_embedding_in_channels"] = input_param["in_channels"] input_param["conditioning_embedding_num_channels"] = (input_param["channels"][0],) diff --git a/tests/test_controlnet_inferers.py b/tests/test_controlnet_inferers.py index 1f675537dc..96e707acb5 100644 --- a/tests/test_controlnet_inferers.py +++ b/tests/test_controlnet_inferers.py @@ -12,6 +12,7 @@ from __future__ import annotations import unittest +from unittest import skipUnless import torch from parameterized import parameterized @@ -29,6 +30,8 @@ from monai.utils import optional_import _, has_scipy = optional_import("scipy") +_, has_einops = optional_import("einops") + CNDM_TEST_CASES = [ [ @@ -443,6 +446,7 @@ class ControlNetTestDiffusionSamplingInferer(unittest.TestCase): @parameterized.expand(CNDM_TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_call(self, model_params, controlnet_params, input_shape): model = DiffusionModelUNet(**model_params) controlnet = ControlNet(**controlnet_params) @@ -464,6 +468,7 @@ def test_call(self, model_params, controlnet_params, input_shape): self.assertEqual(sample.shape, input_shape) @parameterized.expand(CNDM_TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_sample_intermediates(self, model_params, controlnet_params, input_shape): model = DiffusionModelUNet(**model_params) controlnet = ControlNet(**controlnet_params) @@ -489,6 +494,7 @@ def test_sample_intermediates(self, model_params, controlnet_params, input_shape self.assertEqual(len(intermediates), 10) @parameterized.expand(CNDM_TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_ddpm_sampler(self, model_params, controlnet_params, input_shape): model = DiffusionModelUNet(**model_params) controlnet = ControlNet(**controlnet_params) @@ -514,6 +520,7 @@ def test_ddpm_sampler(self, model_params, controlnet_params, input_shape): self.assertEqual(len(intermediates), 10) @parameterized.expand(CNDM_TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_ddim_sampler(self, model_params, controlnet_params, input_shape): model = DiffusionModelUNet(**model_params) controlnet = ControlNet(**controlnet_params) @@ -539,6 +546,7 @@ def test_ddim_sampler(self, model_params, controlnet_params, input_shape): self.assertEqual(len(intermediates), 10) @parameterized.expand(CNDM_TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_sampler_conditioned(self, model_params, controlnet_params, input_shape): model_params["with_conditioning"] = True model_params["cross_attention_dim"] = 3 @@ -568,6 +576,7 @@ def test_sampler_conditioned(self, model_params, controlnet_params, input_shape) self.assertEqual(len(intermediates), 10) @parameterized.expand(CNDM_TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_get_likelihood(self, model_params, controlnet_params, input_shape): model = DiffusionModelUNet(**model_params) device = "cuda:0" if torch.cuda.is_available() else "cpu" @@ -604,6 +613,7 @@ def test_normal_cdf(self): torch.testing.assert_allclose(cdf_approx, cdf_true, atol=1e-3, rtol=1e-5) @parameterized.expand(CNDM_TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_sampler_conditioned_concat(self, model_params, controlnet_params, input_shape): # copy the model_params dict to prevent from modifying test cases model_params = model_params.copy() @@ -642,6 +652,7 @@ def test_sampler_conditioned_concat(self, model_params, controlnet_params, input class LatentControlNetTestDiffusionSamplingInferer(unittest.TestCase): @parameterized.expand(LATENT_CNDM_TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_prediction_shape( self, ae_model_type, @@ -708,6 +719,7 @@ def test_prediction_shape( self.assertEqual(prediction.shape, latent_shape) @parameterized.expand(LATENT_CNDM_TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_sample_shape( self, ae_model_type, @@ -770,6 +782,7 @@ def test_sample_shape( self.assertEqual(sample.shape, input_shape) @parameterized.expand(LATENT_CNDM_TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_sample_intermediates( self, ae_model_type, @@ -837,6 +850,7 @@ def test_sample_intermediates( self.assertEqual(intermediates[0].shape, input_shape) @parameterized.expand(LATENT_CNDM_TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_get_likelihoods( self, ae_model_type, @@ -904,6 +918,7 @@ def test_get_likelihoods( self.assertEqual(intermediates[0].shape, latent_shape) @parameterized.expand(LATENT_CNDM_TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_resample_likelihoods( self, ae_model_type, @@ -973,6 +988,7 @@ def test_resample_likelihoods( self.assertEqual(intermediates[0].shape[2:], input_shape[2:]) @parameterized.expand(LATENT_CNDM_TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_prediction_shape_conditioned_concat( self, ae_model_type, @@ -1053,6 +1069,7 @@ def test_prediction_shape_conditioned_concat( self.assertEqual(prediction.shape, latent_shape) @parameterized.expand(LATENT_CNDM_TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_sample_shape_conditioned_concat( self, ae_model_type, @@ -1128,6 +1145,7 @@ def test_sample_shape_conditioned_concat( self.assertEqual(sample.shape, input_shape) @parameterized.expand(LATENT_CNDM_TEST_CASES_DIFF_SHAPES) + @skipUnless(has_einops, "Requires einops") def test_sample_shape_different_latents( self, ae_model_type, @@ -1203,6 +1221,7 @@ def test_sample_shape_different_latents( ) self.assertEqual(prediction.shape, latent_shape) + @skipUnless(has_einops, "Requires einops") def test_incompatible_spade_setup(self): stage_1 = SPADEAutoencoderKL( spatial_dims=2, diff --git a/tests/test_crossattention.py b/tests/test_crossattention.py new file mode 100644 index 0000000000..4ab0ab1823 --- /dev/null +++ b/tests/test_crossattention.py @@ -0,0 +1,131 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from unittest import skipUnless + +import numpy as np +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.blocks.crossattention import CrossAttentionBlock +from monai.networks.layers.factories import RelPosEmbedding +from monai.utils import optional_import + +einops, has_einops = optional_import("einops") + +TEST_CASE_CABLOCK = [] +for dropout_rate in np.linspace(0, 1, 4): + for hidden_size in [360, 480, 600, 768]: + for num_heads in [4, 6, 8, 12]: + for rel_pos_embedding in [None, RelPosEmbedding.DECOMPOSED]: + for input_size in [(16, 32), (8, 8, 8)]: + test_case = [ + { + "hidden_size": hidden_size, + "num_heads": num_heads, + "dropout_rate": dropout_rate, + "rel_pos_embedding": rel_pos_embedding, + "input_size": input_size, + }, + (2, 512, hidden_size), + (2, 512, hidden_size), + ] + TEST_CASE_CABLOCK.append(test_case) + + +class TestResBlock(unittest.TestCase): + + @parameterized.expand(TEST_CASE_CABLOCK) + @skipUnless(has_einops, "Requires einops") + def test_shape(self, input_param, input_shape, expected_shape): + net = CrossAttentionBlock(**input_param) + with eval_mode(net): + result = net(torch.randn(input_shape), context=torch.randn(2, 512, input_param["hidden_size"])) + self.assertEqual(result.shape, expected_shape) + + def test_ill_arg(self): + with self.assertRaises(ValueError): + CrossAttentionBlock(hidden_size=128, num_heads=12, dropout_rate=6.0) + + with self.assertRaises(ValueError): + CrossAttentionBlock(hidden_size=620, num_heads=8, dropout_rate=0.4) + + @skipUnless(has_einops, "Requires einops") + def test_attention_dim_not_multiple_of_heads(self): + with self.assertRaises(ValueError): + CrossAttentionBlock(hidden_size=128, num_heads=3, dropout_rate=0.1) + + @skipUnless(has_einops, "Requires einops") + def test_inner_dim_different(self): + CrossAttentionBlock(hidden_size=128, num_heads=4, dropout_rate=0.1, dim_head=30) + + def test_causal_no_sequence_length(self): + with self.assertRaises(ValueError): + CrossAttentionBlock(hidden_size=128, num_heads=4, dropout_rate=0.1, causal=True) + + @skipUnless(has_einops, "Requires einops") + def test_causal(self): + block = CrossAttentionBlock( + hidden_size=128, num_heads=1, dropout_rate=0.1, causal=True, sequence_length=16, save_attn=True + ) + input_shape = (1, 16, 128) + block(torch.randn(input_shape)) + # check upper triangular part of the attention matrix is zero + assert torch.triu(block.att_mat, diagonal=1).sum() == 0 + + @skipUnless(has_einops, "Requires einops") + def test_context_input(self): + block = CrossAttentionBlock( + hidden_size=128, num_heads=1, dropout_rate=0.1, causal=True, sequence_length=16, context_input_size=12 + ) + input_shape = (1, 16, 128) + block(torch.randn(input_shape), context=torch.randn(1, 3, 12)) + + @skipUnless(has_einops, "Requires einops") + def test_context_wrong_input_size(self): + block = CrossAttentionBlock( + hidden_size=128, num_heads=1, dropout_rate=0.1, causal=True, sequence_length=16, context_input_size=12 + ) + input_shape = (1, 16, 128) + with self.assertRaises(RuntimeError): + block(torch.randn(input_shape), context=torch.randn(1, 3, 24)) + + @skipUnless(has_einops, "Requires einops") + def test_access_attn_matrix(self): + # input format + hidden_size = 128 + num_heads = 2 + dropout_rate = 0 + input_shape = (2, 256, hidden_size) + + # be not able to access the matrix + no_matrix_acess_blk = CrossAttentionBlock( + hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate + ) + no_matrix_acess_blk(torch.randn(input_shape)) + assert isinstance(no_matrix_acess_blk.att_mat, torch.Tensor) + # no of elements is zero + assert no_matrix_acess_blk.att_mat.nelement() == 0 + + # be able to acess the attention matrix + matrix_acess_blk = CrossAttentionBlock( + hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, save_attn=True + ) + matrix_acess_blk(torch.randn(input_shape)) + assert matrix_acess_blk.att_mat.shape == (input_shape[0], input_shape[0], input_shape[1], input_shape[1]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_cucim_dict_transform.py b/tests/test_cucim_dict_transform.py index d2dcc6aa5f..3c5703a34c 100644 --- a/tests/test_cucim_dict_transform.py +++ b/tests/test_cucim_dict_transform.py @@ -80,8 +80,8 @@ class TestCuCIMDict(unittest.TestCase): def test_tramsforms_numpy_single(self, params, input, expected): input = {"image": input} output = CuCIMd(keys="image", **params)(input)["image"] - self.assertTrue(output.dtype == expected.dtype) - self.assertTrue(isinstance(output, np.ndarray)) + self.assertEqual(output.dtype, expected.dtype) + self.assertIsInstance(output, np.ndarray) cp.testing.assert_allclose(output, expected) @parameterized.expand( @@ -98,8 +98,8 @@ def test_tramsforms_numpy_batch(self, params, input, expected): input = {"image": input[cp.newaxis, ...]} expected = expected[cp.newaxis, ...] output = CuCIMd(keys="image", **params)(input)["image"] - self.assertTrue(output.dtype == expected.dtype) - self.assertTrue(isinstance(output, np.ndarray)) + self.assertEqual(output.dtype, expected.dtype) + self.assertIsInstance(output, np.ndarray) cp.testing.assert_allclose(output, expected) @parameterized.expand( @@ -116,8 +116,8 @@ def test_tramsforms_cupy_single(self, params, input, expected): input = {"image": cp.asarray(input)} expected = cp.asarray(expected) output = CuCIMd(keys="image", **params)(input)["image"] - self.assertTrue(output.dtype == expected.dtype) - self.assertTrue(isinstance(output, cp.ndarray)) + self.assertEqual(output.dtype, expected.dtype) + self.assertIsInstance(output, cp.ndarray) cp.testing.assert_allclose(output, expected) @parameterized.expand( @@ -134,8 +134,8 @@ def test_tramsforms_cupy_batch(self, params, input, expected): input = {"image": cp.asarray(input)[cp.newaxis, ...]} expected = cp.asarray(expected)[cp.newaxis, ...] output = CuCIMd(keys="image", **params)(input)["image"] - self.assertTrue(output.dtype == expected.dtype) - self.assertTrue(isinstance(output, cp.ndarray)) + self.assertEqual(output.dtype, expected.dtype) + self.assertIsInstance(output, cp.ndarray) cp.testing.assert_allclose(output, expected) diff --git a/tests/test_cucim_transform.py b/tests/test_cucim_transform.py index 5f16c11589..162e16b52a 100644 --- a/tests/test_cucim_transform.py +++ b/tests/test_cucim_transform.py @@ -79,8 +79,8 @@ class TestCuCIM(unittest.TestCase): ) def test_tramsforms_numpy_single(self, params, input, expected): output = CuCIM(**params)(input) - self.assertTrue(output.dtype == expected.dtype) - self.assertTrue(isinstance(output, np.ndarray)) + self.assertEqual(output.dtype, expected.dtype) + self.assertIsInstance(output, np.ndarray) cp.testing.assert_allclose(output, expected) @parameterized.expand( @@ -97,8 +97,8 @@ def test_tramsforms_numpy_batch(self, params, input, expected): input = input[cp.newaxis, ...] expected = expected[cp.newaxis, ...] output = CuCIM(**params)(input) - self.assertTrue(output.dtype == expected.dtype) - self.assertTrue(isinstance(output, np.ndarray)) + self.assertEqual(output.dtype, expected.dtype) + self.assertIsInstance(output, np.ndarray) cp.testing.assert_allclose(output, expected) @parameterized.expand( @@ -115,8 +115,8 @@ def test_tramsforms_cupy_single(self, params, input, expected): input = cp.asarray(input) expected = cp.asarray(expected) output = CuCIM(**params)(input) - self.assertTrue(output.dtype == expected.dtype) - self.assertTrue(isinstance(output, cp.ndarray)) + self.assertEqual(output.dtype, expected.dtype) + self.assertIsInstance(output, cp.ndarray) cp.testing.assert_allclose(output, expected) @parameterized.expand( @@ -133,8 +133,8 @@ def test_tramsforms_cupy_batch(self, params, input, expected): input = cp.asarray(input)[cp.newaxis, ...] expected = cp.asarray(expected)[cp.newaxis, ...] output = CuCIM(**params)(input) - self.assertTrue(output.dtype == expected.dtype) - self.assertTrue(isinstance(output, cp.ndarray)) + self.assertEqual(output.dtype, expected.dtype) + self.assertIsInstance(output, cp.ndarray) cp.testing.assert_allclose(output, expected) diff --git a/tests/test_detect_envelope.py b/tests/test_detect_envelope.py index e2efefeb77..f9c2b5ac53 100644 --- a/tests/test_detect_envelope.py +++ b/tests/test_detect_envelope.py @@ -147,7 +147,7 @@ def test_value_error(self, arguments, image, method): elif method == "__call__": self.assertRaises(ValueError, DetectEnvelope(**arguments), image) else: - raise ValueError("Expected raising method invalid. Should be __init__ or __call__.") + self.fail("Expected raising method invalid. Should be __init__ or __call__.") @SkipIfModule("torch.fft") diff --git a/tests/test_diffusion_inferer.py b/tests/test_diffusion_inferer.py index ecd4855385..7f37025d3c 100644 --- a/tests/test_diffusion_inferer.py +++ b/tests/test_diffusion_inferer.py @@ -12,6 +12,7 @@ from __future__ import annotations import unittest +from unittest import skipUnless import torch from parameterized import parameterized @@ -22,6 +23,7 @@ from monai.utils import optional_import _, has_scipy = optional_import("scipy") +_, has_einops = optional_import("einops") TEST_CASES = [ [ @@ -55,6 +57,7 @@ class TestDiffusionSamplingInferer(unittest.TestCase): @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_call(self, model_params, input_shape): model = DiffusionModelUNet(**model_params) device = "cuda:0" if torch.cuda.is_available() else "cpu" @@ -70,6 +73,7 @@ def test_call(self, model_params, input_shape): self.assertEqual(sample.shape, input_shape) @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_sample_intermediates(self, model_params, input_shape): model = DiffusionModelUNet(**model_params) device = "cuda:0" if torch.cuda.is_available() else "cpu" @@ -85,6 +89,7 @@ def test_sample_intermediates(self, model_params, input_shape): self.assertEqual(len(intermediates), 10) @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_ddpm_sampler(self, model_params, input_shape): model = DiffusionModelUNet(**model_params) device = "cuda:0" if torch.cuda.is_available() else "cpu" @@ -100,6 +105,7 @@ def test_ddpm_sampler(self, model_params, input_shape): self.assertEqual(len(intermediates), 10) @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_ddim_sampler(self, model_params, input_shape): model = DiffusionModelUNet(**model_params) device = "cuda:0" if torch.cuda.is_available() else "cpu" @@ -115,6 +121,7 @@ def test_ddim_sampler(self, model_params, input_shape): self.assertEqual(len(intermediates), 10) @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_sampler_conditioned(self, model_params, input_shape): model_params["with_conditioning"] = True model_params["cross_attention_dim"] = 3 @@ -138,6 +145,7 @@ def test_sampler_conditioned(self, model_params, input_shape): self.assertEqual(len(intermediates), 10) @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_get_likelihood(self, model_params, input_shape): model = DiffusionModelUNet(**model_params) device = "cuda:0" if torch.cuda.is_available() else "cpu" @@ -166,6 +174,7 @@ def test_normal_cdf(self): torch.testing.assert_allclose(cdf_approx, cdf_true, atol=1e-3, rtol=1e-5) @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_sampler_conditioned_concat(self, model_params, input_shape): # copy the model_params dict to prevent from modifying test cases model_params = model_params.copy() @@ -196,6 +205,7 @@ def test_sampler_conditioned_concat(self, model_params, input_shape): self.assertEqual(len(intermediates), 10) @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_call_conditioned_concat(self, model_params, input_shape): # copy the model_params dict to prevent from modifying test cases model_params = model_params.copy() diff --git a/tests/test_diffusion_model_unet.py b/tests/test_diffusion_model_unet.py index d40a31a1da..7f764d85de 100644 --- a/tests/test_diffusion_model_unet.py +++ b/tests/test_diffusion_model_unet.py @@ -11,13 +11,21 @@ from __future__ import annotations +import os +import tempfile import unittest +from unittest import skipUnless import torch from parameterized import parameterized +from monai.apps import download_url from monai.networks import eval_mode from monai.networks.nets import DiffusionModelUNet +from monai.utils import optional_import +from tests.utils import skip_if_downloading_fails, testing_data_config + +_, has_einops = optional_import("einops") UNCOND_CASES_2D = [ [ @@ -286,12 +294,14 @@ class TestDiffusionModelUNet2D(unittest.TestCase): @parameterized.expand(UNCOND_CASES_2D) + @skipUnless(has_einops, "Requires einops") def test_shape_unconditioned_models(self, input_param): net = DiffusionModelUNet(**input_param) with eval_mode(net): result = net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1,)).long()) self.assertEqual(result.shape, (1, 1, 16, 16)) + @skipUnless(has_einops, "Requires einops") def test_timestep_with_wrong_shape(self): net = DiffusionModelUNet( spatial_dims=2, @@ -306,6 +316,7 @@ def test_timestep_with_wrong_shape(self): with eval_mode(net): net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1, 1)).long()) + @skipUnless(has_einops, "Requires einops") def test_shape_with_different_in_channel_out_channel(self): in_channels = 6 out_channels = 3 @@ -359,6 +370,7 @@ def test_num_res_blocks_with_different_length_channels(self): norm_num_groups=8, ) + @skipUnless(has_einops, "Requires einops") def test_shape_conditioned_models(self): net = DiffusionModelUNet( spatial_dims=2, @@ -396,6 +408,7 @@ def test_with_conditioning_cross_attention_dim_none(self): norm_num_groups=8, ) + @skipUnless(has_einops, "Requires einops") def test_context_with_conditioning_none(self): net = DiffusionModelUNet( spatial_dims=2, @@ -417,6 +430,7 @@ def test_context_with_conditioning_none(self): context=torch.rand((1, 1, 3)), ) + @skipUnless(has_einops, "Requires einops") def test_shape_conditioned_models_class_conditioning(self): net = DiffusionModelUNet( spatial_dims=2, @@ -437,6 +451,7 @@ def test_shape_conditioned_models_class_conditioning(self): ) self.assertEqual(result.shape, (1, 1, 16, 32)) + @skipUnless(has_einops, "Requires einops") def test_conditioned_models_no_class_labels(self): net = DiffusionModelUNet( spatial_dims=2, @@ -453,6 +468,7 @@ def test_conditioned_models_no_class_labels(self): with self.assertRaises(ValueError): net.forward(x=torch.rand((1, 1, 16, 32)), timesteps=torch.randint(0, 1000, (1,)).long()) + @skipUnless(has_einops, "Requires einops") def test_model_channels_not_same_size_of_attention_levels(self): with self.assertRaises(ValueError): DiffusionModelUNet( @@ -468,6 +484,7 @@ def test_model_channels_not_same_size_of_attention_levels(self): ) @parameterized.expand(COND_CASES_2D) + @skipUnless(has_einops, "Requires einops") def test_conditioned_2d_models_shape(self, input_param): net = DiffusionModelUNet(**input_param) with eval_mode(net): @@ -477,12 +494,14 @@ def test_conditioned_2d_models_shape(self, input_param): class TestDiffusionModelUNet3D(unittest.TestCase): @parameterized.expand(UNCOND_CASES_3D) + @skipUnless(has_einops, "Requires einops") def test_shape_unconditioned_models(self, input_param): net = DiffusionModelUNet(**input_param) with eval_mode(net): result = net.forward(torch.rand((1, 1, 16, 16, 16)), torch.randint(0, 1000, (1,)).long()) self.assertEqual(result.shape, (1, 1, 16, 16, 16)) + @skipUnless(has_einops, "Requires einops") def test_shape_with_different_in_channel_out_channel(self): in_channels = 6 out_channels = 3 @@ -499,6 +518,7 @@ def test_shape_with_different_in_channel_out_channel(self): result = net.forward(torch.rand((1, in_channels, 16, 16, 16)), torch.randint(0, 1000, (1,)).long()) self.assertEqual(result.shape, (1, out_channels, 16, 16, 16)) + @skipUnless(has_einops, "Requires einops") def test_shape_conditioned_models(self): net = DiffusionModelUNet( spatial_dims=3, @@ -527,9 +547,39 @@ def test_wrong_dropout(self, input_param): _ = DiffusionModelUNet(**input_param) @parameterized.expand(DROPOUT_OK) + @skipUnless(has_einops, "Requires einops") def test_right_dropout(self, input_param): _ = DiffusionModelUNet(**input_param) + @skipUnless(has_einops, "Requires einops") + def test_compatibility_with_monai_generative(self): + # test loading weights from a model saved in MONAI Generative, version 0.2.3 + with skip_if_downloading_fails(): + net = DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + channels=(8, 8, 8), + attention_levels=(False, False, True), + with_conditioning=True, + cross_attention_dim=3, + transformer_num_layers=1, + norm_num_groups=8, + ) + + tmpdir = tempfile.mkdtemp() + key = "diffusion_model_unet_monai_generative_weights" + url = testing_data_config("models", key, "url") + hash_type = testing_data_config("models", key, "hash_type") + hash_val = testing_data_config("models", key, "hash_val") + filename = "diffusion_model_unet_monai_generative_weights.pt" + + weight_path = os.path.join(tmpdir, filename) + download_url(url=url, filepath=weight_path, hash_val=hash_val, hash_type=hash_type) + + net.load_old_state_dict(torch.load(weight_path), verbose=False) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_ensure_typed.py b/tests/test_ensure_typed.py index 09aa1f04b5..fe543347de 100644 --- a/tests/test_ensure_typed.py +++ b/tests/test_ensure_typed.py @@ -33,8 +33,8 @@ def test_array_input(self): keys="data", data_type=dtype, dtype=np.float32 if dtype == "NUMPY" else None, device="cpu" )({"data": test_data})["data"] if dtype == "NUMPY": - self.assertTrue(result.dtype == np.float32) - self.assertTrue(isinstance(result, torch.Tensor if dtype == "tensor" else np.ndarray)) + self.assertEqual(result.dtype, np.float32) + self.assertIsInstance(result, torch.Tensor if dtype == "tensor" else np.ndarray) assert_allclose(result, test_data, type_test=False) self.assertTupleEqual(result.shape, (2, 2)) @@ -45,7 +45,7 @@ def test_single_input(self): for test_data in test_datas: for dtype in ("tensor", "numpy"): result = EnsureTyped(keys="data", data_type=dtype)({"data": test_data})["data"] - self.assertTrue(isinstance(result, torch.Tensor if dtype == "tensor" else np.ndarray)) + self.assertIsInstance(result, torch.Tensor if dtype == "tensor" else np.ndarray) if isinstance(test_data, bool): self.assertFalse(result) else: @@ -56,11 +56,11 @@ def test_string(self): for dtype in ("tensor", "numpy"): # string input result = EnsureTyped(keys="data", data_type=dtype)({"data": "test_string"})["data"] - self.assertTrue(isinstance(result, str)) + self.assertIsInstance(result, str) self.assertEqual(result, "test_string") # numpy array of string result = EnsureTyped(keys="data", data_type=dtype)({"data": np.array(["test_string"])})["data"] - self.assertTrue(isinstance(result, np.ndarray)) + self.assertIsInstance(result, np.ndarray) self.assertEqual(result[0], "test_string") def test_list_tuple(self): @@ -68,15 +68,15 @@ def test_list_tuple(self): result = EnsureTyped(keys="data", data_type=dtype, wrap_sequence=False, track_meta=True)( {"data": [[1, 2], [3, 4]]} )["data"] - self.assertTrue(isinstance(result, list)) - self.assertTrue(isinstance(result[0][1], MetaTensor if dtype == "tensor" else np.ndarray)) + self.assertIsInstance(result, list) + self.assertIsInstance(result[0][1], MetaTensor if dtype == "tensor" else np.ndarray) assert_allclose(result[1][0], torch.as_tensor(3), type_test=False) # tuple of numpy arrays result = EnsureTyped(keys="data", data_type=dtype, wrap_sequence=False)( {"data": (np.array([1, 2]), np.array([3, 4]))} )["data"] - self.assertTrue(isinstance(result, tuple)) - self.assertTrue(isinstance(result[0], torch.Tensor if dtype == "tensor" else np.ndarray)) + self.assertIsInstance(result, tuple) + self.assertIsInstance(result[0], torch.Tensor if dtype == "tensor" else np.ndarray) assert_allclose(result[1], torch.as_tensor([3, 4]), type_test=False) def test_dict(self): @@ -92,19 +92,19 @@ def test_dict(self): ) for key in ("data", "label"): result = trans[key] - self.assertTrue(isinstance(result, dict)) - self.assertTrue(isinstance(result["img"], torch.Tensor if dtype == "tensor" else np.ndarray)) - self.assertTrue(isinstance(result["meta"]["size"], torch.Tensor if dtype == "tensor" else np.ndarray)) + self.assertIsInstance(result, dict) + self.assertIsInstance(result["img"], torch.Tensor if dtype == "tensor" else np.ndarray) + self.assertIsInstance(result["meta"]["size"], torch.Tensor if dtype == "tensor" else np.ndarray) self.assertEqual(result["meta"]["path"], "temp/test") self.assertEqual(result["extra"], None) assert_allclose(result["img"], torch.as_tensor([1.0, 2.0]), type_test=False) assert_allclose(result["meta"]["size"], torch.as_tensor([1, 2, 3]), type_test=False) if dtype == "numpy": - self.assertTrue(trans["data"]["img"].dtype == np.float32) - self.assertTrue(trans["label"]["img"].dtype == np.int8) + self.assertEqual(trans["data"]["img"].dtype, np.float32) + self.assertEqual(trans["label"]["img"].dtype, np.int8) else: - self.assertTrue(trans["data"]["img"].dtype == torch.float32) - self.assertTrue(trans["label"]["img"].dtype == torch.int8) + self.assertEqual(trans["data"]["img"].dtype, torch.float32) + self.assertEqual(trans["label"]["img"].dtype, torch.int8) if __name__ == "__main__": diff --git a/tests/test_flipd.py b/tests/test_flipd.py index 277f387051..1df6d34056 100644 --- a/tests/test_flipd.py +++ b/tests/test_flipd.py @@ -78,7 +78,7 @@ def test_torch(self, spatial_axis, img: torch.Tensor, track_meta: bool, device): def test_meta_dict(self): xform = Flipd("image", [0, 1]) res = xform({"image": torch.zeros(1, 3, 4)}) - self.assertTrue(res["image"].applied_operations == res["image_transforms"]) + self.assertEqual(res["image"].applied_operations, res["image_transforms"]) if __name__ == "__main__": diff --git a/tests/test_freeze_layers.py b/tests/test_freeze_layers.py index 1bea4ed1b5..7be8e576bf 100644 --- a/tests/test_freeze_layers.py +++ b/tests/test_freeze_layers.py @@ -40,9 +40,9 @@ def test_freeze_vars(self, device): for name, param in model.named_parameters(): if "class_layer" in name: - self.assertEqual(param.requires_grad, False) + self.assertFalse(param.requires_grad) else: - self.assertEqual(param.requires_grad, True) + self.assertTrue(param.requires_grad) @parameterized.expand(TEST_CASES) def test_exclude_vars(self, device): @@ -53,9 +53,9 @@ def test_exclude_vars(self, device): for name, param in model.named_parameters(): if "class_layer" in name: - self.assertEqual(param.requires_grad, True) + self.assertTrue(param.requires_grad) else: - self.assertEqual(param.requires_grad, False) + self.assertFalse(param.requires_grad) if __name__ == "__main__": diff --git a/tests/test_generalized_dice_loss.py b/tests/test_generalized_dice_loss.py index 7499507129..5738f4a089 100644 --- a/tests/test_generalized_dice_loss.py +++ b/tests/test_generalized_dice_loss.py @@ -184,7 +184,7 @@ def test_differentiability(self): generalized_dice_loss = GeneralizedDiceLoss() loss = generalized_dice_loss(prediction, target) - self.assertNotEqual(loss.grad_fn, None) + self.assertIsNotNone(loss.grad_fn) def test_batch(self): prediction = torch.zeros(2, 3, 3, 3) @@ -194,7 +194,7 @@ def test_batch(self): generalized_dice_loss = GeneralizedDiceLoss(batch=True) loss = generalized_dice_loss(prediction, target) - self.assertNotEqual(loss.grad_fn, None) + self.assertIsNotNone(loss.grad_fn) def test_script(self): loss = GeneralizedDiceLoss() diff --git a/tests/test_get_package_version.py b/tests/test_get_package_version.py index ab9e69cd31..e9e1d8eca6 100644 --- a/tests/test_get_package_version.py +++ b/tests/test_get_package_version.py @@ -20,14 +20,14 @@ class TestGetVersion(unittest.TestCase): def test_default(self): output = get_package_version("42foobarnoexist") - self.assertTrue("UNKNOWN" in output) + self.assertIn("UNKNOWN", output) output = get_package_version("numpy") - self.assertFalse("UNKNOWN" in output) + self.assertNotIn("UNKNOWN", output) def test_msg(self): output = get_package_version("42foobarnoexist", "test") - self.assertTrue("test" in output) + self.assertIn("test", output) if __name__ == "__main__": diff --git a/tests/test_grid_patch.py b/tests/test_grid_patch.py index 4b324eda1a..56af123548 100644 --- a/tests/test_grid_patch.py +++ b/tests/test_grid_patch.py @@ -124,11 +124,11 @@ def test_grid_patch_meta(self, input_parameters, image, expected, expected_meta) self.assertTrue(output.meta["path"] == expected_meta[0]["path"]) for output_patch, expected_patch, expected_patch_meta in zip(output, expected, expected_meta): assert_allclose(output_patch, expected_patch, type_test=False) - self.assertTrue(isinstance(output_patch, MetaTensor)) - self.assertTrue(output_patch.meta["location"] == expected_patch_meta["location"]) + self.assertIsInstance(output_patch, MetaTensor) + self.assertEqual(output_patch.meta["location"], expected_patch_meta["location"]) self.assertTrue(output_patch.meta["spatial_shape"], list(output_patch.shape[1:])) if "path" in expected_meta[0]: - self.assertTrue(output_patch.meta["path"] == expected_patch_meta["path"]) + self.assertEqual(output_patch.meta["path"], expected_patch_meta["path"]) if __name__ == "__main__": diff --git a/tests/test_handler_stats.py b/tests/test_handler_stats.py index f876cff2a3..52da5c179b 100644 --- a/tests/test_handler_stats.py +++ b/tests/test_handler_stats.py @@ -76,9 +76,9 @@ def _update_metric(engine): if has_key_word.match(line): content_count += 1 if epoch_log is True: - self.assertTrue(content_count == max_epochs) + self.assertEqual(content_count, max_epochs) else: - self.assertTrue(content_count == 2) # 2 = len([1, 2]) from event_filter + self.assertEqual(content_count, 2) # 2 = len([1, 2]) from event_filter @parameterized.expand([[True], [get_event_filter([1, 3])]]) def test_loss_print(self, iteration_log): @@ -116,9 +116,9 @@ def _train_func(engine, batch): if has_key_word.match(line): content_count += 1 if iteration_log is True: - self.assertTrue(content_count == num_iters * max_epochs) + self.assertEqual(content_count, num_iters * max_epochs) else: - self.assertTrue(content_count == 2) # 2 = len([1, 3]) from event_filter + self.assertEqual(content_count, 2) # 2 = len([1, 3]) from event_filter def test_loss_dict(self): log_stream = StringIO() @@ -150,7 +150,7 @@ def _train_func(engine, batch): for line in output_str.split("\n"): if has_key_word.match(line): content_count += 1 - self.assertTrue(content_count > 0) + self.assertGreater(content_count, 0) def test_loss_file(self): key_to_handler = "test_logging" @@ -184,7 +184,7 @@ def _train_func(engine, batch): for line in output_str.split("\n"): if has_key_word.match(line): content_count += 1 - self.assertTrue(content_count > 0) + self.assertGreater(content_count, 0) def test_exception(self): # set up engine @@ -239,7 +239,7 @@ def _update_metric(engine): for line in output_str.split("\n"): if has_key_word.match(line): content_count += 1 - self.assertTrue(content_count > 0) + self.assertGreater(content_count, 0) def test_default_logger(self): log_stream = StringIO() @@ -274,7 +274,7 @@ def _train_func(engine, batch): for line in output_str.split("\n"): if has_key_word.match(line): content_count += 1 - self.assertTrue(content_count > 0) + self.assertGreater(content_count, 0) if __name__ == "__main__": diff --git a/tests/test_integration_bundle_run.py b/tests/test_integration_bundle_run.py index c2e0fb55b7..60aaef05bf 100644 --- a/tests/test_integration_bundle_run.py +++ b/tests/test_integration_bundle_run.py @@ -135,9 +135,8 @@ def test_scripts_fold(self): command_run = cmd + ["run", "training", "--config_file", config_file, "--meta_file", meta_file] completed_process = subprocess.run(command_run, check=True, capture_output=True, text=True) output = repr(completed_process.stdout).replace("\\n", "\n").replace("\\t", "\t") # Get the captured output - print(output) - self.assertTrue(expected_condition in output) + self.assertIn(expected_condition, output) command_run_workflow = cmd + [ "run_workflow", "--run_id", @@ -149,8 +148,7 @@ def test_scripts_fold(self): ] completed_process = subprocess.run(command_run_workflow, check=True, capture_output=True, text=True) output = repr(completed_process.stdout).replace("\\n", "\n").replace("\\t", "\t") # Get the captured output - print(output) - self.assertTrue(expected_condition in output) + self.assertIn(expected_condition, output) # test missing meta file self.assertIn("ERROR", command_line_tests(cmd + ["run", "training", "--config_file", config_file])) diff --git a/tests/test_inverse_collation.py b/tests/test_inverse_collation.py index f33b5c67eb..bf3972e6bd 100644 --- a/tests/test_inverse_collation.py +++ b/tests/test_inverse_collation.py @@ -133,7 +133,7 @@ def test_collation(self, _, transform, collate_fn, ndim): d = decollate_batch(item) self.assertTrue(len(d) <= self.batch_size) for b in d: - self.assertTrue(isinstance(b["image"], MetaTensor)) + self.assertIsInstance(b["image"], MetaTensor) np.testing.assert_array_equal( b["image"].applied_operations[-1]["orig_size"], b["label"].applied_operations[-1]["orig_size"] ) diff --git a/tests/test_invertd.py b/tests/test_invertd.py index c32a3af643..f6e8fc40e7 100644 --- a/tests/test_invertd.py +++ b/tests/test_invertd.py @@ -134,7 +134,7 @@ def test_invert(self): # 25300: 2 workers (cpu, non-macos) # 1812: 0 workers (gpu or macos) # 1821: windows torch 1.10.0 - self.assertTrue((reverted.size - n_good) < 40000, f"diff. {reverted.size - n_good}") + self.assertLess((reverted.size - n_good), 40000, f"diff. {reverted.size - n_good}") set_determinism(seed=None) diff --git a/tests/test_latent_diffusion_inferer.py b/tests/test_latent_diffusion_inferer.py index 4ab803bb6f..065ebafd95 100644 --- a/tests/test_latent_diffusion_inferer.py +++ b/tests/test_latent_diffusion_inferer.py @@ -12,6 +12,7 @@ from __future__ import annotations import unittest +from unittest import skipUnless import torch from parameterized import parameterized @@ -19,7 +20,9 @@ from monai.inferers import LatentDiffusionInferer from monai.networks.nets import VQVAE, AutoencoderKL, DiffusionModelUNet, SPADEAutoencoderKL, SPADEDiffusionModelUNet from monai.networks.schedulers import DDPMScheduler +from monai.utils import optional_import +_, has_einops = optional_import("einops") TEST_CASES = [ [ "AutoencoderKL", @@ -313,6 +316,7 @@ class TestDiffusionSamplingInferer(unittest.TestCase): @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_prediction_shape( self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape ): @@ -360,6 +364,7 @@ def test_prediction_shape( self.assertEqual(prediction.shape, latent_shape) @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_sample_shape( self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape ): @@ -404,6 +409,7 @@ def test_sample_shape( self.assertEqual(sample.shape, input_shape) @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_sample_intermediates( self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape ): @@ -458,6 +464,7 @@ def test_sample_intermediates( self.assertEqual(intermediates[0].shape, input_shape) @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_get_likelihoods( self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape ): @@ -510,6 +517,7 @@ def test_get_likelihoods( self.assertEqual(intermediates[0].shape, latent_shape) @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_resample_likelihoods( self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape ): @@ -564,6 +572,7 @@ def test_resample_likelihoods( self.assertEqual(intermediates[0].shape[2:], input_shape[2:]) @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_prediction_shape_conditioned_concat( self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape ): @@ -629,6 +638,7 @@ def test_prediction_shape_conditioned_concat( self.assertEqual(prediction.shape, latent_shape) @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_sample_shape_conditioned_concat( self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape ): @@ -689,6 +699,7 @@ def test_sample_shape_conditioned_concat( self.assertEqual(sample.shape, input_shape) @parameterized.expand(TEST_CASES_DIFF_SHAPES) + @skipUnless(has_einops, "Requires einops") def test_sample_shape_different_latents( self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape ): @@ -745,6 +756,7 @@ def test_sample_shape_different_latents( ) self.assertEqual(prediction.shape, latent_shape) + @skipUnless(has_einops, "Requires einops") def test_incompatible_spade_setup(self): stage_1 = SPADEAutoencoderKL( spatial_dims=2, diff --git a/tests/test_load_imaged.py b/tests/test_load_imaged.py index 699ed70059..914240c705 100644 --- a/tests/test_load_imaged.py +++ b/tests/test_load_imaged.py @@ -190,7 +190,7 @@ def test_correct(self, input_p, expected_shape, track_meta): self.assertTrue(hasattr(r, "affine")) self.assertIsInstance(r.affine, torch.Tensor) self.assertEqual(r.meta["space"], "RAS") - self.assertTrue("qform_code" not in r.meta) + self.assertNotIn("qform_code", r.meta) else: self.assertIsInstance(r, torch.Tensor) self.assertNotIsInstance(r, MetaTensor) diff --git a/tests/test_load_spacing_orientation.py b/tests/test_load_spacing_orientation.py index 63422761ca..cbc730e1bb 100644 --- a/tests/test_load_spacing_orientation.py +++ b/tests/test_load_spacing_orientation.py @@ -48,7 +48,7 @@ def test_load_spacingd(self, filename): ref = resample_to_output(anat, (1, 0.2, 1), order=1) t2 = time.time() print(f"time scipy: {t2 - t1}") - self.assertTrue(t2 >= t1) + self.assertGreaterEqual(t2, t1) np.testing.assert_allclose(res_dict["image"].affine, ref.affine) np.testing.assert_allclose(res_dict["image"].shape[1:], ref.shape) np.testing.assert_allclose(ref.get_fdata(), res_dict["image"][0], atol=0.05) @@ -68,7 +68,7 @@ def test_load_spacingd_rotate(self, filename): ref = resample_to_output(anat, (1, 2, 3), order=1) t2 = time.time() print(f"time scipy: {t2 - t1}") - self.assertTrue(t2 >= t1) + self.assertGreaterEqual(t2, t1) np.testing.assert_allclose(res_dict["image"].affine, ref.affine) if "anatomical" not in filename: np.testing.assert_allclose(res_dict["image"].shape[1:], ref.shape) diff --git a/tests/test_look_up_option.py b/tests/test_look_up_option.py index d40b7eaa8c..75560b4ac4 100644 --- a/tests/test_look_up_option.py +++ b/tests/test_look_up_option.py @@ -56,7 +56,7 @@ def test_default(self): def test_str_enum(self): output = look_up_option("C", {"A", "B"}, default=None) - self.assertEqual(output, None) + self.assertIsNone(output) self.assertEqual(list(_CaseStrEnum), ["A", "B"]) self.assertEqual(_CaseStrEnum.MODE_A, "A") self.assertEqual(str(_CaseStrEnum.MODE_A), "A") diff --git a/tests/test_matshow3d.py b/tests/test_matshow3d.py index e513025e69..e54bb523e4 100644 --- a/tests/test_matshow3d.py +++ b/tests/test_matshow3d.py @@ -78,7 +78,7 @@ def test_samples(self): fig, mat = matshow3d( [im[keys] for im in ims], title=f"testing {keys}", figsize=(2, 2), frames_per_row=5, every_n=2, show=False ) - self.assertTrue(mat.dtype == np.float32) + self.assertEqual(mat.dtype, np.float32) with tempfile.TemporaryDirectory() as tempdir: tempimg = f"{tempdir}/matshow3d_patch_test.png" diff --git a/tests/test_median_filter.py b/tests/test_median_filter.py index 516388afce..bdfdf24f9f 100644 --- a/tests/test_median_filter.py +++ b/tests/test_median_filter.py @@ -21,13 +21,13 @@ class MedianFilterTestCase(unittest.TestCase): + @parameterized.expand([(torch.ones(1, 1, 2, 3, 5), [1, 2, 4]), (torch.ones(1, 1, 4, 3, 4), 1)]) # 3d_big # 3d def test_3d(self, input_tensor, radius): filter = MedianFilter(radius).to(torch.device("cpu:0")) expected = input_tensor.numpy() output = filter(input_tensor).cpu().numpy() - np.testing.assert_allclose(output, expected, rtol=1e-5) def test_3d_radii(self): diff --git a/tests/test_mednistdataset.py b/tests/test_mednistdataset.py index 1db632c144..c1b21e9373 100644 --- a/tests/test_mednistdataset.py +++ b/tests/test_mednistdataset.py @@ -41,7 +41,7 @@ def _test_dataset(dataset): self.assertEqual(len(dataset), int(MEDNIST_FULL_DATASET_LENGTH * dataset.test_frac)) self.assertTrue("image" in dataset[0]) self.assertTrue("label" in dataset[0]) - self.assertTrue(isinstance(dataset[0]["image"], MetaTensor)) + self.assertIsInstance(dataset[0]["image"], MetaTensor) self.assertTupleEqual(dataset[0]["image"].shape, (1, 64, 64)) with skip_if_downloading_fails(): diff --git a/tests/test_meta_affine.py b/tests/test_meta_affine.py index 95764a0c89..890734391f 100644 --- a/tests/test_meta_affine.py +++ b/tests/test_meta_affine.py @@ -160,7 +160,7 @@ def test_linear_consistent(self, xform_cls, input_dict, atol): diff = np.abs(itk.GetArrayFromImage(ref_2) - itk.GetArrayFromImage(expected)) avg_diff = np.mean(diff) - self.assertTrue(avg_diff < atol, f"{xform_cls} avg_diff: {avg_diff}, tol: {atol}") + self.assertLess(avg_diff, atol, f"{xform_cls} avg_diff: {avg_diff}, tol: {atol}") @parameterized.expand(TEST_CASES_DICT) def test_linear_consistent_dict(self, xform_cls, input_dict, atol): @@ -175,7 +175,7 @@ def test_linear_consistent_dict(self, xform_cls, input_dict, atol): diff = {k: np.abs(itk.GetArrayFromImage(ref_2[k]) - itk.GetArrayFromImage(expected[k])) for k in keys} avg_diff = {k: np.mean(diff[k]) for k in keys} for k in keys: - self.assertTrue(avg_diff[k] < atol, f"{xform_cls} avg_diff: {avg_diff}, tol: {atol}") + self.assertLess(avg_diff[k], atol, f"{xform_cls} avg_diff: {avg_diff}, tol: {atol}") if __name__ == "__main__": diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py index 1e0f188b63..f31a07eba4 100644 --- a/tests/test_meta_tensor.py +++ b/tests/test_meta_tensor.py @@ -222,9 +222,9 @@ def test_stack(self, device, dtype): def test_get_set_meta_fns(self): set_track_meta(False) - self.assertEqual(get_track_meta(), False) + self.assertFalse(get_track_meta()) set_track_meta(True) - self.assertEqual(get_track_meta(), True) + self.assertTrue(get_track_meta()) @parameterized.expand(TEST_DEVICES) def test_torchscript(self, device): diff --git a/tests/test_mmar_download.py b/tests/test_mmar_download.py index 6af3d09fb2..2ac73a8149 100644 --- a/tests/test_mmar_download.py +++ b/tests/test_mmar_download.py @@ -142,7 +142,7 @@ def test_load_ckpt(self, input_args, expected_name, expected_val): def test_unique(self): # model ids are unique keys = sorted(m["id"] for m in MODEL_DESC) - self.assertTrue(keys == sorted(set(keys))) + self.assertEqual(keys, sorted(set(keys))) def test_search(self): self.assertEqual(_get_val({"a": 1, "b": 2}, key="b"), 2) diff --git a/tests/test_persistentdataset.py b/tests/test_persistentdataset.py index b7bf2fbb11..7c4969e283 100644 --- a/tests/test_persistentdataset.py +++ b/tests/test_persistentdataset.py @@ -165,7 +165,7 @@ def test_different_transforms(self): im1 = PersistentDataset([im], Identity(), cache_dir=path, hash_transform=json_hashing)[0] im2 = PersistentDataset([im], Flip(1), cache_dir=path, hash_transform=json_hashing)[0] l2 = ((im1 - im2) ** 2).sum() ** 0.5 - self.assertTrue(l2 > 1) + self.assertGreater(l2, 1) if __name__ == "__main__": diff --git a/tests/test_rand_affined.py b/tests/test_rand_affined.py index 950058a9e9..eb8ebd06c5 100644 --- a/tests/test_rand_affined.py +++ b/tests/test_rand_affined.py @@ -240,7 +240,7 @@ def test_rand_affined(self, input_param, input_data, expected_val, track_meta): resampler.lazy = False if input_param.get("cache_grid", False): - self.assertTrue(g.rand_affine._cached_grid is not None) + self.assertIsNotNone(g.rand_affine._cached_grid) for key in res: if isinstance(key, str) and key.endswith("_transforms"): continue diff --git a/tests/test_rand_bias_field.py b/tests/test_rand_bias_field.py index 333a9ecba5..328f46b7ee 100644 --- a/tests/test_rand_bias_field.py +++ b/tests/test_rand_bias_field.py @@ -39,7 +39,7 @@ def test_output_shape(self, class_args, img_shape): img = p(np.random.rand(*img_shape)) output = bias_field(img) np.testing.assert_equal(output.shape, img_shape) - self.assertTrue(output.dtype in (np.float32, torch.float32)) + self.assertIn(output.dtype, (np.float32, torch.float32)) img_zero = np.zeros([*img_shape]) output_zero = bias_field(img_zero) diff --git a/tests/test_rand_weighted_cropd.py b/tests/test_rand_weighted_cropd.py index 1524442f61..a1414df0ac 100644 --- a/tests/test_rand_weighted_cropd.py +++ b/tests/test_rand_weighted_cropd.py @@ -154,7 +154,7 @@ def test_rand_weighted_cropd(self, _, init_params, input_data, expected_shape, e crop = RandWeightedCropd(**init_params) crop.set_random_state(10) result = crop(input_data) - self.assertTrue(len(result) == init_params["num_samples"]) + self.assertEqual(len(result), init_params["num_samples"]) _len = len(tuple(input_data.keys())) self.assertTupleEqual(tuple(result[0].keys())[:_len], tuple(input_data.keys())) diff --git a/tests/test_recon_net_utils.py b/tests/test_recon_net_utils.py index 1815000777..48d3b59a17 100644 --- a/tests/test_recon_net_utils.py +++ b/tests/test_recon_net_utils.py @@ -64,7 +64,7 @@ def test_reshape_channel_complex(self, test_data): def test_complex_normalize(self, test_data): result, mean, std = complex_normalize(test_data) result = result * std + mean - self.assertTrue((((result - test_data) ** 2).mean() ** 0.5).item() < 1e-5) + self.assertLess((((result - test_data) ** 2).mean() ** 0.5).item(), 1e-5) @parameterized.expand(TEST_PAD) def test_pad(self, test_data): diff --git a/tests/test_reg_loss_integration.py b/tests/test_reg_loss_integration.py index e8f82eb0c2..1fb81689e6 100644 --- a/tests/test_reg_loss_integration.py +++ b/tests/test_reg_loss_integration.py @@ -99,7 +99,7 @@ def forward(self, x): # backward pass loss_val.backward() optimizer.step() - self.assertTrue(init_loss > loss_val, "loss did not decrease") + self.assertGreater(init_loss, loss_val, "loss did not decrease") if __name__ == "__main__": diff --git a/tests/test_resnet.py b/tests/test_resnet.py index 449edba4bf..5d34a32d8d 100644 --- a/tests/test_resnet.py +++ b/tests/test_resnet.py @@ -198,6 +198,14 @@ [model, *TEST_CASE_1] for model in [resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200] ] +CASE_EXTRACT_FEATURES = [ + ( + {"model_name": "resnet10", "pretrained": True, "spatial_dims": 3, "in_channels": 1}, + [1, 1, 64, 64, 64], + ([1, 64, 32, 32, 32], [1, 64, 16, 16, 16], [1, 128, 8, 8, 8], [1, 256, 4, 4, 4], [1, 512, 2, 2, 2]), + ) +] + CASE_EXTRACT_FEATURES = [ ( @@ -228,7 +236,7 @@ def test_resnet_shape(self, model, input_param, input_shape, expected_shape): if input_param.get("feed_forward", True): self.assertEqual(result.shape, expected_shape) else: - self.assertTrue(result.shape in expected_shape) + self.assertIn(result.shape, expected_shape) @parameterized.expand(PRETRAINED_TEST_CASES) @skip_if_quick diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index d52cc71e55..d069d6aa30 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -62,6 +62,27 @@ def test_ill_arg(self): with self.assertRaises(ValueError): SABlock(hidden_size=620, num_heads=8, dropout_rate=0.4) + def test_attention_dim_not_multiple_of_heads(self): + with self.assertRaises(ValueError): + SABlock(hidden_size=128, num_heads=3, dropout_rate=0.1) + + @skipUnless(has_einops, "Requires einops") + def test_inner_dim_different(self): + SABlock(hidden_size=128, num_heads=4, dropout_rate=0.1, dim_head=30) + + def test_causal_no_sequence_length(self): + with self.assertRaises(ValueError): + SABlock(hidden_size=128, num_heads=4, dropout_rate=0.1, causal=True) + + @skipUnless(has_einops, "Requires einops") + def test_causal(self): + block = SABlock(hidden_size=128, num_heads=1, dropout_rate=0.1, causal=True, sequence_length=16, save_attn=True) + input_shape = (1, 16, 128) + block(torch.randn(input_shape)) + # check upper triangular part of the attention matrix is zero + assert torch.triu(block.att_mat, diagonal=1).sum() == 0 + + @skipUnless(has_einops, "Requires einops") def test_access_attn_matrix(self): # input format hidden_size = 128 @@ -83,6 +104,40 @@ def test_access_attn_matrix(self): matrix_acess_blk(torch.randn(input_shape)) assert matrix_acess_blk.att_mat.shape == (input_shape[0], input_shape[0], input_shape[1], input_shape[1]) + def test_number_of_parameters(self): + + def count_sablock_params(*args, **kwargs): + """Count the number of parameters in a SABlock.""" + sablock = SABlock(*args, **kwargs) + return sum([x.numel() for x in sablock.parameters() if x.requires_grad]) + + hidden_size = 128 + num_heads = 8 + default_dim_head = hidden_size // num_heads + + # Default dim_head is hidden_size // num_heads + nparams_default = count_sablock_params(hidden_size=hidden_size, num_heads=num_heads) + nparams_like_default = count_sablock_params( + hidden_size=hidden_size, num_heads=num_heads, dim_head=default_dim_head + ) + self.assertEqual(nparams_default, nparams_like_default) + + # Increasing dim_head should increase the number of parameters + nparams_custom_large = count_sablock_params( + hidden_size=hidden_size, num_heads=num_heads, dim_head=default_dim_head * 2 + ) + self.assertGreater(nparams_custom_large, nparams_default) + + # Decreasing dim_head should decrease the number of parameters + nparams_custom_small = count_sablock_params( + hidden_size=hidden_size, num_heads=num_heads, dim_head=default_dim_head // 2 + ) + self.assertGreater(nparams_default, nparams_custom_small) + + # Increasing the number of heads with the default behaviour should not change the number of params. + nparams_default_more_heads = count_sablock_params(hidden_size=hidden_size, num_heads=num_heads * 2) + self.assertEqual(nparams_default, nparams_default_more_heads) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_sobel_gradient.py b/tests/test_sobel_gradient.py index 3d995a60c9..a0d7cf5a8b 100644 --- a/tests/test_sobel_gradient.py +++ b/tests/test_sobel_gradient.py @@ -164,8 +164,8 @@ def test_sobel_gradients(self, image, arguments, expected_grad): ) def test_sobel_kernels(self, arguments, expected_kernels): sobel = SobelGradients(**arguments) - self.assertTrue(sobel.kernel_diff.dtype == expected_kernels[0].dtype) - self.assertTrue(sobel.kernel_smooth.dtype == expected_kernels[0].dtype) + self.assertEqual(sobel.kernel_diff.dtype, expected_kernels[0].dtype) + self.assertEqual(sobel.kernel_smooth.dtype, expected_kernels[0].dtype) assert_allclose(sobel.kernel_diff, expected_kernels[0]) assert_allclose(sobel.kernel_smooth, expected_kernels[1]) diff --git a/tests/test_sobel_gradientd.py b/tests/test_sobel_gradientd.py index 7499a0410b..03524823a5 100644 --- a/tests/test_sobel_gradientd.py +++ b/tests/test_sobel_gradientd.py @@ -187,8 +187,8 @@ def test_sobel_gradients(self, image_dict, arguments, expected_grad): ) def test_sobel_kernels(self, arguments, expected_kernels): sobel = SobelGradientsd(**arguments) - self.assertTrue(sobel.kernel_diff.dtype == expected_kernels[0].dtype) - self.assertTrue(sobel.kernel_smooth.dtype == expected_kernels[0].dtype) + self.assertEqual(sobel.kernel_diff.dtype, expected_kernels[0].dtype) + self.assertEqual(sobel.kernel_smooth.dtype, expected_kernels[0].dtype) assert_allclose(sobel.kernel_diff, expected_kernels[0]) assert_allclose(sobel.kernel_smooth, expected_kernels[1]) diff --git a/tests/test_spade_diffusion_model_unet.py b/tests/test_spade_diffusion_model_unet.py index 113e58ed89..481705f56f 100644 --- a/tests/test_spade_diffusion_model_unet.py +++ b/tests/test_spade_diffusion_model_unet.py @@ -12,13 +12,16 @@ from __future__ import annotations import unittest +from unittest import skipUnless import torch from parameterized import parameterized from monai.networks import eval_mode from monai.networks.nets import SPADEDiffusionModelUNet +from monai.utils import optional_import +einops, has_einops = optional_import("einops") UNCOND_CASES_2D = [ [ { @@ -262,6 +265,7 @@ class TestSPADEDiffusionModelUNet2D(unittest.TestCase): @parameterized.expand(UNCOND_CASES_2D) + @skipUnless(has_einops, "Requires einops") def test_shape_unconditioned_models(self, input_param): net = SPADEDiffusionModelUNet(**input_param) with eval_mode(net): @@ -272,6 +276,7 @@ def test_shape_unconditioned_models(self, input_param): ) self.assertEqual(result.shape, (1, 1, 16, 16)) + @skipUnless(has_einops, "Requires einops") def test_timestep_with_wrong_shape(self): net = SPADEDiffusionModelUNet( spatial_dims=2, @@ -289,6 +294,7 @@ def test_timestep_with_wrong_shape(self): torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1, 1)).long(), torch.rand((1, 3, 16, 16)) ) + @skipUnless(has_einops, "Requires einops") def test_label_with_wrong_shape(self): net = SPADEDiffusionModelUNet( spatial_dims=2, @@ -304,6 +310,7 @@ def test_label_with_wrong_shape(self): with eval_mode(net): net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1,)).long(), torch.rand((1, 6, 16, 16))) + @skipUnless(has_einops, "Requires einops") def test_shape_with_different_in_channel_out_channel(self): in_channels = 6 out_channels = 3 @@ -363,6 +370,7 @@ def test_num_res_blocks_with_different_length_channels(self): norm_num_groups=8, ) + @skipUnless(has_einops, "Requires einops") def test_shape_conditioned_models(self): net = SPADEDiffusionModelUNet( spatial_dims=2, @@ -387,6 +395,7 @@ def test_shape_conditioned_models(self): ) self.assertEqual(result.shape, (1, 1, 16, 32)) + @skipUnless(has_einops, "Requires einops") def test_with_conditioning_cross_attention_dim_none(self): with self.assertRaises(ValueError): SPADEDiffusionModelUNet( @@ -403,6 +412,7 @@ def test_with_conditioning_cross_attention_dim_none(self): norm_num_groups=8, ) + @skipUnless(has_einops, "Requires einops") def test_context_with_conditioning_none(self): net = SPADEDiffusionModelUNet( spatial_dims=2, @@ -426,6 +436,7 @@ def test_context_with_conditioning_none(self): context=torch.rand((1, 1, 3)), ) + @skipUnless(has_einops, "Requires einops") def test_shape_conditioned_models_class_conditioning(self): net = SPADEDiffusionModelUNet( spatial_dims=2, @@ -448,6 +459,7 @@ def test_shape_conditioned_models_class_conditioning(self): ) self.assertEqual(result.shape, (1, 1, 16, 32)) + @skipUnless(has_einops, "Requires einops") def test_conditioned_models_no_class_labels(self): net = SPADEDiffusionModelUNet( spatial_dims=2, @@ -485,6 +497,7 @@ def test_model_channels_not_same_size_of_attention_levels(self): ) @parameterized.expand(COND_CASES_2D) + @skipUnless(has_einops, "Requires einops") def test_conditioned_2d_models_shape(self, input_param): net = SPADEDiffusionModelUNet(**input_param) with eval_mode(net): @@ -499,6 +512,7 @@ def test_conditioned_2d_models_shape(self, input_param): class TestDiffusionModelUNet3D(unittest.TestCase): @parameterized.expand(UNCOND_CASES_3D) + @skipUnless(has_einops, "Requires einops") def test_shape_unconditioned_models(self, input_param): net = SPADEDiffusionModelUNet(**input_param) with eval_mode(net): @@ -509,6 +523,7 @@ def test_shape_unconditioned_models(self, input_param): ) self.assertEqual(result.shape, (1, 1, 16, 16, 16)) + @skipUnless(has_einops, "Requires einops") def test_shape_with_different_in_channel_out_channel(self): in_channels = 6 out_channels = 3 @@ -530,6 +545,7 @@ def test_shape_with_different_in_channel_out_channel(self): ) self.assertEqual(result.shape, (1, out_channels, 16, 16, 16)) + @skipUnless(has_einops, "Requires einops") def test_shape_conditioned_models(self): net = SPADEDiffusionModelUNet( spatial_dims=3, diff --git a/tests/test_spatialattention.py b/tests/test_spatialattention.py new file mode 100644 index 0000000000..70b78263c5 --- /dev/null +++ b/tests/test_spatialattention.py @@ -0,0 +1,55 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from unittest import skipUnless + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.blocks.spatialattention import SpatialAttentionBlock +from monai.utils import optional_import + +einops, has_einops = optional_import("einops") + +TEST_CASES = [ + [ + {"spatial_dims": 2, "num_channels": 128, "num_head_channels": 32, "norm_num_groups": 32, "norm_eps": 1e-6}, + (1, 128, 32, 32), + (1, 128, 32, 32), + ], + [ + {"spatial_dims": 3, "num_channels": 16, "num_head_channels": 8, "norm_num_groups": 8, "norm_eps": 1e-6}, + (1, 16, 8, 8, 8), + (1, 16, 8, 8, 8), + ], +] + + +class TestBlock(unittest.TestCase): + @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_shape(self, input_param, input_shape, expected_shape): + net = SpatialAttentionBlock(**input_param) + with eval_mode(net): + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + def test_attention_dim_not_multiple_of_heads(self): + with self.assertRaises(ValueError): + SpatialAttentionBlock(spatial_dims=2, num_channels=128, num_head_channels=33) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_threadcontainer.py b/tests/test_threadcontainer.py index 9551dec703..568461748b 100644 --- a/tests/test_threadcontainer.py +++ b/tests/test_threadcontainer.py @@ -62,7 +62,7 @@ def test_container(self): self.assertTrue(con.is_alive) self.assertIsNotNone(con.status()) - self.assertTrue(len(con.status_dict) > 0) + self.assertGreater(len(con.status_dict), 0) con.join() diff --git a/tests/test_to_cupy.py b/tests/test_to_cupy.py index 5a1754e7c5..38400f0d3f 100644 --- a/tests/test_to_cupy.py +++ b/tests/test_to_cupy.py @@ -62,8 +62,8 @@ def test_numpy_input_dtype(self): test_data = np.rot90(test_data) self.assertFalse(test_data.flags["C_CONTIGUOUS"]) result = ToCupy(np.uint8)(test_data) - self.assertTrue(result.dtype == cp.uint8) - self.assertTrue(isinstance(result, cp.ndarray)) + self.assertEqual(result.dtype, cp.uint8) + self.assertIsInstance(result, cp.ndarray) self.assertTrue(result.flags["C_CONTIGUOUS"]) cp.testing.assert_allclose(result, test_data) @@ -72,8 +72,8 @@ def test_tensor_input(self): test_data = test_data.rot90() self.assertFalse(test_data.is_contiguous()) result = ToCupy()(test_data) - self.assertTrue(result.dtype == cp.float32) - self.assertTrue(isinstance(result, cp.ndarray)) + self.assertEqual(result.dtype, cp.float32) + self.assertIsInstance(result, cp.ndarray) self.assertTrue(result.flags["C_CONTIGUOUS"]) cp.testing.assert_allclose(result, test_data) @@ -83,8 +83,8 @@ def test_tensor_cuda_input(self): test_data = test_data.rot90() self.assertFalse(test_data.is_contiguous()) result = ToCupy()(test_data) - self.assertTrue(result.dtype == cp.float32) - self.assertTrue(isinstance(result, cp.ndarray)) + self.assertEqual(result.dtype, cp.float32) + self.assertIsInstance(result, cp.ndarray) self.assertTrue(result.flags["C_CONTIGUOUS"]) cp.testing.assert_allclose(result, test_data) @@ -95,8 +95,8 @@ def test_tensor_cuda_input_dtype(self): self.assertFalse(test_data.is_contiguous()) result = ToCupy(dtype="float32")(test_data) - self.assertTrue(result.dtype == cp.float32) - self.assertTrue(isinstance(result, cp.ndarray)) + self.assertEqual(result.dtype, cp.float32) + self.assertIsInstance(result, cp.ndarray) self.assertTrue(result.flags["C_CONTIGUOUS"]) cp.testing.assert_allclose(result, test_data) diff --git a/tests/test_to_numpy.py b/tests/test_to_numpy.py index f92b7c0075..f4e5f80a29 100644 --- a/tests/test_to_numpy.py +++ b/tests/test_to_numpy.py @@ -32,7 +32,7 @@ def test_cupy_input(self): test_data = cp.rot90(test_data) self.assertFalse(test_data.flags["C_CONTIGUOUS"]) result = ToNumpy()(test_data) - self.assertTrue(isinstance(result, np.ndarray)) + self.assertIsInstance(result, np.ndarray) self.assertTrue(result.flags["C_CONTIGUOUS"]) assert_allclose(result, test_data.get(), type_test=False) @@ -41,8 +41,8 @@ def test_numpy_input(self): test_data = np.rot90(test_data) self.assertFalse(test_data.flags["C_CONTIGUOUS"]) result = ToNumpy(dtype="float32")(test_data) - self.assertTrue(isinstance(result, np.ndarray)) - self.assertTrue(result.dtype == np.float32) + self.assertIsInstance(result, np.ndarray) + self.assertEqual(result.dtype, np.float32) self.assertTrue(result.flags["C_CONTIGUOUS"]) assert_allclose(result, test_data, type_test=False) @@ -51,7 +51,7 @@ def test_tensor_input(self): test_data = test_data.rot90() self.assertFalse(test_data.is_contiguous()) result = ToNumpy(dtype=torch.uint8)(test_data) - self.assertTrue(isinstance(result, np.ndarray)) + self.assertIsInstance(result, np.ndarray) self.assertTrue(result.flags["C_CONTIGUOUS"]) assert_allclose(result, test_data, type_test=False) @@ -61,7 +61,7 @@ def test_tensor_cuda_input(self): test_data = test_data.rot90() self.assertFalse(test_data.is_contiguous()) result = ToNumpy()(test_data) - self.assertTrue(isinstance(result, np.ndarray)) + self.assertIsInstance(result, np.ndarray) self.assertTrue(result.flags["C_CONTIGUOUS"]) assert_allclose(result, test_data, type_test=False) @@ -77,7 +77,7 @@ def test_list_tuple(self): def test_single_value(self): for test_data in [5, np.array(5), torch.tensor(5)]: result = ToNumpy(dtype=np.uint8)(test_data) - self.assertTrue(isinstance(result, np.ndarray)) + self.assertIsInstance(result, np.ndarray) assert_allclose(result, np.asarray(test_data), type_test=False) self.assertEqual(result.ndim, 0) diff --git a/tests/test_torchvision_fc_model.py b/tests/test_torchvision_fc_model.py index 322cce1161..9cc19db62c 100644 --- a/tests/test_torchvision_fc_model.py +++ b/tests/test_torchvision_fc_model.py @@ -195,8 +195,8 @@ def test_get_module(self): mod = look_up_named_module("model.1.submodule.1.submodule.1.submodule.0.conv", net) self.assertTrue(str(mod).startswith("Conv2d")) self.assertIsInstance(set_named_module(net, "model", torch.nn.Identity()).model, torch.nn.Identity) - self.assertEqual(look_up_named_module("model.1.submodule.1.submodule.1.submodule.conv", net), None) - self.assertEqual(look_up_named_module("test attribute", net), None) + self.assertIsNone(look_up_named_module("model.1.submodule.1.submodule.1.submodule.conv", net)) + self.assertIsNone(look_up_named_module("test attribute", net)) if __name__ == "__main__": diff --git a/tests/test_traceable_transform.py b/tests/test_traceable_transform.py index dd139053e3..6a499b2dd9 100644 --- a/tests/test_traceable_transform.py +++ b/tests/test_traceable_transform.py @@ -33,12 +33,12 @@ def test_default(self): expected_key = "_transforms" a = _TraceTest() for x in a.transform_info_keys(): - self.assertTrue(x in a.get_transform_info()) + self.assertIn(x, a.get_transform_info()) self.assertEqual(a.trace_key(), expected_key) data = {"image": "test"} data = a(data) # adds to the stack - self.assertTrue(isinstance(data[expected_key], list)) + self.assertIsInstance(data[expected_key], list) self.assertEqual(data[expected_key][0]["class"], "_TraceTest") data = a(data) # adds to the stack diff --git a/tests/test_transformer.py b/tests/test_transformer.py index ea6ebdf50f..b371809d47 100644 --- a/tests/test_transformer.py +++ b/tests/test_transformer.py @@ -11,15 +11,22 @@ from __future__ import annotations +import os +import tempfile import unittest +from unittest import skipUnless import numpy as np import torch from parameterized import parameterized +from monai.apps import download_url from monai.networks import eval_mode from monai.networks.nets import DecoderOnlyTransformer +from monai.utils import optional_import +from tests.utils import skip_if_downloading_fails, testing_data_config +_, has_einops = optional_import("einops") TEST_CASES = [] for dropout_rate in np.linspace(0, 1, 2): for attention_layer_dim in [360, 480, 600, 768]: @@ -40,12 +47,14 @@ class TestDecoderOnlyTransformer(unittest.TestCase): @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_unconditioned_models(self, input_param): net = DecoderOnlyTransformer(**input_param) with eval_mode(net): net.forward(torch.randint(0, 10, (1, 16))) @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_conditioned_models(self, input_param): net = DecoderOnlyTransformer(**input_param, with_cross_attention=True) with eval_mode(net): @@ -57,7 +66,9 @@ def test_attention_dim_not_multiple_of_heads(self): num_tokens=10, max_seq_len=16, attn_layers_dim=8, attn_layers_depth=2, attn_layers_heads=3 ) + @skipUnless(has_einops, "Requires einops") def test_dropout_rate_negative(self): + with self.assertRaises(ValueError): DecoderOnlyTransformer( num_tokens=10, @@ -68,6 +79,31 @@ def test_dropout_rate_negative(self): embedding_dropout_rate=-1, ) + @skipUnless(has_einops, "Requires einops") + def test_compatibility_with_monai_generative(self): + # test loading weights from a model saved in MONAI Generative, version 0.2.3 + with skip_if_downloading_fails(): + net = DecoderOnlyTransformer( + num_tokens=10, + max_seq_len=16, + attn_layers_dim=8, + attn_layers_depth=2, + attn_layers_heads=2, + with_cross_attention=True, + embedding_dropout_rate=0, + ) + + tmpdir = tempfile.mkdtemp() + key = "decoder_only_transformer_monai_generative_weights" + url = testing_data_config("models", key, "url") + hash_type = testing_data_config("models", key, "hash_type") + hash_val = testing_data_config("models", key, "hash_val") + filename = "decoder_only_transformer_monai_generative_weights.pt" + weight_path = os.path.join(tmpdir, filename) + download_url(url=url, filepath=weight_path, hash_val=hash_val, hash_type=hash_type) + + net.load_old_state_dict(torch.load(weight_path), verbose=False) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_transformerblock.py b/tests/test_transformerblock.py index 5a8dbba83c..a850cc6f74 100644 --- a/tests/test_transformerblock.py +++ b/tests/test_transformerblock.py @@ -12,6 +12,7 @@ from __future__ import annotations import unittest +from unittest import skipUnless import numpy as np import torch @@ -19,28 +20,33 @@ from monai.networks import eval_mode from monai.networks.blocks.transformerblock import TransformerBlock +from monai.utils import optional_import +einops, has_einops = optional_import("einops") TEST_CASE_TRANSFORMERBLOCK = [] for dropout_rate in np.linspace(0, 1, 4): for hidden_size in [360, 480, 600, 768]: for num_heads in [4, 8, 12]: for mlp_dim in [1024, 3072]: - test_case = [ - { - "hidden_size": hidden_size, - "num_heads": num_heads, - "mlp_dim": mlp_dim, - "dropout_rate": dropout_rate, - }, - (2, 512, hidden_size), - (2, 512, hidden_size), - ] - TEST_CASE_TRANSFORMERBLOCK.append(test_case) + for cross_attention in [False, True]: + test_case = [ + { + "hidden_size": hidden_size, + "num_heads": num_heads, + "mlp_dim": mlp_dim, + "dropout_rate": dropout_rate, + "with_cross_attention": cross_attention, + }, + (2, 512, hidden_size), + (2, 512, hidden_size), + ] + TEST_CASE_TRANSFORMERBLOCK.append(test_case) class TestTransformerBlock(unittest.TestCase): @parameterized.expand(TEST_CASE_TRANSFORMERBLOCK) + @skipUnless(has_einops, "Requires einops") def test_shape(self, input_param, input_shape, expected_shape): net = TransformerBlock(**input_param) with eval_mode(net): @@ -54,6 +60,7 @@ def test_ill_arg(self): with self.assertRaises(ValueError): TransformerBlock(hidden_size=622, num_heads=8, mlp_dim=3072, dropout_rate=0.4) + @skipUnless(has_einops, "Requires einops") def test_access_attn_matrix(self): # input format hidden_size = 128 diff --git a/tests/test_vqvaetransformer_inferer.py b/tests/test_vqvaetransformer_inferer.py index 1a511d287b..36b715f588 100644 --- a/tests/test_vqvaetransformer_inferer.py +++ b/tests/test_vqvaetransformer_inferer.py @@ -12,14 +12,17 @@ from __future__ import annotations import unittest +from unittest import skipUnless import torch from parameterized import parameterized from monai.inferers import VQVAETransformerInferer from monai.networks.nets import VQVAE, DecoderOnlyTransformer +from monai.utils import optional_import from monai.utils.ordering import Ordering, OrderingType +einops, has_einops = optional_import("einops") TEST_CASES = [ [ { @@ -78,6 +81,7 @@ class TestVQVAETransformerInferer(unittest.TestCase): @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_prediction_shape( self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape ): @@ -98,6 +102,7 @@ def test_prediction_shape( self.assertEqual(prediction.shape, logits_shape) @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_prediction_shape_shorter_sequence( self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape ): @@ -121,7 +126,9 @@ def test_prediction_shape_shorter_sequence( cropped_logits_shape = (logits_shape[0], max_seq_len, logits_shape[2]) self.assertEqual(prediction.shape, cropped_logits_shape) + @skipUnless(has_einops, "Requires einops") def test_sample(self): + stage_1 = VQVAE( spatial_dims=2, in_channels=1, @@ -163,6 +170,7 @@ def test_sample(self): ) self.assertEqual(sample.shape, (2, 1, 8, 8)) + @skipUnless(has_einops, "Requires einops") def test_sample_shorter_sequence(self): stage_1 = VQVAE( spatial_dims=2, @@ -206,6 +214,7 @@ def test_sample_shorter_sequence(self): self.assertEqual(sample.shape, (2, 1, 8, 8)) @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_get_likelihood( self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape ): @@ -228,6 +237,7 @@ def test_get_likelihood( self.assertEqual(likelihood.shape, latent_shape) @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_get_likelihood_shorter_sequence( self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape ): @@ -253,6 +263,7 @@ def test_get_likelihood_shorter_sequence( self.assertEqual(likelihood.shape, latent_shape) @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") def test_get_likelihood_resampling( self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape ): diff --git a/tests/test_warp.py b/tests/test_warp.py index bac595224f..55f40764c3 100644 --- a/tests/test_warp.py +++ b/tests/test_warp.py @@ -124,7 +124,7 @@ def test_itk_benchmark(self): relative_diff = np.mean( np.divide(monai_result - itk_result, itk_result, out=np.zeros_like(itk_result), where=(itk_result != 0)) ) - self.assertTrue(relative_diff < 0.01) + self.assertLess(relative_diff, 0.01) @parameterized.expand(TEST_CASES, skip_on_empty=True) def test_resample(self, input_param, input_data, expected_val): diff --git a/tests/testing_data/data_config.json b/tests/testing_data/data_config.json index a570c787ba..318331e5f7 100644 --- a/tests/testing_data/data_config.json +++ b/tests/testing_data/data_config.json @@ -138,6 +138,21 @@ "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/ssl_pretrained_weights.pth", "hash_type": "sha256", "hash_val": "c3564f40a6a051d3753a6d8fae5cc8eaf21ce8d82a9a3baf80748d15664055e8" + }, + "decoder_only_transformer_monai_generative_weights": { + "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/decoder_only_transformer.pth", + "hash_type": "sha256", + "hash_val": "f93de37d64d77cf91f3bde95cdf93d161aee800074c89a92aff9d5699120ec0d" + }, + "diffusion_model_unet_monai_generative_weights": { + "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/diffusion_model_unet.pth", + "hash_type": "sha256", + "hash_val": "0d2171b386902f5b4fd3e967b4024f63e353694ca45091b114970019d045beee" + }, + "autoencoderkl_monai_generative_weights": { + "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/autoencoderkl.pth", + "hash_type": "sha256", + "hash_val": "6e02c9540c51b16b9ba98b5c0c75d6b84b430afe9a3237df1d67a520f8d34184" } }, "configs": {