Skip to content

Commit

Permalink
make 6-class datasplit actually 6 classes
Browse files Browse the repository at this point in the history
  • Loading branch information
pattonw committed Feb 18, 2025
1 parent 54c1478 commit 3f5cd5b
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions tests/fixtures/datasplits.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,21 +96,21 @@ def six_class_datasplit(tmp_path):
"""
two crops for training, one for validation. Raw data is normally distributed
around 0 with std 1.
gt is provided as distances. First, gt is generated as a 12 class problem:
gt has 12 classes where class i in [0, 11] is all voxels with raw intensity
between (raw.min() + i(raw.max()-raw.min())/12, raw.min() + (i+1)(raw.max()-raw.min())/12).
Then we pair up classes (i, i+1) for i in [0,2,4,6,8,10], and compute distances to
gt is provided as distances. First, gt is generated as a 6 class problem:
gt has 6 classes where class i in [0, 5] is all voxels with raw intensity
between (raw.min() + i(raw.max()-raw.min())/6, raw.min() + (i+1)(raw.max()-raw.min())/6).
Then we pair up classes (i, i+1) for i in [0,2,4], and compute distances to
the nearest voxel in the pair. This leaves us with 6 distance channels.
"""
twelve_class_zarr = zarr.open(tmp_path / "twelve_class.zarr", "w")
six_class_zarr = zarr.open(tmp_path / "six_class.zarr", "w")
crop1_raw = ZarrArrayConfig(
name="crop1_raw",
file_name=tmp_path / "twelve_class.zarr",
file_name=tmp_path / "six_class.zarr",
dataset=f"volumes/crop1/raw",
)
crop1_gt = ZarrArrayConfig(
name="crop1_gt",
file_name=tmp_path / "twelve_class.zarr",
file_name=tmp_path / "six_class.zarr",
dataset=f"volumes/crop1/gt",
)
crop1_distances = BinarizeArrayConfig(
Expand All @@ -127,12 +127,12 @@ def six_class_datasplit(tmp_path):
)
crop2_raw = ZarrArrayConfig(
name="crop2_raw",
file_name=tmp_path / "twelve_class.zarr",
file_name=tmp_path / "six_class.zarr",
dataset=f"volumes/crop2/raw",
)
crop2_gt = ZarrArrayConfig(
name="crop2_gt",
file_name=tmp_path / "twelve_class.zarr",
file_name=tmp_path / "six_class.zarr",
dataset=f"volumes/crop2/gt",
)
crop2_distances = BinarizeArrayConfig(
Expand All @@ -149,12 +149,12 @@ def six_class_datasplit(tmp_path):
)
crop3_raw = ZarrArrayConfig(
name="crop3_raw",
file_name=tmp_path / "twelve_class.zarr",
file_name=tmp_path / "six_class.zarr",
dataset=f"volumes/crop3/raw",
)
crop3_gt = ZarrArrayConfig(
name="crop3_gt",
file_name=tmp_path / "twelve_class.zarr",
file_name=tmp_path / "six_class.zarr",
dataset=f"volumes/crop3/gt",
)
crop3_distances = BinarizeArrayConfig(
Expand All @@ -172,15 +172,15 @@ def six_class_datasplit(tmp_path):
for raw, gt in zip(
[crop1_raw, crop2_raw, crop3_raw], [crop1_gt, crop2_gt, crop3_gt]
):
raw_dataset = twelve_class_zarr.create_dataset(
raw_dataset = six_class_zarr.create_dataset(
raw.dataset, shape=(40, 20, 20), dtype=np.float32
)
gt_dataset = twelve_class_zarr.create_dataset(
gt_dataset = six_class_zarr.create_dataset(
gt.dataset, shape=(40, 20, 20), dtype=np.uint8
)
random_data = np.random.rand(40, 20, 20)
# as intensities increase so does the class
for i in list(np.linspace(random_data.min(), random_data.max(), 13))[1:]:
for i in list(np.linspace(random_data.min(), random_data.max(), 7))[1:]:
gt_dataset[:] += random_data > i
raw_dataset[:] = random_data
raw_dataset.attrs["offset"] = (0, 0, 0)
Expand Down

0 comments on commit 3f5cd5b

Please sign in to comment.