Skip to content

Commit

Permalink
add a debug option to the TCN node in order to see the inputs it has …
Browse files Browse the repository at this point in the history
…when it decides not to create a classification
  • Loading branch information
josephvanpeltkw committed Oct 25, 2024
1 parent f3cf62a commit ccf8032
Showing 1 changed file with 46 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@
# activity prediction for the "live" image will not occur until object
# detections are predicted for that frame.
PARAM_WINDOW_LEADS_WITH_OBJECTS = "window_leads_with_objects"
# Debug file saved out to the filesystem for understanding the node's
# inputs when it decides not to create an activity classification.
# the format will be csv with a list of the object detections and the pose
PARAM_DEBUG_FILE = "debug_file"


class NoActivityClassification(Exception):
Expand Down Expand Up @@ -148,6 +152,7 @@ def __init__(self):
(PARAM_TOPIC, "medical"),
(PARAM_POSE_REPEAT_RATE, 0),
(PARAM_WINDOW_LEADS_WITH_OBJECTS, False),
(PARAM_DEBUG_FILE, ""),
],
)
self._img_ts_topic = param_values[PARAM_IMG_TS_TOPIC]
Expand All @@ -166,6 +171,12 @@ def __init__(self):

self._window_lead_with_objects = param_values[PARAM_WINDOW_LEADS_WITH_OBJECTS]

self._debug_file = param_values[PARAM_DEBUG_FILE]
# clear the file if it exists (since we are appending to it)
if self._debug_file != "":
with open(self._debug_file, "w") as f:
f.write("")

self.topic = param_values[PARAM_TOPIC]
# Load in TCN classification model and weights
with SimpleTimer("Loading inference module", log.info):
Expand Down Expand Up @@ -655,6 +666,12 @@ def rt_loop(self):
"not yield an activity classification for "
"publishing."
)
if self._debug_file != "":
# save the info for why this window was not processed
repr = window.__repr__()
with open(self._debug_file, "a") as f:
f.write(f"timestamp: {self.get_clock().now().to_msg()}\n")
f.write(f"{repr}\n")

# This window has completed processing - record its leading
# timestamp now.
Expand Down Expand Up @@ -888,5 +905,34 @@ def destroy_node(self):
main = make_default_main(ActivityClassifierTCN, multithreaded_executor=4)


if __name__ == "__main__":
main()
"""
Save results if we have been initialized to do that.
This method does nothing if this node has not been initialized to
collect results.
"""
rc = self._results_collector
if rc is not None:
self.get_logger().info(
f"Writing classification results to: {self._output_kwcoco_path}"
)
self._results_collector.write_file()

def destroy_node(self):
log = self.get_logger()
log.info("Stopping node runtime")
self.rt_stop()
with SimpleTimer("Shutting down runtime thread...", log.info):
self._rt_active.clear() # make RT active flag "False"
self._rt_thread.join()
self._save_results()
super().destroy_node()


main = make_default_main(ActivityClassifierTCN, multithreaded_executor=4)


if __name__ == "__main__":
main()

0 comments on commit ccf8032

Please sign in to comment.