From f3c36340cc573d5aa7c7f68dffda7e9d4046d212 Mon Sep 17 00:00:00 2001 From: "marjan.asgari" Date: Wed, 4 Sep 2024 11:36:23 +0000 Subject: [PATCH 1/3] return_mask_name --- geo_inference/geo_inference.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/geo_inference/geo_inference.py b/geo_inference/geo_inference.py index e798248..2568aab 100644 --- a/geo_inference/geo_inference.py +++ b/geo_inference/geo_inference.py @@ -114,7 +114,7 @@ async def run_async(): # Start the periodic garbage collection task self.gc_task = asyncio.create_task(self.constant_gc(5)) # Calls gc.collect() every 5 seconds # Run the main computation asynchronously - await self.async_run_inference( + mask_layer_name = await self.async_run_inference( inference_input=inference_input, bands_requested=bands_requested, patch_size=patch_size, @@ -127,8 +127,10 @@ async def run_async(): await self.gc_task except asyncio.CancelledError: logger.info("The End of Inference") + + return mask_layer_name - asyncio.run(run_async()) + mask_layer_name = asyncio.run(run_async()) async def async_run_inference(self, inference_input: Union[Path, str], @@ -189,6 +191,8 @@ async def async_run_inference(self, yolo_csv_path = self.work_dir.joinpath(prefix_base_name + "_yolo.csv") coco_json_path = self.work_dir.joinpath(prefix_base_name + "_coco.json") stride_patch_size = int(patch_size / 2) + + """ Processing starts""" start_time = time.time() try: @@ -332,6 +336,7 @@ async def async_run_inference(self, ) ) torch.cuda.empty_cache() + return mask_path.name except Exception as e: print(f"Processing on the Dask cluster failed due to: {e}") From 2d440817f0a277a167550b4ad7577ecec9fe1d0e Mon Sep 17 00:00:00 2001 From: "marjan.asgari" Date: Wed, 4 Sep 2024 11:57:08 +0000 Subject: [PATCH 2/3] return_mask_name --- geo_inference/geo_inference.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/geo_inference/geo_inference.py b/geo_inference/geo_inference.py index 2568aab..064d48a 100644 --- a/geo_inference/geo_inference.py +++ b/geo_inference/geo_inference.py @@ -1,6 +1,7 @@ import os import gc import re +import sys import time import torch import pystac @@ -17,7 +18,7 @@ from omegaconf import ListConfig from rasterio.windows import from_bounds from typing import Union, Sequence, List -from dask.diagnostics import ProgressBar +ffrom dask.diagnostics import ProgressBar from multiprocessing.pool import ThreadPool @@ -40,6 +41,7 @@ logger = logging.getLogger(__name__) + class GeoInference: """ @@ -114,7 +116,7 @@ async def run_async(): # Start the periodic garbage collection task self.gc_task = asyncio.create_task(self.constant_gc(5)) # Calls gc.collect() every 5 seconds # Run the main computation asynchronously - mask_layer_name = await self.async_run_inference( + self.mask_layer_name = await self.async_run_inference( inference_input=inference_input, bands_requested=bands_requested, patch_size=patch_size, @@ -126,11 +128,11 @@ async def run_async(): try: await self.gc_task except asyncio.CancelledError: - logger.info("The End of Inference") - - return mask_layer_name + pass + + asyncio.run(run_async()) + return self.mask_layer_name - mask_layer_name = asyncio.run(run_async()) async def async_run_inference(self, inference_input: Union[Path, str], @@ -361,14 +363,15 @@ def main() -> None: num_classes=arguments["classes"], prediction_threshold=arguments["prediction_threshold"] ) - geo_inference( + inference_mask_layer_name = geo_inference( inference_input=arguments["image"], bands_requested=arguments["bands_requested"], patch_size=arguments["patch_size"], workers=arguments["workers"], bbox=arguments["bbox"], ) + if __name__ == "__main__": - main() + main() \ No newline at end of file From 5e999ab55fb2c2b7ec3d08c47344da302abbed0e Mon Sep 17 00:00:00 2001 From: "marjan.asgari" Date: Wed, 4 Sep 2024 12:01:26 +0000 Subject: [PATCH 3/3] return_mask_name --- geo_inference/geo_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/geo_inference/geo_inference.py b/geo_inference/geo_inference.py index 064d48a..0af462a 100644 --- a/geo_inference/geo_inference.py +++ b/geo_inference/geo_inference.py @@ -18,7 +18,7 @@ from omegaconf import ListConfig from rasterio.windows import from_bounds from typing import Union, Sequence, List -ffrom dask.diagnostics import ProgressBar +from dask.diagnostics import ProgressBar from multiprocessing.pool import ThreadPool