Skip to content

Commit

Permalink
update run configs in tests to account for moved arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
pattonw committed Feb 18, 2025
1 parent b07ecd1 commit 4da6f3d
Showing 1 changed file with 13 additions and 10 deletions.
23 changes: 13 additions & 10 deletions tests/operations/test_mini.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
@pytest.mark.parametrize("upsample", [True, False])
@pytest.mark.parametrize("padding", ["valid", "same"])
@pytest.mark.parametrize("func", ["train", "validate"])
@pytest.mark.parametrize("multiprocessing", [False])
@pytest.mark.parametrize("multiprocessing", [True, False])
def test_mini(
tmpdir,
data_dims,
Expand All @@ -47,7 +47,7 @@ def test_mini(
# TODO: maybe check that an appropriate warning is raised somewhere
return

trainer_config = build_test_train_config(multiprocessing)
trainer_config = build_test_train_config()

data_config = build_test_data_config(
tmpdir,
Expand All @@ -73,27 +73,30 @@ def test_mini(
datasplit_config=data_config,
repetition=0,
num_iterations=1,
snapshot_interval=1,
batch_size=2,
num_workers=int(multiprocessing),
)
run = Run(run_config)
compute_context = create_compute_context()
device = compute_context.device
run.model.to(device)

if func == "train":
train_run(run)
train_run(run, validate=False, save_snapshots=True)
array_store = create_array_store()
snapshot_container = array_store.snapshot_container(run.name).container
assert snapshot_container.exists()
assert all(
x in zarr.open(snapshot_container)
for x in [
"0/volumes/raw",
"0/volumes/gt",
"0/volumes/target",
"0/volumes/weight",
"0/volumes/prediction",
"0/volumes/gradients",
"0/volumes/mask",
"volumes/raw",
"volumes/gt",
"volumes/target",
"volumes/weight",
"volumes/prediction",
"volumes/gradients",
"volumes/mask",
]
)
elif func == "validate":
Expand Down

0 comments on commit 4da6f3d

Please sign in to comment.