Skip to content

Commit

Permalink
Merge pull request #34 from spoonsso/mirrors
Browse files Browse the repository at this point in the history
Mirrored video functionality.

Former-commit-id: d5317f70ae9da532bb745e2b281d6d0829efadd4
  • Loading branch information
spoonsso authored Feb 24, 2021
2 parents 696b4d8 + ad3a2f4 commit 8dd8c01
Show file tree
Hide file tree
Showing 23 changed files with 530 additions and 504 deletions.
5 changes: 4 additions & 1 deletion dannce/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@
"augment_bright_val": 0.05,
"augment_rotation_val": 5,
"drop_landmark": None,
"raw_im_h": None,
"raw_im_w": None,
"mirror": False,
"max_num_samples": None,
}
_param_defaults_dannce = {
"metric": ["euclidean_distance_3D"],
Expand Down Expand Up @@ -59,7 +63,6 @@
"nvox": None,
"expval": None,
"com_thresh": None,
"max_num_samples": None,
"start_sample": None,
"new_n_channels_out": None,
"cam3_train": None,
Expand Down
64 changes: 39 additions & 25 deletions dannce/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,13 @@ def add_shared_args(parser):
help="If true, converts 3-channel video frames into mono grayscale using standard RGB->gray conversion formula (ref. scikit-image).",
)

parser.add_argument(
"--mirror",
dest="mirror",
type=ast.literal_eval,
help="If true, uses a single video file for multiple views.",
)

return parser


Expand Down Expand Up @@ -197,8 +204,21 @@ def add_shared_predict_args(parser):
parser.add_argument(
"--max-num-samples",
dest="max_num_samples",
type=int,
help="Maximum number of samples to predict during COM or DANNCE prediction.",
)
parser.add_argument(
"--start-batch",
dest="start_batch",
type=int,
help="Starting batch number during dannce prediction.",
)
parser.add_argument(
"--start-sample",
dest="start_sample",
type=int,
help="Starting sample number during dannce prediction.",
)
return parser


Expand Down Expand Up @@ -311,6 +331,19 @@ def add_dannce_shared_args(parser):
dest="n_views",
type=int,
help="Sets the absolute number of views (when using fewer than 6 views only)")
parser.add_argument(
"--train-mode",
dest="train_mode",
help="Training modes can be:\n"
"new: initializes and trains a network from scratch\n"
"finetune: loads in pre-trained weights and fine-tuned from there\n"
"continued: initializes a full model, including optimizer state, and continuous training from the last full model checkpoint",
)
parser.add_argument(
"--dannce-finetune-weights",
dest="dannce_finetune_weights",
help="Path to weights of initial model for dannce fine tuning.",
)
return parser


Expand All @@ -326,19 +359,6 @@ def add_dannce_train_args(parser):
type=ast.literal_eval,
help="If True, use rotation augmentation for dannce training.",
)
parser.add_argument(
"--dannce-finetune-weights",
dest="dannce_finetune_weights",
help="Path to weights of initial model for dannce fine tuning.",
)
parser.add_argument(
"--train-mode",
dest="train_mode",
help="Training modes can be:\n"
"new: initializes and trains a network from scratch\n"
"finetune: loads in pre-trained weights and fine-tuned from there\n"
"continued: initializes a full model, including optimizer state, and continuous training from the last full model checkpoint",
)
parser.add_argument(
"--augment-continuous-rotation",
dest="augment_continuous_rotation",
Expand All @@ -365,18 +385,6 @@ def add_dannce_predict_args(parser):
dest="dannce_predict_model",
help="Path to model to use for dannce prediction.",
)
parser.add_argument(
"--start-batch",
dest="start_batch",
type=int,
help="Starting batch number during dannce prediction.",
)
parser.add_argument(
"--start-sample",
dest="start_sample",
type=int,
help="Starting sample number during dannce prediction.",
)
parser.add_argument(
"--predict-model",
dest="predict_model",
Expand All @@ -388,6 +396,12 @@ def add_dannce_predict_args(parser):
type=ast.literal_eval,
help="If True, use expected value network. This is normally inferred from the network name. But because prediction can be decoupled from the net param, expval can be set independently if desired.",
)
parser.add_argument(
"--from-weights",
dest="from_weights",
type=ast.literal_eval,
help="If True, attempt to load in a prediction model without requiring a full model file (i.e. just using weights). May fail for some model types.",
)
return parser


