diff --git a/tcn_hpl/callbacks/plot_metrics.py b/tcn_hpl/callbacks/plot_metrics.py index b8c661dd9..9c8238571 100644 --- a/tcn_hpl/callbacks/plot_metrics.py +++ b/tcn_hpl/callbacks/plot_metrics.py @@ -337,7 +337,6 @@ def on_test_batch_end( batch: Any, batch_idx: int, dataloader_idx: int, - preds_dset_output_fpath: Path = "./tcn_activity_predictions.kwcoco.json" ) -> None: """Called when the test batch ends.""" # Re-using validation lists since test phase does not collide with @@ -347,7 +346,7 @@ def on_test_batch_end( self._val_all_targets.append(outputs["targets"].cpu()) self._val_all_source_vids.append(outputs["source_vid"].cpu()) self._val_all_source_frames.append(outputs["source_frame"].cpu()) - self._preds_dset_output_fpath = preds_dset_output_fpath + self._preds_dset_output_fpath = self.output_dir / "tcn_activity_predictions.kwcoco.json" def on_test_epoch_end( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" @@ -378,29 +377,26 @@ def on_test_epoch_end( acts_dset.dataset['categories'] = truth_dset.dataset['categories'] # Create numpy lookup tables for i in range(len(all_preds)): - frame_index = all_source_frames[i] - video_id = all_source_vids[i] - ''' - # This list could be as long as the number of videos in the dset - matching_frame_indexes = torch.where(all_source_frames == frame_index)[0] - assert video_id in all_source_vids[matching_frame_indexes] - sub_index = torch.where(all_source_vids[matching_frame_indexes] == video_id) - frame_index = int(matching_frame_indexes[sub_index]) - ''' + frame_index = all_source_frames[i].item() + video_id = all_source_vids[i].item() # Now get the image_id that matches the frame_index and video_id. sorted_img_ids_for_one_video = acts_dset.index.vidid_to_gids[int(video_id)] image_id = sorted_img_ids_for_one_video[frame_index] # Sanity check: this image_id corresponds to the frame_index and video_id - assert acts_dset.index.imgs[image_id]['frame_index'] == frame_index - assert acts_dset.index.imgs[image_id]['video_id'] == video_id + try: + assert acts_dset.index.imgs[image_id]['frame_index'] == frame_index + assert acts_dset.index.imgs[image_id]['video_id'] == video_id + except: + import ipdb; ipdb.set_trace() ann = { "image_id": image_id, - "category_id": all_preds[i], - "score": all_probs[i][all_preds[i]], - "prob": all_probs[i], + "category_id": all_preds[i].item(), + "score": all_probs[i][all_preds[i]].item(), + "prob": all_probs[i].numpy().tolist(), } acts_dset.add_annotation(**ann) + print(f"Dumping activities file to {acts_dset.fpath}") acts_dset.dump(acts_dset.fpath, newlines=True)