Skip to content

Commit

Permalink
torch tensors were changing when being observed! Assigning needed var…
Browse files Browse the repository at this point in the history
…iables to ints fixed the problem.
  • Loading branch information
cameron-a-johnson committed Nov 20, 2024
1 parent f63f997 commit 65dc076
Showing 1 changed file with 12 additions and 16 deletions.
28 changes: 12 additions & 16 deletions tcn_hpl/callbacks/plot_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,6 @@ def on_test_batch_end(
batch: Any,
batch_idx: int,
dataloader_idx: int,
preds_dset_output_fpath: Path = "./tcn_activity_predictions.kwcoco.json"
) -> None:
"""Called when the test batch ends."""
# Re-using validation lists since test phase does not collide with
Expand All @@ -347,7 +346,7 @@ def on_test_batch_end(
self._val_all_targets.append(outputs["targets"].cpu())
self._val_all_source_vids.append(outputs["source_vid"].cpu())
self._val_all_source_frames.append(outputs["source_frame"].cpu())
self._preds_dset_output_fpath = preds_dset_output_fpath
self._preds_dset_output_fpath = self.output_dir / "tcn_activity_predictions.kwcoco.json"

def on_test_epoch_end(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
Expand Down Expand Up @@ -378,29 +377,26 @@ def on_test_epoch_end(
acts_dset.dataset['categories'] = truth_dset.dataset['categories']
# Create numpy lookup tables
for i in range(len(all_preds)):
frame_index = all_source_frames[i]
video_id = all_source_vids[i]
'''
# This list could be as long as the number of videos in the dset
matching_frame_indexes = torch.where(all_source_frames == frame_index)[0]
assert video_id in all_source_vids[matching_frame_indexes]
sub_index = torch.where(all_source_vids[matching_frame_indexes] == video_id)
frame_index = int(matching_frame_indexes[sub_index])
'''
frame_index = all_source_frames[i].item()
video_id = all_source_vids[i].item()
# Now get the image_id that matches the frame_index and video_id.
sorted_img_ids_for_one_video = acts_dset.index.vidid_to_gids[int(video_id)]
image_id = sorted_img_ids_for_one_video[frame_index]
# Sanity check: this image_id corresponds to the frame_index and video_id
assert acts_dset.index.imgs[image_id]['frame_index'] == frame_index
assert acts_dset.index.imgs[image_id]['video_id'] == video_id
try:
assert acts_dset.index.imgs[image_id]['frame_index'] == frame_index
assert acts_dset.index.imgs[image_id]['video_id'] == video_id
except:
import ipdb; ipdb.set_trace()

ann = {
"image_id": image_id,
"category_id": all_preds[i],
"score": all_probs[i][all_preds[i]],
"prob": all_probs[i],
"category_id": all_preds[i].item(),
"score": all_probs[i][all_preds[i]].item(),
"prob": all_probs[i].numpy().tolist(),
}
acts_dset.add_annotation(**ann)
print(f"Dumping activities file to {acts_dset.fpath}")
acts_dset.dump(acts_dset.fpath, newlines=True)


Expand Down

0 comments on commit 65dc076

Please sign in to comment.