Expand Down
162 changes: 119 additions & 43 deletions dannce/engine/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
chunks=3500,
preload=True,
mono=False,
mirror=False,
):
"""Initialize Generator."""
self.dim_in = dim_in
Expand All @@ -57,7 +58,8 @@ def __init__(
self.samples_per_cluster = samples_per_cluster
self._N_VIDEO_FRAMES = chunks
self.preload = preload
self.mono=mono
self.mono = mono
self.mirror = mirror
self.on_epoch_end()

if self.vidreaders is not None:
Expand Down Expand Up @@ -201,7 +203,8 @@ def __init__(
crop_im=True,
norm_im=True,
chunks=3500,
mono=False
mono=False,
mirror=False,
):
"""Initialize data generator."""
DataGenerator.__init__(
Expand All @@ -222,7 +225,8 @@ def __init__(
vidreaders,
chunks,
preload,
mono
mono,
mirror
)
self.vmin = vmin
self.vmax = vmax
Expand Down Expand Up @@ -373,7 +377,7 @@ def __data_generation(self, list_IDs_temp):
(x_coord_3d.ravel(), y_coord_3d.ravel(), z_coord_3d.ravel()), axis=1
)

for camname in self.camnames[experimentID]:
for _ci, camname in enumerate(self.camnames[experimentID]):
ts = time.time()
# Need this copy so that this_y does not change
this_y = np.round(self.labels[ID]["data"][camname]).copy()
Expand All @@ -385,37 +389,50 @@ def __data_generation(self, list_IDs_temp):
com_precrop = np.nanmean(this_y, axis=1)

# Store sample
# for pre-cropped tifs
if self.immode == "tif":
thisim = imageio.imread(
os.path.join(
self.tifdirs[experimentID],
if not self.mirror or _ci == 0:
# for pre-cropped tifs
if self.immode == "tif":
thisim = imageio.imread(
os.path.join(
self.tifdirs[experimentID],
camname,
"{}.tif".format(sampleID),
)
)

# From raw video, need to crop
elif self.immode == "vid":
thisim = self.load_vid_frame(
self.labels[ID]["frames"][camname],
camname,
"{}.tif".format(sampleID),
self.preload,
extension=self.extension,
)[
self.crop_height[0] : self.crop_height[1],
self.crop_width[0] : self.crop_width[1],
]
# print("Decode frame took {} sec".format(time.time() - ts))
tss = time.time()

# Load in the image file at the specified path
elif self.immode == "arb_ims":
thisim = imageio.imread(
self.tifdirs[experimentID]
+ self.labels[ID]["frames"][camname][0]
+ ".jpg"
)
)

# From raw video, need to crop
elif self.immode == "vid":
thisim = self.load_vid_frame(
self.labels[ID]["frames"][camname],
camname,
self.preload,
extension=self.extension,
)[
self.crop_height[0] : self.crop_height[1],
self.crop_width[0] : self.crop_width[1],
]
# print("Decode frame took {} sec".format(time.time() - ts))
tss = time.time()

# Load in the image file at the specified path
elif self.immode == "arb_ims":
thisim = imageio.imread(
self.tifdirs[experimentID]
+ self.labels[ID]["frames"][camname][0]
+ ".jpg"
)
if self.mirror:
# Save copy of the first image loaded in, so that it can be flipped accordingly.
self.raw_im = thisim.copy()

if self.mirror and self.camera_params[experimentID][camname]["m"] == 1:
thisim = self.raw_im.copy()
thisim = thisim[-1::-1]
elif self.mirror and self.camera_params[experimentID][camname]["m"] == 0:
thisim = self.raw_im
elif self.mirror:
raise Exception("Invalid mirror parameter, m, must be 0 or 1")

if self.immode == "vid" or self.immode == "arb_ims":
this_y[0, :] = this_y[0, :] - self.crop_width[0]
Expand Down Expand Up @@ -651,7 +668,8 @@ def __init__(
crop_im=True,
norm_im=True,
chunks=3500,
mono=False
mono=False,
mirror=False,
):
"""Initialize data generator."""
DataGenerator.__init__(
Expand All @@ -672,7 +690,8 @@ def __init__(
vidreaders,
chunks,
preload,
mono
mono,
mirror,
)
self.vmin = vmin
self.vmax = vmax
Expand Down Expand Up @@ -752,7 +771,6 @@ def rot180(self, X):
return X

def project_grid(self, X_grid, camname, ID, experimentID):
ts = time.time()
# Need this copy so that this_y does not change
this_y = self.torch.as_tensor(
self.labels[ID]["data"][camname],
Expand All @@ -776,10 +794,47 @@ def project_grid(self, X_grid, camname, ID, experimentID):
self.preload,
extension=self.extension,
)[
self.crop_height[0] : self.crop_height[1],
self.crop_width[0] : self.crop_width[1],
self.crop_height[0]: self.crop_height[1],
self.crop_width[0]: self.crop_width[1],
]
return self.pj_grid_post(X_grid, camname, ID, experimentID,
com, com_precrop, thisim)

def pj_grid_mirror(self, X_grid, camname, ID, experimentID, thisim):
this_y = self.torch.as_tensor(
self.labels[ID]["data"][camname],
dtype=self.torch.float32,
device=self.device,
).round()

if self.torch.all(self.torch.isnan(this_y)):
com_precrop = self.torch.zeros_like(this_y[:, 0]) * self.torch.nan
else:
# For projecting points, we should not use this offset
com_precrop = self.torch.mean(this_y, axis=1)

this_y[0, :] = this_y[0, :] - self.crop_width[0]
this_y[1, :] = this_y[1, :] - self.crop_height[0]
com = self.torch.mean(this_y, axis=1)

if not self.mirror:
raise Exception("Trying to project onto mirrored images without mirror being set properly")

if self.camera_params[experimentID][camname]["m"] == 1:
passim = thisim[-1::-1].copy()
elif self.camera_params[experimentID][camname]["m"] == 0:
passim = thisim.copy()
else:
raise Exception("Invalid mirror parameter, m, must be 0 or 1")


return self.pj_grid_post(X_grid, camname, ID, experimentID,
com, com_precrop, passim)

def pj_grid_post(self, X_grid, camname, ID, experimentID,
com, com_precrop, thisim):
# separate the porjection and sampling into its own function so that
# when mirror == True, this can be called directly
if self.crop_im:
if self.torch.all(self.torch.isnan(com)):
thisim = self.torch.zeros(
Expand Down Expand Up @@ -932,11 +987,30 @@ def __data_generation(self, list_IDs_temp):
ts = time.time()
num_cams = len(self.camnames[experimentID])
arglist = []
for c in range(num_cams):
arglist.append(
[X_grid[i], self.camnames[experimentID][c], ID, experimentID]
)
result = self.threadpool.starmap(self.project_grid, arglist)
if self.mirror:
# Here we only load the video once, and then parallelize the projection
# and sampling after mirror flipping. For setups that collect views
# in a single imgae with the use of mirrors
loadim = self.load_vid_frame(
self.labels[ID]["frames"][self.camnames[experimentID][0]],
self.camnames[experimentID][0],
self.preload,
extension=self.extension,
)[
self.crop_height[0]: self.crop_height[1],
self.crop_width[0]: self.crop_width[1],
]
for c in range(num_cams):
arglist.append(
[X_grid[i], self.camnames[experimentID][c], ID, experimentID, loadim]
)
result = self.threadpool.starmap(self.pj_grid_mirror, arglist)
else:
for c in range(num_cams):
arglist.append(
[X_grid[i], self.camnames[experimentID][c], ID, experimentID]
)
result = self.threadpool.starmap(self.project_grid, arglist)

for c in range(num_cams):
ic = c + i * len(self.camnames[experimentID])
Expand Down Expand Up @@ -1098,6 +1172,7 @@ def __init__(
norm_im=True,
chunks=3500,
mono=False,
mirror=False,
):

"""Initialize data generator."""
Expand All @@ -1119,7 +1194,8 @@ def __init__(
vidreaders,
chunks,
preload,
mono
mono,
mirror,
)
self.vmin = vmin
self.vmax = vmax
Expand Down
Loading

0 comments on commit 8dd8c01

Please sign in to comment.