diff --git a/geo_inference/geo_inference.py b/geo_inference/geo_inference.py index e798248..0af462a 100644 --- a/geo_inference/geo_inference.py +++ b/geo_inference/geo_inference.py @@ -1,6 +1,7 @@ import os import gc import re +import sys import time import torch import pystac @@ -40,6 +41,7 @@ logger = logging.getLogger(__name__) + class GeoInference: """ @@ -114,7 +116,7 @@ async def run_async(): # Start the periodic garbage collection task self.gc_task = asyncio.create_task(self.constant_gc(5)) # Calls gc.collect() every 5 seconds # Run the main computation asynchronously - await self.async_run_inference( + self.mask_layer_name = await self.async_run_inference( inference_input=inference_input, bands_requested=bands_requested, patch_size=patch_size, @@ -126,9 +128,11 @@ async def run_async(): try: await self.gc_task except asyncio.CancelledError: - logger.info("The End of Inference") + pass asyncio.run(run_async()) + return self.mask_layer_name + async def async_run_inference(self, inference_input: Union[Path, str], @@ -189,6 +193,8 @@ async def async_run_inference(self, yolo_csv_path = self.work_dir.joinpath(prefix_base_name + "_yolo.csv") coco_json_path = self.work_dir.joinpath(prefix_base_name + "_coco.json") stride_patch_size = int(patch_size / 2) + + """ Processing starts""" start_time = time.time() try: @@ -332,6 +338,7 @@ async def async_run_inference(self, ) ) torch.cuda.empty_cache() + return mask_path.name except Exception as e: print(f"Processing on the Dask cluster failed due to: {e}") @@ -356,14 +363,15 @@ def main() -> None: num_classes=arguments["classes"], prediction_threshold=arguments["prediction_threshold"] ) - geo_inference( + inference_mask_layer_name = geo_inference( inference_input=arguments["image"], bands_requested=arguments["bands_requested"], patch_size=arguments["patch_size"], workers=arguments["workers"], bbox=arguments["bbox"], ) + if __name__ == "__main__": - main() + main() \ No newline at end of file