Skip to content

Commit

Permalink
Zarr compression tests only with versions before 3.0 (#8319)
Browse files Browse the repository at this point in the history
Fixes #8298.

### Description

This includes the tests for the `compressor` argument when testing with
Zarr before version 3.0 when this argument was deprecated. A fix to
upgrade the version of `pycln` used is also included. The version of
PyTorch is also fixed to below 2.6 to avoid issues with misuse of
`torch.load` which must be addressed later.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Eric Kerfoot <[email protected]>
  • Loading branch information
ericspod authored Feb 3, 2025
1 parent 8ac8e0d commit 8dcb9dc
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 28 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ repos:
)$
- repo: https://github.com/hadialqattan/pycln
rev: v2.4.0
rev: v2.5.0
hooks:
- id: pycln
args: [--config=pyproject.toml]
5 changes: 5 additions & 0 deletions monai/data/meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,3 +607,8 @@ def print_verbose(self) -> None:
print(self)
if self.meta is not None:
print(self.meta.__repr__())


# needed in later versions of Pytorch to indicate the class is safe for serialisation
if hasattr(torch.serialization, "add_safe_globals"):
torch.serialization.add_safe_globals([MetaTensor])
2 changes: 1 addition & 1 deletion monai/utils/jupyter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def plot_engine_status(


def _get_loss_from_output(
output: list[torch.Tensor | dict[str, torch.Tensor]] | dict[str, torch.Tensor] | torch.Tensor
output: list[torch.Tensor | dict[str, torch.Tensor]] | dict[str, torch.Tensor] | torch.Tensor,
) -> torch.Tensor:
"""Returns a single value from the network output, which is a dict or tensor."""

Expand Down
4 changes: 2 additions & 2 deletions monai/visualize/img2tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,11 @@ def _image3_animated_gif(
img_str = b""
for b_data in PIL.GifImagePlugin.getheader(ims[0])[0]:
img_str += b_data
img_str += b"\x21\xFF\x0B\x4E\x45\x54\x53\x43\x41\x50" b"\x45\x32\x2E\x30\x03\x01\x00\x00\x00"
img_str += b"\x21\xff\x0b\x4e\x45\x54\x53\x43\x41\x50" b"\x45\x32\x2e\x30\x03\x01\x00\x00\x00"
for i in ims:
for b_data in PIL.GifImagePlugin.getdata(i):
img_str += b_data
img_str += b"\x3B"
img_str += b"\x3b"

summary = SummaryX if has_tensorboardx and isinstance(writer, SummaryWriterX) else Summary
summary_image_str = summary.Image(height=10, width=10, colorspace=1, encoded_image_string=img_str)
Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ pep8-naming
pycodestyle
pyflakes
black>=22.12
isort>=5.1
isort>=5.1, <6.0
ruff
pytype>=2020.6.1; platform_system != "Windows"
types-setuptools
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
torch>=1.9
torch>=1.9,<2.6
numpy>=1.24,<2.0
45 changes: 23 additions & 22 deletions tests/test_zarr_avg_merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,32 +260,33 @@
TENSOR_4x4,
]

ALL_TESTS = [
TEST_CASE_0_DEFAULT_DTYPE,
TEST_CASE_1_DEFAULT_DTYPE,
TEST_CASE_2_DEFAULT_DTYPE,
TEST_CASE_3_DEFAULT_DTYPE,
TEST_CASE_4_DEFAULT_DTYPE,
TEST_CASE_5_VALUE_DTYPE,
TEST_CASE_6_COUNT_DTYPE,
TEST_CASE_7_COUNT_VALUE_DTYPE,
TEST_CASE_8_DTYPE,
TEST_CASE_9_LARGER_SHAPE,
TEST_CASE_10_DIRECTORY_STORE,
TEST_CASE_11_MEMORY_STORE,
TEST_CASE_12_CHUNKS,
TEST_CASE_16_WITH_LOCK,
TEST_CASE_17_WITHOUT_LOCK,
]

# add compression tests only when using Zarr version before 3.0
if not version_geq(get_package_version("zarr"), "3.0.0"):
ALL_TESTS += [TEST_CASE_13_COMPRESSOR_LZ4, TEST_CASE_14_COMPRESSOR_PICKLE, TEST_CASE_15_COMPRESSOR_LZMA]


@unittest.skipUnless(has_zarr and has_numcodecs, "Requires zarr (and numcodecs) packages.)")
class ZarrAvgMergerTests(unittest.TestCase):

@parameterized.expand(
[
TEST_CASE_0_DEFAULT_DTYPE,
TEST_CASE_1_DEFAULT_DTYPE,
TEST_CASE_2_DEFAULT_DTYPE,
TEST_CASE_3_DEFAULT_DTYPE,
TEST_CASE_4_DEFAULT_DTYPE,
TEST_CASE_5_VALUE_DTYPE,
TEST_CASE_6_COUNT_DTYPE,
TEST_CASE_7_COUNT_VALUE_DTYPE,
TEST_CASE_8_DTYPE,
TEST_CASE_9_LARGER_SHAPE,
TEST_CASE_10_DIRECTORY_STORE,
TEST_CASE_11_MEMORY_STORE,
TEST_CASE_12_CHUNKS,
TEST_CASE_13_COMPRESSOR_LZ4,
TEST_CASE_14_COMPRESSOR_PICKLE,
TEST_CASE_15_COMPRESSOR_LZMA,
TEST_CASE_16_WITH_LOCK,
TEST_CASE_17_WITHOUT_LOCK,
]
)
@parameterized.expand(ALL_TESTS)
def test_zarr_avg_merger_patches(self, arguments, patch_locations, expected):
codec_reg = numcodecs.registry.codec_registry
if "compressor" in arguments:
Expand Down

0 comments on commit 8dcb9dc

Please sign in to comment.