Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

7227 refactor transformer and diffusion model unet #7715

Merged
merged 144 commits into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
144 commits
Select commit Hold shift + click to select a range
2fc0126
wholeBody_ct_segmentation failed to be download (#7280)
KumoLiu Dec 4, 2023
88f8dd2
update the Python version requirements for transformers (#7275)
KumoLiu Dec 5, 2023
d5585c3
7263 add diffusion loss (#7272)
kvttt Dec 5, 2023
8212458
Fix swinunetrv2 2D bug (#7302)
heyufan1995 Dec 8, 2023
b1c8b42
Fix `RuntimeError` in `DataAnalyzer` (#7310)
KumoLiu Dec 12, 2023
6126a6e
Support specified filenames in `Saveimage` (#7318)
KumoLiu Dec 14, 2023
4c546b3
Fix typo (#7321)
KumoLiu Dec 15, 2023
01a0ee1
fix optimizer pararmeter issue (#7322)
binliunls Dec 15, 2023
90d2acb
Fix `lazy` ignored in `SpatialPadd` (#7316)
KumoLiu Dec 18, 2023
f5327ad
Update openslide-python version (#7344)
KumoLiu Dec 28, 2023
b3d7a48
Upgrade the version of `transformers` (#7343)
KumoLiu Dec 29, 2023
02973f0
Bump github/codeql-action from 2 to 3 (#7354)
dependabot[bot] Jan 2, 2024
38ac573
Bump actions/upload-artifact from 3 to 4 (#7350)
dependabot[bot] Jan 2, 2024
454f89c
Bump actions/setup-python from 4 to 5 (#7351)
dependabot[bot] Jan 2, 2024
db6e81d
Bump actions/download-artifact from 3 to 4 (#7352)
dependabot[bot] Jan 2, 2024
8fa6931
Bump peter-evans/slash-command-dispatch from 3.0.1 to 3.0.2 (#7353)
dependabot[bot] Jan 3, 2024
445d750
Give more useful exception when batch is considered during matrix mul…
KumoLiu Jan 8, 2024
d7137cf
Fix incorrectly size compute in auto3dseg analyzer (#7374)
KumoLiu Jan 9, 2024
e1ffa7e
7380 mention demo in bending energy and diffusion docstrings (#7381)
kvttt Jan 10, 2024
25f1901
Pin gdown version to v4.6.3 (#7384)
KumoLiu Jan 12, 2024
78295c7
Fix Premerge (#7397)
KumoLiu Jan 18, 2024
80be1c3
Track applied operations in image filter (#7395)
vlaminckaxel Jan 18, 2024
facf176
Add `compile` support in `SupervisedTrainer` and `SupervisedEvaluator…
KumoLiu Jan 19, 2024
abe2e31
Fix CUDA_VISIBLE_DEVICES setting ignored (#7408)
KumoLiu Jan 22, 2024
6c9e49e
Fix Incorrect updated affine in `NrrdReader` and update docstring in …
KumoLiu Jan 25, 2024
df4cf92
Ignore E704 after update black (#7422)
KumoLiu Jan 30, 2024
cd36c1c
update `rm -rf /opt/hostedtoolcache` avoid change the python version …
KumoLiu Feb 1, 2024
6def75e
Bump peter-evans/slash-command-dispatch from 3.0.2 to 4.0.0 (#7428)
dependabot[bot] Feb 1, 2024
6ea3e39
Bump peter-evans/create-or-update-comment from 3 to 4 (#7429)
dependabot[bot] Feb 2, 2024
5301443
Bump actions/cache from 3 to 4 (#7430)
dependabot[bot] Feb 2, 2024
aa4f578
Bump codecov/codecov-action from 3 to 4 (#7431)
dependabot[bot] Feb 2, 2024
33afaef
Update tensorboard version to fix deadlock (#7435)
KumoLiu Feb 2, 2024
718e4be
auto updates (#7439)
monai-bot Feb 5, 2024
ec2cc83
Instantiation mode `"partial"` to `"callable"`. Return the `_target_`…
ibro45 Feb 6, 2024
c27cc98
Add support for mlflow experiment name in auto3dseg (#7442)
bhashemian Feb 6, 2024
ff43028
Update gdown version (#7448)
KumoLiu Feb 7, 2024
a8080fc
Skip "test_gaussian_filter" as a workaround for blossom killed (#7474)
KumoLiu Feb 20, 2024
50f9aea
auto updates (#7463)
monai-bot Feb 20, 2024
df37a22
Skip "test_resize" as a workaround for blossom killed (#7484)
KumoLiu Feb 21, 2024
1b93988
Fix Python 3.12 import AttributeError (#7482)
KumoLiu Feb 21, 2024
f4103c5
Update test_nnunetv2runner (#7483)
KumoLiu Feb 22, 2024
9c77a04
Fix github resource issue when build latest docker (#7450)
KumoLiu Feb 23, 2024
f6f9e81
Use int16 instead of int8 in `LabelStats` (#7489)
KumoLiu Feb 23, 2024
20512d3
auto updates (#7495)
monai-bot Feb 26, 2024
7cfa2c9
Add sample_std parameter to RandGaussianNoise. (#7492)
bakert1 Feb 26, 2024
9830525
Add __repr__ and __str__ to Metrics baseclass (#7487)
MathijsdeBoer Feb 28, 2024
02c7f53
Bump al-cheb/configure-pagefile-action from 1.3 to 1.4 (#7510)
dependabot[bot] Mar 1, 2024
e9e2738
Add arm support (#7500)
KumoLiu Mar 3, 2024
95f69de
Fix error in "test_bundle_trt_export" (#7524)
KumoLiu Mar 10, 2024
6b7568d
Fix typo in the PerceptualNetworkType Enum (#7548)
SomeUserName1 Mar 15, 2024
ec63e06
Update to use `log_sigmoid` in `FocalLoss` (#7534)
KumoLiu Mar 18, 2024
35c93fd
Update integration_segmentation_3d result for PyTorch2403 (#7551)
KumoLiu Mar 22, 2024
c649934
Add Barlow Twins loss for representation learning (#7530)
Lucas-rbnt Mar 22, 2024
c3a7383
Stein's Unbiased Risk Estimator (SURE) loss and Conjugate Gradient (#…
cxlcl Mar 22, 2024
c86e790
auto updates (#7577)
monai-bot Mar 25, 2024
97678fa
Remove nested error propagation on `ConfigComponent` instantiate (#7569)
surajpaib Mar 26, 2024
e5bebfc
2872 implementation of mixup, cutmix and cutout (#7198)
juampatronics Mar 26, 2024
2716b6a
Remove device count cache when import monai (#7581)
KumoLiu Mar 27, 2024
c9fed96
Fixing gradient in sincos positional encoding in monai/networks/block…
Lucas-rbnt Mar 27, 2024
ba3c72c
Fix inconsistent alpha parameter/docs for RandGibbsNoise/RandGibbsNoi…
johnzielke Mar 27, 2024
7c0b10e
Fix bundle_root for NNIGen (#7586)
mingxin-zheng Mar 27, 2024
2d463a7
Auto3DSeg algo_template hash update (#7589)
monai-bot Mar 27, 2024
15d2abf
Utilizing subprocess for nnUNet training. (#7576)
KumoLiu Apr 1, 2024
ec4d946
typo fix (#7595)
scalyvladimir Apr 1, 2024
a7c2589
auto updates (#7599)
monai-bot Apr 1, 2024
c885100
7540 change bundle workflow args (#7549)
yiheng-wang-nv Apr 1, 2024
264b9e4
Add "properties_path" in BundleWorkflow (#7542)
KumoLiu Apr 1, 2024
bbaaf4c
Auto3DSeg algo_template hash update (#7603)
monai-bot Apr 1, 2024
5ec7305
ENH: generate_label_classes_crop_centers: warn only if ratio of missi…
lorinczszabolcs Apr 2, 2024
763347d
Update base image to 2403 (#7600)
KumoLiu Apr 3, 2024
195d7dd
simplification of the sincos positional encoding in patchembedding.py…
Lucas-rbnt Apr 4, 2024
625967c
harmonization and clarification of dice losses variants docs and asso…
Lucas-rbnt Apr 5, 2024
c0b9cc0
Implementation of intensity clipping transform: bot hard and soft app…
Lucas-rbnt Apr 5, 2024
87152d1
Fix typo in `SSIMMetric` (#7612)
KumoLiu Apr 8, 2024
e9a5bfe
auto updates (#7614)
monai-bot Apr 10, 2024
54a6991
Fix test error in `test_soft_clipping_one_sided_high` (#7624)
KumoLiu Apr 11, 2024
3856c45
Fix deprecated warning in ruff (#7625)
KumoLiu Apr 11, 2024
da3ecdd
7601 fix mlflow artifacts (#7604)
binliunls Apr 12, 2024
1268488
Uninstall opencv included in base image (#7626)
KumoLiu Apr 12, 2024
9e2904a
Add checks for num_fold and fail early if wrong (#7634)
mingxin-zheng Apr 12, 2024
0497448
Auto3DSeg algo_template hash update (#7642)
monai-bot Apr 15, 2024
605ffe1
Auto3DSeg algo_template hash update (#7643)
monai-bot Apr 15, 2024
bff4b15
Remove source code of numcodecs in the Dockerfile (#7644)
KumoLiu Apr 15, 2024
16d4e2f
Remove memory_pool_limit in trt config (#7647)
KumoLiu Apr 16, 2024
d6e6b24
Add version requirement for mlflow (#7659)
KumoLiu Apr 19, 2024
ffd4454
Auto3DSeg algo_template hash update (#7674)
monai-bot Apr 19, 2024
224c47a
Fixed four test issues within test code. (#7662)
freddiewanah Apr 19, 2024
7a6b69f
Adapt to use assert raises (#7670)
freddiewanah Apr 19, 2024
03a5fa6
MedicalNetPerceptualSimilarity: Add multi-channel (#7568)
SomeUserName1 Apr 19, 2024
c6bf8e9
Workaround for B909 in flake8-bugbear (#7691)
KumoLiu Apr 22, 2024
178ebc8
Fix AttributeError in 'PerceptualLoss' (#7693)
KumoLiu Apr 22, 2024
ac9b186
Always convert input to C-order in distance_transform_edt (#7675)
KumoLiu Apr 23, 2024
a59676f
Auto3DSeg algo_template hash update (#7695)
monai-bot Apr 23, 2024
ec6aa33
Merge similar test components with parameterized (#7663)
freddiewanah Apr 23, 2024
dc58e5c
Add ResNet backbones for FlexibleUNet (#7571)
k-sukharev Apr 23, 2024
1c07a17
Refactored test assertions that have suboptimal tests with numbers (#…
freddiewanah Apr 23, 2024
07a78d2
Auto3DSeg algo_template hash update (#7700)
monai-bot Apr 24, 2024
c3e4457
Update pycln version (#7704)
KumoLiu Apr 24, 2024
bfe09b8
Refactored others type of subotimal asserts (#7672)
freddiewanah Apr 24, 2024
8c709de
Fix download failing on FIPS machines (#7698)
MattTheCuber Apr 25, 2024
1fffe05
Moves to MONAI self attention block and adds CrossAttentionBlock to M…
marksgraham Apr 25, 2024
a686a03
Replaces diffusion_model_unet self attention with MONAI block, with s…
marksgraham Apr 25, 2024
ca1c6ae
Replaces diffusion_model_unet cross attention block with newly added …
marksgraham Apr 25, 2024
ab4aaa8
Create a spatial attention block
marksgraham Apr 25, 2024
6a130cc
7713 Update TRT parameter (#7714)
binliunls Apr 26, 2024
be74e2c
Positional args
marksgraham Apr 26, 2024
0869aa5
try tests with einops skip
marksgraham Apr 26, 2024
9eaa426
einops
marksgraham Apr 26, 2024
d4bdecd
Move SABlock back to selfattention.py and place new attention blocks …
marksgraham Apr 26, 2024
4c193ea
Fix itk install error when python=3.8 (#7719)
KumoLiu Apr 29, 2024
5650899
auto updates (#7723)
monai-bot Apr 29, 2024
edba500
Make use of shared transformer block, add upcast_attention option
marksgraham Apr 29, 2024
09ee954
Use MONAI internal upsample block
marksgraham Apr 29, 2024
799dabd
Fixed upsample/downsample
marksgraham Apr 29, 2024
b528338
Remove all traces of flash_attention
marksgraham Apr 29, 2024
7bda11b
Updates arguments and tests
marksgraham Apr 29, 2024
e1a69b0
Auto3DSeg algo_template hash update (#7728)
monai-bot Apr 30, 2024
eecd0d3
DCO Remediation Commit for Mark Graham <[email protected]>
marksgraham Apr 30, 2024
d47dfaf
Add einops skips to tests
marksgraham Apr 30, 2024
68d4021
More einops skips
marksgraham Apr 30, 2024
23e45d3
More einops skips
marksgraham Apr 30, 2024
22c2333
More einops skips
marksgraham Apr 30, 2024
a38ea3f
Fix itk install error when python=3.8 (#7719)
KumoLiu Apr 29, 2024
d29a545
Update pycln version (#7704)
KumoLiu Apr 24, 2024
bb2fbe9
Mypy
marksgraham Apr 30, 2024
a4088ec
Reduce memory usage in test
marksgraham Apr 30, 2024
6d3b4a7
Fix test args
marksgraham Apr 30, 2024
a3ca76d
DCO Remediation Commit for Mark Graham <[email protected]>
marksgraham Apr 30, 2024
fe733b0
Propagate kernel size through attention Attention-UNet (#7734)
Pkaps25 May 7, 2024
d4df446
Add compatability tests
marksgraham May 7, 2024
59f8373
Add compatability tests
marksgraham May 7, 2024
1d411b8
Add compatability tests
marksgraham May 7, 2024
6831413
Add compatability tests
marksgraham May 7, 2024
ecaf5a1
Fixed misguiding weight mode documentation (#7746)
simojens May 7, 2024
3a64373
Adds ref for causal attention and allows user to specify attention dtype
marksgraham May 7, 2024
ce4a5bd
Attention dtype
marksgraham May 7, 2024
7f6319e
Attention dtype
marksgraham May 7, 2024
32b7754
Enhance logging logic in `ConfigWorkflow` (#7745)
KumoLiu May 8, 2024
f278e51
Add version requirement for filelock and nni (#7744)
KumoLiu May 8, 2024
d83fa56
Add dimensionality of heads argument to SABlock (#7664)
NabJa May 8, 2024
98b38d8
Merge dev
marksgraham May 9, 2024
ae96458
Attempt to fix contigous issue
marksgraham May 10, 2024
defe4df
Attempt to fix contigous issue
marksgraham May 10, 2024
f0c59aa
Remove tests added unintentionally
marksgraham May 10, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/pythonapp-min.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ on:
- main
- releasing/*
pull_request:
head_ref-ignore:
- dev

concurrency:
# automatically cancel the previously triggered workflows when there's a newer version
Expand Down
6 changes: 4 additions & 2 deletions .github/workflows/pythonapp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ on:
- main
- releasing/*
pull_request:
head_ref-ignore:
- dev

concurrency:
# automatically cancel the previously triggered workflows when there's a newer version
Expand Down Expand Up @@ -68,10 +70,10 @@ jobs:
maximum-size: 16GB
disk-root: "D:"
- uses: actions/checkout@v4
- name: Set up Python 3.8
- name: Set up Python 3.9
uses: actions/setup-python@v5
with:
python-version: '3.8'
python-version: '3.9'
- name: Prepare pip wheel
run: |
which python
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ repos:
)$

- repo: https://github.com/hadialqattan/pycln
rev: v2.1.3
rev: v2.4.0
hooks:
- id: pycln
args: [--config=pyproject.toml]
9 changes: 4 additions & 5 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,10 @@ LABEL maintainer="[email protected]"

# TODO: remark for issue [revise the dockerfile](https://github.com/zarr-developers/numcodecs/issues/431)
RUN if [[ $(uname -m) =~ "aarch64" ]]; then \
cd /opt && \
git clone --branch v0.12.1 --recursive https://github.com/zarr-developers/numcodecs && \
pip wheel numcodecs && \
rm -r /opt/*.whl && \
rm -rf /opt/numcodecs; \
export CFLAGS="-O3" && \
export DISABLE_NUMCODECS_SSE2=true && \
export DISABLE_NUMCODECS_AVX2=true && \
pip install numcodecs; \
fi

WORKDIR /opt/monai
Expand Down
1 change: 1 addition & 0 deletions docs/source/networks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,7 @@ Layers
.. autoclass:: monai.networks.layers.vector_quantizer.VectorQuantizer
:members:

=======
`ConjugateGradient`
~~~~~~~~~~~~~~~~~~~
.. autoclass:: ConjugateGradient
Expand Down
7 changes: 6 additions & 1 deletion monai/apps/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,12 @@ def check_hash(filepath: PathLike, val: str | None = None, hash_type: str = "md5
logger.info(f"Expected {hash_type} is None, skip {hash_type} check for file {filepath}.")
return True
actual_hash_func = look_up_option(hash_type.lower(), SUPPORTED_HASH_TYPES)
actual_hash = actual_hash_func()

if sys.version_info >= (3, 9):
actual_hash = actual_hash_func(usedforsecurity=False) # allows checks on FIPS enabled machines
else:
actual_hash = actual_hash_func()

try:
with open(filepath, "rb") as f:
for chunk in iter(lambda: f.read(1024 * 1024), b""):
Expand Down
10 changes: 7 additions & 3 deletions monai/bundle/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ class ConfigWorkflow(BundleWorkflow):
logging_file: config file for `logging` module in the program. for more details:
https://docs.python.org/3/library/logging.config.html#logging.config.fileConfig.
If None, default to "configs/logging.conf", which is commonly used for bundles in MONAI model zoo.
If False, the logging logic for the bundle will not be modified.
init_id: ID name of the expected config expression to initialize before running, default to "initialize".
allow a config to have no `initialize` logic and the ID.
run_id: ID name of the expected config expression to run, default to "run".
Expand Down Expand Up @@ -278,7 +279,7 @@ def __init__(
self,
config_file: str | Sequence[str],
meta_file: str | Sequence[str] | None = None,
logging_file: str | None = None,
logging_file: str | bool | None = None,
init_id: str = "initialize",
run_id: str = "run",
final_id: str = "finalize",
Expand Down Expand Up @@ -307,15 +308,18 @@ def __init__(
super().__init__(workflow_type=workflow_type, meta_file=meta_file, properties_path=properties_path)
self.config_root_path = config_root_path
logging_file = str(self.config_root_path / "logging.conf") if logging_file is None else logging_file
if logging_file is not None:

if logging_file is False:
logger.warn(f"Logging file is set to {logging_file}, skipping logging.")
else:
if not os.path.isfile(logging_file):
if logging_file == str(self.config_root_path / "logging.conf"):
logger.warn(f"Default logging file in {logging_file} does not exist, skipping logging.")
else:
raise FileNotFoundError(f"Cannot find the logging config file: {logging_file}.")
else:
logger.info(f"Setting logging properties based on config: {logging_file}.")
fileConfig(logging_file, disable_existing_loggers=False)
fileConfig(str(logging_file), disable_existing_loggers=False)

self.parser = ConfigParser()
self.parser.read_config(f=config_file)
Expand Down
14 changes: 9 additions & 5 deletions monai/fl/client/monai_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,14 @@ def initialize(self, extra=None):

Args:
extra: Dict with additional information that should be provided by FL system,
i.e., `ExtraItems.CLIENT_NAME` and `ExtraItems.APP_ROOT`.
i.e., `ExtraItems.CLIENT_NAME`, `ExtraItems.APP_ROOT` and `ExtraItems.LOGGING_FILE`.
You can diable the logging logic in the monai bundle by setting {ExtraItems.LOGGING_FILE} to False.

"""
if extra is None:
extra = {}
self.client_name = extra.get(ExtraItems.CLIENT_NAME, "noname")
logging_file = extra.get(ExtraItems.LOGGING_FILE, None)
self.logger.info(f"Initializing {self.client_name} ...")

# FL platform needs to provide filepath to configuration files
Expand All @@ -149,7 +151,7 @@ def initialize(self, extra=None):
if self.workflow is None:
config_train_files = self._add_config_files(self.config_train_filename)
self.workflow = ConfigWorkflow(
config_file=config_train_files, meta_file=None, logging_file=None, workflow_type="train"
config_file=config_train_files, meta_file=None, logging_file=logging_file, workflow_type="train"
)
self.workflow.initialize()
self.workflow.bundle_root = self.bundle_root
Expand Down Expand Up @@ -412,13 +414,15 @@ def initialize(self, extra=None):

Args:
extra: Dict with additional information that should be provided by FL system,
i.e., `ExtraItems.CLIENT_NAME` and `ExtraItems.APP_ROOT`.
i.e., `ExtraItems.CLIENT_NAME`, `ExtraItems.APP_ROOT` and `ExtraItems.LOGGING_FILE`.
You can diable the logging logic in the monai bundle by setting {ExtraItems.LOGGING_FILE} to False.

"""
self._set_cuda_device()
if extra is None:
extra = {}
self.client_name = extra.get(ExtraItems.CLIENT_NAME, "noname")
logging_file = extra.get(ExtraItems.LOGGING_FILE, None)
timestamp = time.strftime("%Y%m%d_%H%M%S")
self.logger.info(f"Initializing {self.client_name} ...")
# FL platform needs to provide filepath to configuration files
Expand All @@ -434,7 +438,7 @@ def initialize(self, extra=None):
self.train_workflow = ConfigWorkflow(
config_file=config_train_files,
meta_file=None,
logging_file=None,
logging_file=logging_file,
workflow_type="train",
**self.train_kwargs,
)
Expand All @@ -459,7 +463,7 @@ def initialize(self, extra=None):
self.eval_workflow = ConfigWorkflow(
config_file=config_eval_files,
meta_file=None,
logging_file=None,
logging_file=logging_file,
workflow_type=self.eval_workflow_name,
**self.eval_kwargs,
)
Expand Down
1 change: 1 addition & 0 deletions monai/fl/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class ExtraItems(StrEnum):
CLIENT_NAME = "fl_client_name"
APP_ROOT = "fl_app_root"
STATS_SENDER = "fl_stats_sender"
LOGGING_FILE = "logging_file"


class FlPhase(StrEnum):
Expand Down
2 changes: 1 addition & 1 deletion monai/losses/ds_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, loss: _Loss, weight_mode: str = "exp", weights: list[float] |
weight_mode: {``"same"``, ``"exp"``, ``"two"``}
Specifies the weights calculation for each image level. Defaults to ``"exp"``.
- ``"same"``: all weights are equal to 1.
- ``"exp"``: exponentially decreasing weights by a power of 2: 0, 0.5, 0.25, 0.125, etc .
- ``"exp"``: exponentially decreasing weights by a power of 2: 1, 0.5, 0.25, 0.125, etc .
- ``"two"``: equal smaller weights for lower levels: 1, 0.5, 0.5, 0.5, 0.5, etc
weights: a list of weights to apply to each deeply supervised sub-loss, if provided, this will be used
regardless of the weight_mode
Expand Down
2 changes: 2 additions & 0 deletions monai/networks/blocks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .backbone_fpn_utils import BackboneWithFPN
from .convolutions import Convolution, ResidualUnit
from .crf import CRF
from .crossattention import CrossAttentionBlock
from .denseblock import ConvDenseBlock, DenseBlock
from .dints_block import ActiConvNormBlock, FactorizedIncreaseBlock, FactorizedReduceBlock, P3DActiConvNormBlock
from .downsample import MaxAvgPool
Expand All @@ -31,6 +32,7 @@
from .segresnet_block import ResBlock
from .selfattention import SABlock
from .spade_norm import SPADE
from .spatialattention import SpatialAttentionBlock
from .squeeze_and_excitation import (
ChannelSELayer,
ResidualSELayer,
Expand Down
166 changes: 166 additions & 0 deletions monai/networks/blocks/crossattention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Optional, Tuple

import torch
import torch.nn as nn

from monai.networks.layers.utils import get_rel_pos_embedding_layer
from monai.utils import optional_import

Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")


class CrossAttentionBlock(nn.Module):
"""
A cross-attention block, based on: "Dosovitskiy et al.,
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
One can setup relative positional embedding as described in <https://arxiv.org/abs/2112.01526>
"""

def __init__(
self,
hidden_size: int,
num_heads: int,
dropout_rate: float = 0.0,
hidden_input_size: int | None = None,
context_input_size: int | None = None,
dim_head: int | None = None,
qkv_bias: bool = False,
save_attn: bool = False,
causal: bool = False,
sequence_length: int | None = None,
rel_pos_embedding: Optional[str] = None,
input_size: Optional[Tuple] = None,
attention_dtype: Optional[torch.dtype] = None,
) -> None:
"""
Args:
hidden_size (int): dimension of hidden layer.
num_heads (int): number of attention heads.
dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0.
hidden_input_size (int, optional): dimension of the input tensor. Defaults to hidden_size.
context_input_size (int, optional): dimension of the context tensor. Defaults to hidden_size.
dim_head (int, optional): dimension of each head. Defaults to hidden_size // num_heads.
qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False.
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
causal: whether to use causal attention.
sequence_length: if causal is True, it is necessary to specify the sequence length.
rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map.
For now only "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported.
input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative
positional parameter size.
attention_dtype: cast attention operations to this dtype.
"""

super().__init__()

if not (0 <= dropout_rate <= 1):
raise ValueError("dropout_rate should be between 0 and 1.")

if dim_head:
inner_size = num_heads * dim_head
self.head_dim = dim_head
else:
if hidden_size % num_heads != 0:
raise ValueError("hidden size should be divisible by num_heads.")
inner_size = hidden_size
self.head_dim = hidden_size // num_heads

if causal and sequence_length is None:
raise ValueError("sequence_length is necessary for causal attention.")

self.num_heads = num_heads
self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size
self.context_input_size = context_input_size if context_input_size else hidden_size
self.out_proj = nn.Linear(inner_size, self.hidden_input_size)
# key, query, value projections
self.to_q = nn.Linear(self.hidden_input_size, inner_size, bias=qkv_bias)
self.to_k = nn.Linear(self.context_input_size, inner_size, bias=qkv_bias)
self.to_v = nn.Linear(self.context_input_size, inner_size, bias=qkv_bias)
self.input_rearrange = Rearrange("b h (l d) -> b l h d", l=num_heads)

self.out_rearrange = Rearrange("b h l d -> b l (h d)")
self.drop_output = nn.Dropout(dropout_rate)
self.drop_weights = nn.Dropout(dropout_rate)

self.scale = self.head_dim**-0.5
self.save_attn = save_attn
self.attention_dtype = attention_dtype

self.causal = causal
self.sequence_length = sequence_length

if causal and sequence_length is not None:
# causal mask to ensure that attention is only applied to the left in the input sequence
self.register_buffer(
"causal_mask",
torch.tril(torch.ones(sequence_length, sequence_length)).view(1, 1, sequence_length, sequence_length),
)
self.causal_mask: torch.Tensor

self.att_mat = torch.Tensor()
self.rel_positional_embedding = (
get_rel_pos_embedding_layer(rel_pos_embedding, input_size, self.head_dim, self.num_heads)
if rel_pos_embedding is not None
else None
)
self.input_size = input_size

def forward(self, x: torch.Tensor, context: torch.Tensor | None = None):
"""
Args:
x (torch.Tensor): input tensor. B x (s_dim_1 * ... * s_dim_n) x C
context (torch.Tensor, optional): context tensor. B x (s_dim_1 * ... * s_dim_n) x C

Return:
torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C
"""
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
b, t, c = x.size() # batch size, sequence length, embedding dimensionality (hidden_size)

q = self.to_q(x)
kv = context if context is not None else x
_, kv_t, _ = kv.size()
k = self.to_k(kv)
v = self.to_v(kv)

if self.attention_dtype is not None:
q = q.to(self.attention_dtype)
k = k.to(self.attention_dtype)

q = q.view(b, t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, t, hs)
k = k.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs)
v = v.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs)
att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale

# apply relative positional embedding if defined
att_mat = self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat

if self.causal:
att_mat = att_mat.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf"))

att_mat = att_mat.softmax(dim=-1)

if self.save_attn:
# no gradients and new tensor;
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
self.att_mat = att_mat.detach()

att_mat = self.drop_weights(att_mat)
x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
x = self.out_rearrange(x)
x = self.out_proj(x)
x = self.drop_output(x)
return x
Loading
Loading