-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpreprocess.py
109 lines (83 loc) · 4.05 KB
/
preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import numpy as np
import os
import nibabel as nib
import h5py
import pickle
import itertools
def get_vals_vecs():
subject = os.listdir("./data/train")[0]
with open(f"./data/train/{subject}/bvals", "r") as f:
vals = np.array(list(map(int, f.readline().split())))
with open(f"./data/train/{subject}/bvecs", "r") as f:
vecs = np.empty((288, 3))
for i in range(3):
vecs[:, i] = list(map(float, f.readline().split()))
dif_indexes_0 = np.where(vals < 100)[0]
with open(f"./data/diffusion_indexes_train", "rb") as f:
dif_indexes_lr = pickle.load(f)
return vals, vecs, dif_indexes_0, dif_indexes_lr
def make_lr(dwi, mask, dif_indexes_lr):
lr = np.zeros((72+4, 87+4, 72+4, 36), dtype=np.float32)
mask_lr = np.zeros((72+4, 87+4, 72+4), dtype=np.uint8)
for i, j, k in itertools.product(range(72+2), range(87+2), range(72+2)):
lr[i, j, k] = np.mean(dwi[i*2:i*2+2, j*2:j*2+2, k*2:k*2+2, dif_indexes_lr], axis=(0,1,2))
mask_lr[i, j, k] = np.max(mask[i*2:i*2+2, j*2:j*2+2, k*2:k*2+2])
return lr, mask_lr
def make_train_dataset():
train_list = os.listdir("./data/train")
train_hrs = np.zeros((5, 144+8, 174+8, 144+8, 288), dtype=np.float32)
train_lrs = np.zeros((5, 72+4, 87+4, 72+4, 36), dtype=np.float32)
mask_lrs = np.zeros((5, 72+4, 87+4, 72+4), dtype=np.uint8)
_, _, dif_indexes_0, dif_indexes_lr = get_vals_vecs()
for idx, subject in enumerate(train_list):
print(f"\r{idx+1} / {len(train_list)} {subject}", end='')
dwi = nib.load(f"./data/train/{subject}/data.nii.gz")
dwi = np.array(dwi.get_fdata(), dtype=np.float32)[:-1, :, :-1]
mask = nib.load(f"./data/train/{subject}/nodif_brain_mask.nii.gz")
mask = np.array(mask.get_fdata(), dtype=np.uint8)[:-1, :, :-1]
dwi = np.pad(dwi, ((4, 4), (4, 4), (4, 4), (0, 0)), "constant", constant_values=0)
mask = np.pad(mask, ((4, 4), (4, 4), (4, 4)), "constant", constant_values=0)
dwi_b0 = np.mean(dwi[..., dif_indexes_0], axis=3)
for dif in range(288):
dwi[..., dif] /= dwi_b0
np.nan_to_num(dwi, copy=False)
np.clip(dwi, 0, 1, out=dwi)
train_hrs[idx] = dwi
train_lrs[idx], mask_lrs[idx] = make_lr(dwi, mask, dif_indexes_lr)
mask_index = np.array(np.where(mask_lrs == 1)).T
hf = h5py.File("./data/train.h5", "w")
hf.create_dataset("hr", data=train_hrs)
hf.create_dataset("lr", data=train_lrs)
hf.create_dataset("mask", data=mask_index)
hf.close()
def make_test_dataset():
test_list = os.listdir("./data/test")
_, _, dif_indexes_0, dif_indexes_lr = get_vals_vecs()
for idx, subject in enumerate(test_list):
print(f"\r{idx+1} / {len(test_list)} {subject}", end='')
dwi = nib.load(f"./data/test/{subject}/data.nii.gz")
dwi_header = dwi.header.copy()
dwi = np.array(dwi.get_fdata(), dtype=np.float32)[:-1, :, :-1]
mask = nib.load(f"./data/test/{subject}/nodif_brain_mask.nii.gz")
mask = np.array(mask.get_fdata(), dtype=np.uint8)[:-1, :, :-1]
dwi = np.pad(dwi, ((4, 4), (4, 4), (4, 4), (0, 0)), "constant", constant_values=0)
mask = np.pad(mask, ((4, 4), (4, 4), (4, 4)), "constant", constant_values=0)
dwi_b0 = np.mean(dwi[..., dif_indexes_0], axis=3)
for dif in range(288):
dwi[..., dif] /= dwi_b0
np.nan_to_num(dwi, copy=False)
np.clip(dwi, 0, 1, out=dwi)
test_lr, mask_lr = make_lr(dwi, mask, dif_indexes_lr)
mask_index = np.array(np.where(mask_lr == 1)).T
os.makedirs(f"./data/test_h5/{subject}", exist_ok=True)
hf = h5py.File(f"./data/test_h5/{subject}/data.h5", "w")
hf.create_dataset("lr", data=test_lr)
hf.create_dataset("hr_b0", data=dwi_b0)
hf.create_dataset("mask_index", data=mask_index)
hf.create_dataset("mask_hr", data=mask)
hf.close()
with open(f"./data/test_h5/{subject}/header", "wb") as f:
pickle.dump(dwi_header, f)
if __name__ == "__main__":
# make_train_dataset()
make_test_dataset()