-
Notifications
You must be signed in to change notification settings - Fork 88
/
Copy pathpreprocess_mitoses.py
1220 lines (1044 loc) · 49.5 KB
/
preprocess_mitoses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""Preprocessing - mitosis detection"""
import argparse
import glob
import json
import math
import os
import shutil
import sys
import numpy as np
from PIL import Image
from sklearn.model_selection import train_test_split
import tensorflow as tf
from train_mitoses import normalize
#from deephistopath.inference import gen_batches
# TODO: update library so that this can be imported without accidentally pulling in spark
def gen_batches(iterator, batch_size, include_partial=True):
""" generate the tile batches from the tile iterator
Args:
iterator: the tile iterator
batch_size: batch size
include_partial: boolean value to keep the partial batch or not
Return:
the iterator for the tile batches
"""
batch = []
for item in iterator:
batch.append(item)
if len(batch) == batch_size:
yield batch
batch = []
if len(batch) > 0 and include_partial:
yield batch
def create_mask(h, w, coords, radius):
"""Create a binary image mask with locations of mitosis patches.
Pixels equal to one indicate normal regions, while areas equal to one
indicate mitosis regions. More specifically, all locations within a
Euclidean distance <= `radius` from the center of a true mitosis are
set to a value of one, and all other locations are set to a value of
zero.
Args:
h: Integer height of the mask.
w: Integer width of the mask.
coords: A list-like collection of (row, col) mitosis coordinates.
radius: An integer radius of the circular patches to place on the
mask for each mitosis location.
Returns:
A binary mask of the same shape as `im` indicating where the
mitosis patches are located.
"""
# check that row, col, and size are within the image bounds
#assert 1 < size <= min(h, w), "size must be >1 and within the bounds of the image"
# create mitosis patch mask
mask = np.zeros((h, w), dtype=np.bool)
for row, col in coords:
assert 0 <= row <= h, "row is outside of the image height"
assert 0 <= col <= w, "col is outside of the image width"
# mitosis mask as a circle with radius `radius` pixels centered on the given location
y, x = np.ogrid[:h, :w]
mitosis_mask = np.sqrt((y-row)**2 + (x-col)**2) <= radius
# indicate mitosis patch area on mask
mask = np.logical_or(mask, mitosis_mask)
return mask
def extract_patch(im, row, col, size):
"""Extract a patch centered at (row, col).
If the (row, col) is at the edge of the image, the image will be
reflected to yield a patch of the desired size.
Args:
im: An image stored as a NumPy array of shape (h, w, ...).
row: An integer row number.
col: An integer col number.
size: An integer size of the square patch to extract.
Returns:
A NumPy array of shape (size, size, ...).
"""
# check that row, col, and size are within the image bounds
dims = np.ndim(im)
assert dims >= 2, "image must be of shape (h, w, ...)"
h, w = im.shape[0:2]
assert 0 <= row <= h, "row {} is outside of the image height {}".format(row, h)
assert 0 <= col <= w, "col {} is outside of the image width {}".format(col, w)
#assert 1 < size <= min(h, w), "size must be >1 and within the bounds of the image"
# (row, col) is the center, so compute upper and lower bounds of patch
half_size = round(size / 2)
row_lower = row - half_size
row_upper = row + half_size
col_lower = col - half_size
col_upper = col + half_size
# clip the bounds to the size of the image and compute padding to add to patch
row_pad_lower = abs(row_lower) if row_lower < 0 else 0
row_pad_upper = row_upper - h if row_upper > h else 0
col_pad_lower = abs(col_lower) if col_lower < 0 else 0
col_pad_upper = col_upper - w if col_upper > w else 0
row_lower = max(0, row_lower)
row_upper = min(row_upper, h)
col_lower = max(0, col_lower)
col_upper = min(col_upper, w)
# extract patch
patch = im[row_lower:row_upper, col_lower:col_upper]
# pad with reflection on the height and width as needed to yield a patch of the desired size
# NOTE: all remaining dimensions (such as channels) receive 0 padding
padding = ((row_pad_lower, row_pad_upper), (col_pad_lower, col_pad_upper)) + ((0, 0),) * (dims-2)
# Note: the padding content starts from the second row/col of the
# input patch instead of the first row/col
patch_padded = np.pad(patch, padding, 'reflect')
return patch_padded
def gen_dense_coords(h, w, stride):
"""Generate centered (row, col) coordinates of patches densely from an
image with striding.
This slides across the image from left to right, top to bottom by
`stride` number of pixels, yielding (row, col) centered coordinates.
Args:
h: Integer height of the image.
w: Integer width of the image.
stride: An integer number of pixels by which to shift in the
sliding window for normal patches.
Returns:
Yields (row, col) integer coordinates of the center of a patch.
"""
assert stride > 0, "stride must be an integer > 0"
# generate coordinates
for row in range(0, h, stride):
for col in range(0, w, stride):
yield row, col # centered coordinates for this patch
def gen_normal_coords(mask, stride):
"""Generate (row, col) coordinates for normal patches.
This generates coordinates for normal patches in a sliding window
fashion with the given stride, possibly overlapping with mitosis
patches up to `threshold` percentage.
Args:
mask: A binary mask, indicating where the mitosis patches are
located, of the same height and width as the region image.
stride: An integer number of pixels by which to shift in the
sliding window for normal patches.
Returns:
Yields (row, col) coordinates of a normal patch.
"""
assert np.ndim(mask) == 2, "mask must be of shape (h, w)"
h, w = mask.shape
assert stride > 0, "stride must be an integer > 0"
for row, col in gen_dense_coords(h, w, stride):
# check that the point is not in a mitotic region
if not mask[row, col]:
yield row, col
def gen_fp_coords(im, normal_coords, size, model, model_name, pred_threshold, batch_size):
"""Generate (row, col) coordinates for false-positive patches.
This generates false-positive patch coordinates by making predictions
for each normal patch and yielding coordinates for cases in which the
predicted value is greater than `pred_threshold`. Note: by having the
threshold be a parameter, the model can output the actual probability
value or the logit value and as long as the threshold is set
appropriately it doesn't matter (i.e., for a probability threshold of
0.5, the corresponding logit threshold would be 0).
Args:
im: An image stored as a np.uint8 NumPy array of shape (h, w, c)
with values in [0, 255].
normal_coords: An iterable collection of (row, col) coordinates.
size: An integer size of the square patch to extract.
model: Keras model to use for false-positive oversampling.
model_name: String indicating the model being used, which is used
for determining the correct normalization. TODO: replace this
pred_threshold: Decimal threshold over which the patch is predicted
as a positive case.
batch_size: Size of batches to process, for performance
improvements.
Returns:
Yields (row, col) coordinates of false-positive patches.
"""
patches_rc = ((extract_patch(im, row, col, size), row, col) for row, col in normal_coords)
patch_rc_batches = gen_batches(patches_rc, batch_size, include_partial=True)
for patch_rc_batch in patch_rc_batches:
patch_batch, row_batch, col_batch = zip(*patch_rc_batch)
norm_patch_batch = normalize((np.array(patch_batch) / 255).astype(np.float32), model_name)
out_batch = np.squeeze(model.predict_on_batch(norm_patch_batch), axis=1)
for out, row, col in zip(out_batch, row_batch, col_batch):
if out > pred_threshold:
yield row, col
def gen_random_translation(h, w, row, col, max_shift):
"""Generate (row_shift, col_shift) random translation shifts relative
to (row, col).
Ensures that the shifts are within the bounds of the image.
Args:
h: Integer height of the image.
w: Integer width of the image.
row: An integer row number.
col: An integer col number.
max_shift: Integer upper bound on the spatial shift range for the
random translations.
Returns:
New (row_shift, col_shift) integer relative translations.
"""
# check that row, col, and size are within the image bounds
assert 0 <= row <= h, "row is outside of the image height"
assert 0 <= col <= w, "col is outside of the image width"
assert max_shift >= 0, "max_shift must be >= 0"
# NOTE: np.random.randint has exclusive upper bounds
row_shifted = min(max(0, row + np.random.randint(-max_shift, max_shift + 1)), h)
col_shifted = min(max(0, col + np.random.randint(-max_shift, max_shift + 1)), w)
row_shift = row_shifted - row
col_shift = col_shifted - col
return row_shift, col_shift
def gen_patches(im, coords, size, rotations, translations, max_shift, p):
"""Generate patches with sampling and augmentation from coordinates.
For every set of (row, col) coordinates in `coords`, this function
yields centered patches sampled with probability `p`, possibly with a
combination of some number of rotations evenly-spaced in [0, 180], and
some number of random translations per rotation.
NOTE: This function will internally create an uint8 version of `im`
in order to use PIL to rotate the image. It will yield patches
converted back to the original type.
Args:
im: An image stored as a NumPy array of shape (h, w, c).
coords: An iterable collection of (row, col) coordinates.
size: An integer size of the square patch to extract.
rotations: Integer number of rotation-augmented patches
evenly-spaced in [0, 180] to extract for each location, in
addition to a 0-degree rotation.
translations: An integer number of random translation augmented
patches to extract for each rotation (including the 0-degree
rotation), in addition to a translation of 0.
max_shift: Integer upper bound on the spatial shift range for
the random translations.
p: A decimal probability of sampling each patch.
Returns:
Yields (patch, row, col, rot, row_shift, col_shift) tuples, where
patch is a NumPy array of shape (size, size, c), row & col are the
original coordinates, rot is the degree of rotation, and row_shift
& col_shift are relative translations from row, col applied after
the rotation.
"""
# check that size is within the image bounds
assert np.ndim(im) == 3, "image must be of shape (h, w, c)"
h, w, c = im.shape
assert 1 < size <= min(h, w), "size must be > 1 and within the bounds of the image"
assert rotations >= 0, "rotations must be >0"
assert translations >= 0, "translations must be >0"
assert max_shift >= 0, "max_shift must be >= 0"
assert 0 <= p <= 1, "p must be a valid decimal probability"
# convert to uint8 type in order to use PIL to rotate
orig_dtype = im.dtype
im = im.astype(np.uint8)
# We want to extract a rotated image, but if we simply extract a patch and rotate it,
# the corners will be empty. Ideally, we don't want to have empty corners, or have
# to fill those corners in with random noise, mirroring, etc. Since we have access
# to the full image, we can first extract a larger patch from which a regular patch
# size can be extracted after a rotation without including any empty regions.
# If we rotate a patch by theta degrees, then the corners of a centered square patch
# will intersect with the sides of the rotated larger patch at the same angle theta,
# forming a right triangle between the side of the centered patch as the hypotenuse,
# the segment of the side of the rotated patch between the corner and the intersection
# with the centered patch, the corner of the rotated patch, and the segment on the
# next side, which is the complement of the first segment lengthwise. Since we know
# the angle and the length of the centered patch, we can compute the lengths of the
# two segments, and thus the length of the side of the outer patch. A 45 degree
# rotation is the worst case scenario, so we extract a bounding patch for that case.
# Additionally, to support random translations, we add `2*max_shift` to the length,
# and then simply adjust the center coordinates and rotate around that shifted center.
# Instead of random rotations, we extract evenly-spaced rotations in the range [0, 180],
# starting with 0 degrees, which equates to a centered patch.
rads = math.pi / 4 # 45 degrees, which is worst case
bounding_size = math.ceil((size+2*max_shift) * (math.cos(rads) + math.sin(rads)))
row_center = col_center = round(bounding_size / 2)
# TODO: either emit a warning, or add a parameter to allow empty corners
#assert bounding_size < min(h, w), "patch size is too large to avoid empty corners after rotation"
for row, col in coords:
bounding_patch = Image.fromarray(extract_patch(im, row, col, bounding_size)) # PIL for rotation
# rotations
for theta in np.linspace(0, 180, rotations+1, dtype=int): # always include 0 degrees
rotated_patch = np.asarray(bounding_patch.rotate(theta, Image.BILINEAR)) # then back to numpy
# random translations
shifts = [gen_random_translation(h, w, row, col, max_shift) for _ in range(translations)]
for row_shift, col_shift in [(0, 0)] + shifts: # always include 0 shift
patch = extract_patch(rotated_patch, row_center + row_shift, col_center + col_shift, size)
patch = patch.astype(orig_dtype) # convert back to original data type
# sample from a Bernoulli distribution with probability `p`
if np.random.binomial(1, p):
yield patch, row, col, theta, row_shift, col_shift
def save_patch(patch, path, lab, case, region, row, col, rotation, row_shift, col_shift, suffix="",
ext="png"):
"""Save an image patch with an appropriate filename.
The filename should contain all of the information needed to be able
to extract the same patch again.
Args:
patch: An image patch stored as a NumPy array of shape
(size, size, c).
path: A string path to the folder in which to store the image.
lab: An integer laboratory number from which the patch originated.
case: An integer case number from which the patch originated.
region: An integer region number from which the patch originated.
row: An integer row number at which the patch is centered, before
rotation and translation.
col: An integer column number at which the patch is centered, before
rotation and translation.
rotation: Integer degrees of rotation.
row_shift: Integer relative row translation of a patch that was
centered at (row, col) and then rotated.
col_shift: Integer relative column translation of a patch that was
centered at (row, col) and then rotated.
suffix: An optional string suffix to append to the filename, before
the file extension.
ext: A string file extension.
"""
# lab is a single digit, case and region are two digits with padding if needed
# TODO: extract filename generation and arg extraction into separate functions
filename = f"{lab}_{case}_{region}_{row}_{col}_{rotation}_{row_shift}_{col_shift}_{suffix}.{ext}"
file_path = os.path.join(path, filename)
# NOTE: the subsampling and quality parameters will only affect jpeg images
Image.fromarray(patch).save(file_path, subsampling=0, quality=100)
def preprocess(images_path, labels_path, dataset, base_save_path, train_size, patch_size, dist,
rotations_train, rotations_val, translations_train, translations_val, max_shift, stride_train,
stride_val, p_train, p_val, fp_path=None, model=None, model_name=None, model_patch_size=None,
model_batch_size=None, pred_threshold=None, fp_rotations=None, fp_translations=None, seed=None):
"""Generate a mitosis detection patch dataset.
This generates train/val datasets of mitosis/normal image patches for
the mitosis detection problem. The mitosis patches will be extracted
with centers at the given coordinates, along with random rotations
and translations from those coordinates. Normal patches will be
extracted in a sliding window fashion with the given stride, possibly
overlapping with mitosis patches up to some given threshold, and
optionally with false-positive oversampling. The train/val split will
be performed on overall cases, stratified by lab. I.e., the cases
from each lab will be separately split into training and validation
sets, and then the associated sets will be combined at the end. In
order to support adversarial training, the generated patch filenames
will each contain information about the laboratory and case from which
the patch originated.
Args:
images_path: Path to folder that contains the mitosis training
images.
labels_path: Path to folder that contains the mitosis training
labels.
dataset: String name of this dataset in {'tupac', 'icpr2012',
'icpr2014'}.
base_save_path: Path to folder in which to write the folders of
output patches.
train_size: Decimal percentage of data to include in the training
set during the train/val split.
patch_size: An integer size of the square patch to extract.
dist: An integer minimum Euclidean distance in pixels between a
normal patch and a mitotic patch.
rotations_train: Integer number of evenly-spaced rotation augmented
patches to extract for each mitosis in the training set, in
addition to the centered mitosis patch.
rotations_val: Integer number of evenly-spaced rotation augmented
patches to extract for each mitosis in the validation set, in
addition to the centered mitosis patch.
translations_train: Integer number of random translation augmented
patches to extract for each rotated mitosis patch in the training
set, in addition to the centered rotated mitosis patch.
translations_val: Integer number of random translation augmented
patches to extract for each rotated mitosis patch in the
validation set, in addition to the centered rotated mitosis patch.
max_shift: Integer upper bound on the spatial shift range for
the random translations.
stride_train: An integer number of pixels by which to shift in the
sliding window for normal patches in the training set.
stride_val: An integer number of pixels by which to shift in the
sliding window for normal patches in the validation set.
p_train: A decimal probability of sampling each normal patch
in the training set.
p_val: A decimal probability of sampling each normal patch
in the validation set.
fp_path: Optional path to a folder that contains false-positive
coordinates, which will be used instead of the model for
false-positive oversampling.
model: Optional Keras Model to use for false-positive oversampling.
model_name: String indicating the model being used, which is used
for determining the correct normalization. TODO: replace this
model_patch_size: An integer size of a square patch that the model
expects as input.
model_batch_size: Size of batches to process, for performance
improvements.
pred_threshold: Decimal threshold over which the patch is predicted
as a positive case.
fp_rotations: Integer number of evenly-spaced rotation augmented
patches to extract for each false-positive patch in the training
set, in addition to the centered patch.
fp_translations: Integer number of random translation
augmented patches to extract for each rotated false-positive patch
in the training set, in addition to the centered rotated patch.
seed: Integer random seed for NumPy.
"""
# set numpy seed
np.random.seed(seed)
# lab info
# TODO: turn this into a class
if dataset == "tupac":
# reformat case to zero-padded 2-character number
lab1 = [f"{n:02d}" for n in range(1, 24)] # cases 1-23
lab2 = [f"{n:02d}" for n in range(24, 49)] # cases 24-48
lab3 = [f"{n:02d}" for n in range(49, 74)] # cases 49-73
labs = {1: lab1, 2: lab2, 3: lab3}
scanners = [""]
region_im_subpath = ""
ext = "tif"
coords_subpath = ""
coords_suffix = ""
elif dataset == "icpr2012":
lab0 = [f"{n:02d}_v2" for n in range(0,5)] # cases 0-4
labs = {0: lab0} # reuse the "labs" idea
scanners = ["A", "H"]
region_im_subpath = ""
ext = "bmp"
coords_subpath = ""
coords_suffix = ""
elif dataset == "icpr2014":
lab0 = [f"{n:02d}" for n in [3,4,5,7,10,11,12,14,15,17,18]] # cases, scanner A
labs = {0: lab0} # reuse the "labs" idea
scanners = ["A", "H"]
region_im_subpath = os.path.join("frames", "x40")
ext = "tiff"
coords_subpath = "mitosis"
coords_suffix = "_mitosis"
# TODO: explore the use of the non-mitosis coords
else:
raise(Exception("incompatible dataset"))
# generate & save patches
for lab in labs.keys():
# split cases into train/val sets
lab_cases = labs[lab]
# TODO: extract this out into a separate function
if train_size < 1:
train, val = train_test_split(lab_cases, train_size=train_size, test_size=1-train_size,
random_state=seed)
else:
train = lab_cases
val = []
train_args = ('train', train, translations_train, rotations_train, p_train, stride_train)
val_args = ('val', val, translations_val, rotations_val, p_val, stride_val)
for split_args in [train_args, val_args]:
# generate samples for this split
split_name, cases, translations, rotations, p, stride = split_args
for case in cases:
for scanner in scanners:
case_name = f"{scanner}{case}"
case_path = os.path.join(images_path, case_name)
region_im_paths = glob.glob(os.path.join(case_path, region_im_subpath, f"*.{ext}"))
for region_im_path in region_im_paths: # a single case may have many available regions
region, _ = os.path.basename(region_im_path).split('.') # region number, file extension
im = np.array(Image.open(region_im_path)) # get region image in np.uint8 format
h, w, c = im.shape
coords_path = os.path.join(labels_path, case_name, coords_subpath,
f"{region}{coords_suffix}.csv")
if os.path.isfile(coords_path):
if dataset == "tupac":
# the tupac dataset contains a single x,y coordinate pair per line corresponding to
# center of the mitosis
coords = np.loadtxt(coords_path, dtype=np.int64, delimiter=',', ndmin=2,
usecols=(0,1))
elif dataset == "icpr2012":
# the icpr 2012 dataset contains a set of x,y coordinates per line corresponding to
# the segmentation map of the mitotic nucleus
# therefore, for the purposes of this contest, we read this file into a list of x,y
# lists, one list per mitosis. then, we compute the average x,y value for each
# mitosis, which should correspond to the center of the mitosis. then we form a
# numpy array containing a single x,y value for each mitosis
# TODO: look into using these segmentation maps directly
with open(coords_path, "r") as f:
lines = f.readlines()
coords = [[int(x) for x in l.strip().split(',')] for l in lines]
coords = [[c[i:i+2] for i in range(0, len(c), 2)] for c in coords]
coords = [np.mean(np.array(c), axis=0) for c in coords]
coords = np.array(coords).astype(np.int64)
# MUST REVERSE THIS BECAUSE ICPR DATASETS ARE IN (COL, ROW) FORMAT!!!!
coords[:, [0, 1]] = coords[:, [1, 0]]
else: # dataset == "icpr2014"
# the icpr 2014 dataset contains a x,y,z coordinate per line corresponding to the
# of a nucleus that is mitotic if z == 1, and non-mitotic if z == 0
coords = np.loadtxt(coords_path, dtype=np.int64, delimiter=',', ndmin=2,
usecols=(0,1))
# MUST REVERSE THIS BECAUSE ICPR DATASETS ARE IN (COL, ROW) FORMAT!!!!
coords[:, [0, 1]] = coords[:, [1, 0]]
else: # a missing file indicates no mitoses
coords = [] # no mitoses
# mitosis samples:
# save a centered patch, as well as rotations and random translations thereof
save_path = os.path.join(base_save_path, split_name, "mitosis")
if not os.path.exists(save_path):
os.makedirs(save_path) # create if necessary
patch_gen = gen_patches(im, coords, patch_size, rotations, translations, max_shift, 1)
for i, (patch, row, col, rot, row_shift, col_shift) in enumerate(patch_gen):
save_patch(patch, save_path, lab, case_name, region, row, col, rot, row_shift,
col_shift, i)
# normal samples:
# sample from all possible normal patches
save_path = os.path.join(base_save_path, split_name, "normal")
if not os.path.exists(save_path):
os.makedirs(save_path) # create if necessary
mask = create_mask(h, w, coords, dist)
# optional false_positive oversampling
if fp_path is not None:
fp_coords_path = os.path.join(fp_path, case_name, "{}.csv".format(region))
if os.path.isfile(fp_coords_path):
fp_coords = np.loadtxt(fp_coords_path, dtype=np.int64, delimiter=',', ndmin=2)
else: # a missing file indicates no mitoses
fp_coords = [] # no mitoses
elif model is not None and split_name == "train":
# oversample all false-positive cases in the training set
normal_coords_gen = gen_normal_coords(mask, stride)
fp_coords = gen_fp_coords(im, normal_coords_gen, model_patch_size, model, model_name,
pred_threshold, model_batch_size)
else:
fp_coords = []
fp_patch_gen = gen_patches(im, fp_coords, patch_size, fp_rotations, fp_translations,
max_shift, 1)
for i, (patch, row, col, rot, row_shift, col_shift) in enumerate(fp_patch_gen):
save_patch(patch, save_path, lab, case_name, region, row, col, rot, row_shift,
col_shift, i)
# regular sampling for normal cases
# NOTE: This may sample the false-positive patches again, but that's fine for now
if p > 0:
normal_coords_gen = gen_normal_coords(mask, stride)
patch_gen = gen_patches(im, normal_coords_gen, patch_size, 0, 0, max_shift, p)
for patch, row, col, rot, row_shift, col_shift in patch_gen:
save_patch(patch, save_path, lab, case_name, region, row, col, rot, row_shift,
col_shift)
if __name__ == "__main__":
def check_float_range(x, lb, ub):
"""Argparse utility function for a float type in [lb, ub]."""
try:
x = float(x)
except ValueError as err:
raise argparse.ArgumentTypeError(str(err))
if x < lb or x > ub:
err = "Value should be in [{}, {}]. Got {} instead.".format(lb, ub, x)
raise argparse.ArgumentTypeError(err)
return x
# parse args
parser = argparse.ArgumentParser()
parser.add_argument("--images_path",
default=os.path.join("data", "mitoses", "mitoses_train_image_data"),
help="path to the mitosis training images (default: %(default)s)")
parser.add_argument("--labels_path",
default=os.path.join("data", "mitoses", "mitoses_train_ground_truth"),
help="path to the mitosis training labels (default: %(default)s)")
parser.add_argument("--dataset", default="tupac",
help="name of this dataset in {'tupac', 'icpr2012', 'icpr2014'} (default: %(default)s)")
parser.add_argument("--save_path", default=os.path.join("data", "mitoses", "patches"),
help="path to folder in which to write the folders of output patches (default: %(default)s)")
parser.add_argument("--train_size", type=lambda x: check_float_range(x, 0, 1), default=0.8,
help="decimal percentage of data to include in the training set during the train/val split "\
"(default: %(default)s)")
parser.add_argument("--patch_size", type=int, default=64,
help="integer length of the square patches to extract (default: %(default)s)")
parser.add_argument("--dist", type=int, default=60,
help="minimum distance between the centers of normal and mitotic patches "\
"(default: %(default)s)")
parser.add_argument("--rotations_train", type=int, default=5,
help="number of evenly-spaced rotation augmented patches to extract for each mitosis in the "\
"training set, in addition to the centered mitosis patch (default: %(default)s)")
parser.add_argument("--rotations_val", type=int, default=0,
help="number of evenly-spaced rotation augmented patches to extract for each mitosis in the "\
"validation set, in addition to the centered mitosis patch (default: %(default)s)")
parser.add_argument("--translations_train", type=int, default=5,
help="number of random translation augmented patches to extract for each rotated mitosis "\
"patch in the training set, in addition to the centered rotated mitosis patch "\
"(default: %(default)s)")
parser.add_argument("--translations_val", type=int, default=0,
help="number of random translation augmented patches to extract for each rotated mitosis "\
"patch in the validation set, in addition to the centered rotated mitosis patch "\
"(default: %(default)s)")
parser.add_argument("--max_shift", type=int,
help="upper bound on the spatial shift range for the random translations "\
"(default: `round(patch_size/4)`)")
parser.add_argument("--stride_train", type=int,
help="number of pixels by which to shift in the sliding window for normal patches in the "\
"training set (default: `patch_size*(3/4)`)")
parser.add_argument("--stride_val", type=int,
help="number of pixels by which to shift in the sliding window for normal patches in the "\
"validation set (default: `patch_size*(3/4)`)")
parser.add_argument("--p_train", type=lambda x: check_float_range(x, 0, 1), default=1,
help="probability of sampling each normal patch in the training set (default: %(default)s)")
parser.add_argument("--p_val", type=lambda x: check_float_range(x, 0, 1), default=1,
help="probability of sampling each normal patch in the validation set (default: %(default)s)")
parser.add_argument("--fp_path",
help="path to false-positive locations, which will be used instead of the model "\
"(default: %(default)s)")
parser.add_argument("--model_path",
help="path to a Keras model to use for false-positive oversampling (default: %(default)s)")
# TODO: replace this with unified normalization flag used here and for training
parser.add_argument("--model_name",
help="name of the model being used, which is used for determining the correct normalization "\
"(default: %(default)s)")
parser.add_argument("--model_patch_size", type=int, default=64,
help="integer length of a square patch that the model expects as input "\
"(default: %(default)s)")
parser.add_argument("--model_batch_size", type=int, default=128,
help="size of the batches to predict on (default: %(default)s)")
parser.add_argument("--pred_threshold", type=float, default=0,
help="threshold over which the patch is predicted as a positive case (default: %(default)s)")
parser.add_argument("--fp_rotations", type=int, default=5,
help="number of evenly-spaced rotation augmented patches to extract for each false-positive "\
"patch in the training set, in addition to the centered patch (default: %(default)s)")
parser.add_argument("--fp_translations", type=int, default=5,
help="number of random translation augmented patches to extract for each rotated "\
"false-positive patch in the training set, in addition to the centered rotated patch "\
"(default: %(default)s)")
parser.add_argument("--seed", type=int, help="random seed for numpy (default: %(default)s)")
args = parser.parse_args()
# set any other defaults
if args.max_shift is None:
args.max_shift = round(args.patch_size/4)
if args.stride_train is None:
args.stride_train = round(args.patch_size*(3/4))
if args.stride_val is None:
args.stride_val = round(args.patch_size*(3/4))
# create a random seed if needed
if args.seed is None:
args.seed = np.random.randint(1e9)
# save args to file in save folder
if not os.path.exists(args.save_path):
os.makedirs(args.save_path)
with open(os.path.join(args.save_path, 'args.txt'), 'w') as f:
json.dump(args.__dict__, f)
print("", file=f)
# can be read in later with
#with open('args.txt', 'r') as f:
# args = json.load(f)
# save command line invocation to txt file for ease of rerunning the exact experiment
with open(os.path.join(args.save_path, 'invoke.txt'), 'w') as f:
f.write("python3 " + " ".join(sys.argv) + "\n")
# copy this script to the base save folder
shutil.copy2(os.path.realpath(__file__), args.save_path)
# load model for false-positive oversampling
if args.model_path is not None:
model = tf.keras.models.load_model(args.model_path, compile=False)
else:
model = None
# preprocess!
preprocess(images_path=args.images_path, labels_path=args.labels_path, dataset=args.dataset,
base_save_path=args.save_path, train_size=args.train_size, patch_size=args.patch_size,
dist=args.dist, rotations_train=args.rotations_train, rotations_val=args.rotations_val,
translations_train=args.translations_train, translations_val=args.translations_val,
max_shift=args.max_shift, stride_train=args.stride_train, stride_val=args.stride_val,
p_train=args.p_train, p_val=args.p_val, fp_path=args.fp_path, model=model,
model_name=args.model_name, model_patch_size=args.model_patch_size,
model_batch_size=args.model_batch_size, pred_threshold=args.pred_threshold,
fp_rotations=args.fp_rotations, fp_translations=args.fp_translations, seed=args.seed)
# ---
# tests
# TODO: eventually move these to a separate file.
# `py.test preprocess_mitoses.py`
def test_create_mask():
import pytest
# create image
h, w, c = 100, 200, 3
im = np.random.rand(h, w, c)
# check mask shape and type
coords = [(50, 40)]
radius = 32
mask = create_mask(h, w, coords, radius)
assert mask.shape == (h, w)
assert mask.dtype == bool
# row error
with pytest.raises(AssertionError):
coords = [(-1, 1)]
radius = 32
create_mask(h, w, coords, radius)
# col error
with pytest.raises(AssertionError):
coords = [(1, -1)]
radius = 32
create_mask(h, w, coords, radius)
# radius error
# with pytest.raises(AssertionError):
# coords = [(1, 1)]
# radius = h+1
# create_mask(h, w, coords, radius)
# another radius error
# with pytest.raises(AssertionError):
# coords = [(1, 1)]
# radius = w
# create_mask(h, w, coords, radius)
# another radius error
# with pytest.raises(AssertionError):
# coords = [(1, 1)]
# radius = 1
# create_mask(h, w, coords, radius)
# row, col, radius on boundary
coords = [(0, 0)]
radius = h
half_radius = int(radius / 2)
mask = create_mask(h, w, coords, radius)
correct_mask = np.zeros_like(mask)
for r in range(h):
for c in range(w):
if np.sqrt(r**2 + c**2) <= radius:
correct_mask[r, c] = 1
assert np.array_equal(mask, correct_mask)
# row, col, radius on another boundary
coords = [(h, w)]
radius = h
half_radius = int(radius / 2)
mask = create_mask(h, w, coords, radius)
correct_mask = np.zeros_like(mask)
for r in range(h):
for c in range(w):
if np.sqrt((r-h)**2 + (c-w)**2) <= radius:
correct_mask[r, c] = 1
assert np.array_equal(mask, correct_mask)
# normal row, col, radius
coords = [(50, 40), (60, 50)]
radius = 32
half_radius = int(radius / 2)
mask = create_mask(h, w, coords, radius)
assert mask.shape == (h, w)
correct_mask = np.zeros_like(mask)
for row, col in coords:
for r in range(h):
for c in range(w):
if np.sqrt((r-row)**2 + (c-col)**2) <= radius:
correct_mask[r, c] = 1
assert np.array_equal(mask, correct_mask)
# normal row, col, radius w/ NumPy array
coords = np.array([(50, 40), (60, 50)])
radius = 32
half_radius = int(radius / 2)
mask = create_mask(h, w, coords, radius)
assert mask.shape == (h, w)
correct_mask = np.zeros_like(mask)
for row, col in coords:
for r in range(h):
for c in range(w):
if np.sqrt((r-row)**2 + (c-col)**2) <= radius:
correct_mask[r, c] = 1
assert np.array_equal(mask, correct_mask)
# row, col, radius partially outside bounds
coords = [(50, 40), (10, 190)]
radius = 32
half_radius = int(radius / 2)
mask = create_mask(h, w, coords, radius)
assert mask.shape == (h, w)
correct_mask = np.zeros_like(mask)
for row, col in coords:
for r in range(h):
for c in range(w):
if np.sqrt((r-row)**2 + (c-col)**2) <= radius:
correct_mask[r, c] = 1
assert np.array_equal(mask, correct_mask)
def test_extract_patch():
import pytest
# create image
h, w, c = 100, 200, 3
im = np.random.rand(h, w, c)
im2d = np.random.rand(h, w)
# row error
with pytest.raises(AssertionError):
row, col, size = -1, 1, 32
extract_patch(im, row, col, size)
# col error
with pytest.raises(AssertionError):
row, col, size = 1, -1, 32
extract_patch(im, row, col, size)
# size error
# with pytest.raises(AssertionError):
# row, col, size = 1, 1, h+1
# extract_patch(im, row, col, size)
# another size error
# with pytest.raises(AssertionError):
# row, col, size = 1, 1, w
# extract_patch(im, row, col, size)
# another size error
# with pytest.raises(AssertionError):
# row, col, size = 1, 1, 1
# extract_patch(im, row, col, size)
# row, col, size on boundary
row, col, size = 0, 0, h
patch = extract_patch(im, row, col, size)
patch2d = extract_patch(im2d, row, col, size)
assert patch.shape == (size, size, c)
assert patch2d.shape == (size, size)
# row, col, size on another boundary
row, col, size = h, w, h
patch = extract_patch(im, row, col, size)
patch2d = extract_patch(im2d, row, col, size)
assert patch.shape == (size, size, c)
assert patch2d.shape == (size, size)
# normal row, col, size
row, col, size = 50, 40, 32
patch = extract_patch(im, row, col, size)
assert patch.shape == (size, size, c)
half_size = int(size / 2)
correct_patch = im[row-half_size:row+half_size, col-half_size:col+half_size]
assert np.allclose(patch, correct_patch)
# row, col, size partially outside bounds
row, col, size = 10, 190, 32
patch = extract_patch(im, row, col, size)
assert patch.shape == (size, size, c)
half_size = int(size / 2)
# make sure that the correct patch has actually been extracted
unpadded_patch = patch[6:,:-6]
correct_unpadded_patch = im[0:row+half_size, col-half_size:w]
assert np.array_equal(unpadded_patch, correct_unpadded_patch)
def test_gen_dense_coords():
import types
import pytest
# create image
h, w, c = 100, 200, 3
im = np.random.rand(h, w, c)
stride = 32
# check that it returns a generator object
assert isinstance(gen_dense_coords(h, w, stride), types.GeneratorType)
# stride error
with pytest.raises(AssertionError):
next(gen_dense_coords(h, w, -1))
# another stride error
with pytest.raises(AssertionError):
next(gen_dense_coords(h, w, 0))
# normal
row, col = next(gen_dense_coords(h, w, stride))
assert 0 <= row <= h
assert 0 <= col <= w
# list of coords
coords = list(gen_dense_coords(h, w, stride))
assert len(coords) > 0
# check that stride < size produces more coordinates
coords2 = list(gen_dense_coords(h, w, 1))
assert len(coords2) > len(coords)
# check for correct centered coordinates
h = 6
w = 8
stride = 2
correct_coords = [(0, 0), (0, 2), (0, 4), (0, 6),
(2, 0), (2, 2), (2, 4), (2, 6),
(4, 0), (4, 2), (4, 4), (4, 6)]
coords = list(gen_dense_coords(h, w, stride))
assert coords == correct_coords
def test_gen_normal_coords():
import types
import pytest
# create mask
h, w = 100, 200
size = 32
radius = 30
p = 0.6
stride = size
mask = np.zeros((h, w), dtype=bool)
for r in range(h):
for c in range(w):
if np.sqrt(r**2 + c**2) <= radius:
mask[r, c] = 1
mask[0:size, 0:size] = True
# check that it returns a generator object
assert isinstance(gen_normal_coords(mask, stride), types.GeneratorType)
# mask shape error
with pytest.raises(AssertionError):
next(gen_normal_coords(np.zeros((h, w, 3)), stride))
# stride error
with pytest.raises(AssertionError):
next(gen_normal_coords(mask, -1))
# another stride error
with pytest.raises(AssertionError):
next(gen_normal_coords(mask, 0))
# normal
row, col = next(gen_normal_coords(mask, stride))
assert 0 <= row <= h
assert 0 <= col <= w
# list of coords
coords = list(gen_normal_coords(mask, stride))
assert len(coords) > 0
# check for correct coords
h = 6
w = 8
size = 4
stride = size - 2
mask = np.zeros((h, w))
mask[0:size, 0:size] = 1
correct_coords = []
for r in range(0, h, stride):
for c in range(0, w, stride):
if not mask[r, c]: