From 3f5cd5bafd1c05bb6e742fd89a74b6a388afa077 Mon Sep 17 00:00:00 2001 From: William Patton Date: Mon, 17 Feb 2025 14:02:18 -0800 Subject: [PATCH] make 6-class datasplit actually 6 classes --- tests/fixtures/datasplits.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/tests/fixtures/datasplits.py b/tests/fixtures/datasplits.py index e94aee0c..a36187d6 100644 --- a/tests/fixtures/datasplits.py +++ b/tests/fixtures/datasplits.py @@ -96,21 +96,21 @@ def six_class_datasplit(tmp_path): """ two crops for training, one for validation. Raw data is normally distributed around 0 with std 1. - gt is provided as distances. First, gt is generated as a 12 class problem: - gt has 12 classes where class i in [0, 11] is all voxels with raw intensity - between (raw.min() + i(raw.max()-raw.min())/12, raw.min() + (i+1)(raw.max()-raw.min())/12). - Then we pair up classes (i, i+1) for i in [0,2,4,6,8,10], and compute distances to + gt is provided as distances. First, gt is generated as a 6 class problem: + gt has 6 classes where class i in [0, 5] is all voxels with raw intensity + between (raw.min() + i(raw.max()-raw.min())/6, raw.min() + (i+1)(raw.max()-raw.min())/6). + Then we pair up classes (i, i+1) for i in [0,2,4], and compute distances to the nearest voxel in the pair. This leaves us with 6 distance channels. """ - twelve_class_zarr = zarr.open(tmp_path / "twelve_class.zarr", "w") + six_class_zarr = zarr.open(tmp_path / "six_class.zarr", "w") crop1_raw = ZarrArrayConfig( name="crop1_raw", - file_name=tmp_path / "twelve_class.zarr", + file_name=tmp_path / "six_class.zarr", dataset=f"volumes/crop1/raw", ) crop1_gt = ZarrArrayConfig( name="crop1_gt", - file_name=tmp_path / "twelve_class.zarr", + file_name=tmp_path / "six_class.zarr", dataset=f"volumes/crop1/gt", ) crop1_distances = BinarizeArrayConfig( @@ -127,12 +127,12 @@ def six_class_datasplit(tmp_path): ) crop2_raw = ZarrArrayConfig( name="crop2_raw", - file_name=tmp_path / "twelve_class.zarr", + file_name=tmp_path / "six_class.zarr", dataset=f"volumes/crop2/raw", ) crop2_gt = ZarrArrayConfig( name="crop2_gt", - file_name=tmp_path / "twelve_class.zarr", + file_name=tmp_path / "six_class.zarr", dataset=f"volumes/crop2/gt", ) crop2_distances = BinarizeArrayConfig( @@ -149,12 +149,12 @@ def six_class_datasplit(tmp_path): ) crop3_raw = ZarrArrayConfig( name="crop3_raw", - file_name=tmp_path / "twelve_class.zarr", + file_name=tmp_path / "six_class.zarr", dataset=f"volumes/crop3/raw", ) crop3_gt = ZarrArrayConfig( name="crop3_gt", - file_name=tmp_path / "twelve_class.zarr", + file_name=tmp_path / "six_class.zarr", dataset=f"volumes/crop3/gt", ) crop3_distances = BinarizeArrayConfig( @@ -172,15 +172,15 @@ def six_class_datasplit(tmp_path): for raw, gt in zip( [crop1_raw, crop2_raw, crop3_raw], [crop1_gt, crop2_gt, crop3_gt] ): - raw_dataset = twelve_class_zarr.create_dataset( + raw_dataset = six_class_zarr.create_dataset( raw.dataset, shape=(40, 20, 20), dtype=np.float32 ) - gt_dataset = twelve_class_zarr.create_dataset( + gt_dataset = six_class_zarr.create_dataset( gt.dataset, shape=(40, 20, 20), dtype=np.uint8 ) random_data = np.random.rand(40, 20, 20) # as intensities increase so does the class - for i in list(np.linspace(random_data.min(), random_data.max(), 13))[1:]: + for i in list(np.linspace(random_data.min(), random_data.max(), 7))[1:]: gt_dataset[:] += random_data > i raw_dataset[:] = random_data raw_dataset.attrs["offset"] = (0, 0, 0)