Skip to content

Commit

Permalink
Add missing mean std propagation to mmdeploy exporter
Browse files Browse the repository at this point in the history
Signed-off-by: Kim, Vinnam <[email protected]>
  • Loading branch information
vinnamkim committed Apr 9, 2024
1 parent fd0c9a1 commit 4fda8cb
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/otx/algo/detection/atss.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable
from otx.core.model.detection import MMDetCompatibleModel
from otx.core.schedulers import LRSchedulerListCallable
from otx.core.utils.utils import get_mean_std_from_data_processing

if TYPE_CHECKING:
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
Expand Down Expand Up @@ -54,13 +55,17 @@ def _exporter(self) -> OTXModelExporter:
if self.image_size is None:
raise ValueError(self.image_size)

mean, std = get_mean_std_from_data_processing(self.config)

return MMdeployExporter(
model_builder=self._create_model,
model_cfg=deepcopy(self.config),
deploy_cfg="otx.algo.detection.mmdeploy.atss",
test_pipeline=self._make_fake_test_pipeline(),
task_level_export_parameters=self._export_parameters,
input_size=self.image_size,
mean=mean,
std=std,
resize_mode="standard",
pad_value=0,
swap_rgb=False,
Expand Down
5 changes: 5 additions & 0 deletions src/otx/algo/detection/rtmdet.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable
from otx.core.model.detection import MMDetCompatibleModel
from otx.core.schedulers import LRSchedulerListCallable
from otx.core.utils.utils import get_mean_std_from_data_processing

if TYPE_CHECKING:
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
Expand Down Expand Up @@ -54,13 +55,17 @@ def _exporter(self) -> OTXModelExporter:
if self.image_size is None:
raise ValueError(self.image_size)

mean, std = get_mean_std_from_data_processing(self.config)

return MMdeployExporter(
model_builder=self._create_model,
model_cfg=deepcopy(self.config),
deploy_cfg="otx.algo.detection.mmdeploy.rtmdet",
test_pipeline=self._make_fake_test_pipeline(),
task_level_export_parameters=self._export_parameters,
input_size=self.image_size,
mean=mean,
std=std,
resize_mode="fit_to_window_letterbox",
pad_value=114,
swap_rgb=False,
Expand Down
5 changes: 5 additions & 0 deletions src/otx/algo/detection/ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from otx.core.model.detection import MMDetCompatibleModel
from otx.core.schedulers import LRSchedulerListCallable
from otx.core.utils.build import build_mm_model, modify_num_classes
from otx.core.utils.utils import get_mean_std_from_data_processing

if TYPE_CHECKING:
import torch
Expand Down Expand Up @@ -269,13 +270,17 @@ def _exporter(self) -> OTXModelExporter:
if self.image_size is None:
raise ValueError(self.image_size)

mean, std = get_mean_std_from_data_processing(self.config)

return MMdeployExporter(
model_builder=self._create_model,
model_cfg=deepcopy(self.config),
deploy_cfg="otx.algo.detection.mmdeploy.ssd_mobilenetv2",
test_pipeline=self._make_fake_test_pipeline(),
task_level_export_parameters=self._export_parameters,
input_size=self.image_size,
mean=mean,
std=std,
resize_mode="standard",
pad_value=0,
swap_rgb=False,
Expand Down
9 changes: 9 additions & 0 deletions src/otx/algo/detection/yolox.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable
from otx.core.model.detection import MMDetCompatibleModel
from otx.core.schedulers import LRSchedulerListCallable
from otx.core.utils.utils import get_mean_std_from_data_processing

if TYPE_CHECKING:
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
Expand Down Expand Up @@ -54,13 +55,17 @@ def _exporter(self) -> OTXModelExporter:
if self.image_size is None:
raise ValueError(self.image_size)

mean, std = get_mean_std_from_data_processing(self.config)

return MMdeployExporter(
model_builder=self._create_model,
model_cfg=deepcopy(self.config),
deploy_cfg="otx.algo.detection.mmdeploy.yolox",
test_pipeline=self._make_fake_test_pipeline(),
task_level_export_parameters=self._export_parameters,
input_size=self.image_size,
mean=mean,
std=std,
resize_mode="fit_to_window_letterbox",
pad_value=114,
swap_rgb=True,
Expand Down Expand Up @@ -102,13 +107,17 @@ def _exporter(self) -> OTXModelExporter:
if self.image_size is None:
raise ValueError(self.image_size)

mean, std = get_mean_std_from_data_processing(self.config)

return MMdeployExporter(
model_builder=self._create_model,
model_cfg=deepcopy(self.config),
deploy_cfg="otx.algo.detection.mmdeploy.yolox_tiny",
test_pipeline=self._make_fake_test_pipeline(),
task_level_export_parameters=self._export_parameters,
input_size=self.image_size,
mean=mean,
std=std,
resize_mode="fit_to_window_letterbox",
pad_value=114,
swap_rgb=False,
Expand Down
9 changes: 9 additions & 0 deletions src/otx/algo/instance_segmentation/maskrcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable
from otx.core.model.instance_segmentation import MMDetInstanceSegCompatibleModel
from otx.core.schedulers import LRSchedulerListCallable
from otx.core.utils.utils import get_mean_std_from_data_processing

if TYPE_CHECKING:
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
Expand Down Expand Up @@ -54,13 +55,17 @@ def _exporter(self) -> OTXModelExporter:
if self.image_size is None:
raise ValueError(self.image_size)

mean, std = get_mean_std_from_data_processing(self.config)

return MMdeployExporter(
model_builder=self._create_model,
model_cfg=deepcopy(self.config),
deploy_cfg="otx.algo.instance_segmentation.mmdeploy.maskrcnn",
test_pipeline=self._make_fake_test_pipeline(),
task_level_export_parameters=self._export_parameters,
input_size=self.image_size,
mean=mean,
std=std,
resize_mode="standard", # [TODO](@Eunwoo): need to revert it to fit_to_window after resolving
pad_value=0,
swap_rgb=False,
Expand Down Expand Up @@ -102,13 +107,17 @@ def _exporter(self) -> OTXModelExporter:
if self.image_size is None:
raise ValueError(self.image_size)

mean, std = get_mean_std_from_data_processing(self.config)

return MMdeployExporter(
model_builder=self._create_model,
model_cfg=deepcopy(self.config),
deploy_cfg="otx.algo.instance_segmentation.mmdeploy.maskrcnn_swint",
test_pipeline=self._make_fake_test_pipeline(),
task_level_export_parameters=self._export_parameters,
input_size=self.image_size,
mean=mean,
std=std,
resize_mode="standard", # [TODO](@Eunwoo): need to revert it to fit_to_window after resolving
pad_value=0,
swap_rgb=False,
Expand Down
5 changes: 5 additions & 0 deletions src/otx/algo/instance_segmentation/rtmdet_inst.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable
from otx.core.model.instance_segmentation import MMDetInstanceSegCompatibleModel
from otx.core.schedulers import LRSchedulerListCallable
from otx.core.utils.utils import get_mean_std_from_data_processing

if TYPE_CHECKING:
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
Expand Down Expand Up @@ -53,13 +54,17 @@ def _exporter(self) -> OTXModelExporter:
if self.image_size is None:
raise ValueError(self.image_size)

mean, std = get_mean_std_from_data_processing(self.config)

return MMdeployExporter(
model_builder=self._create_model,
model_cfg=deepcopy(self.config),
deploy_cfg="otx.algo.instance_segmentation.mmdeploy.rtmdet_inst",
test_pipeline=self._make_fake_test_pipeline(),
task_level_export_parameters=self._export_parameters,
input_size=self.image_size,
mean=mean,
std=std,
resize_mode="fit_to_window_letterbox",
pad_value=114,
swap_rgb=False,
Expand Down

0 comments on commit 4fda8cb

Please sign in to comment.