diff --git a/sscma/models/base/general.py b/sscma/models/base/general.py index e7f3527c..1ffef482 100644 --- a/sscma/models/base/general.py +++ b/sscma/models/base/general.py @@ -78,6 +78,7 @@ def __init__( conv_layer: Optional[Callable[..., nn.Module]] or Dict or AnyStr = None, dilation: int = 1, inplace: bool = True, + use_depthwise: bool = False, ) -> None: super().__init__() if padding is None: @@ -86,17 +87,41 @@ def __init__( conv_layer = nn.Conv2d else: conv_layer = get_conv(conv_layer) - conv = conv_layer( - in_channels, - out_channels, - kernel_size, - stride, - padding, - dilation=dilation, - groups=groups, - bias=norm_layer is None if bias is None else bias, - ) - self.add_module('conv', conv) + if use_depthwise: + dw_conv = conv_layer( + in_channels, + in_channels, + kernel_size, + stride, + padding, + dilation=dilation, + groups=in_channels, + bias=norm_layer is None if bias is None else bias, + ) + pw_conv = conv_layer( + in_channels, + out_channels, + 1, + stride, + padding, + dilation=dilation, + groups=1, + bias=norm_layer is None if bias is None else bias, + ) + self.add_module('dw_conv', dw_conv) + self.add_module('pw_conv', pw_conv) + else: + conv = conv_layer( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation=dilation, + groups=groups, + bias=norm_layer is None if bias is None else bias, + ) + self.add_module('conv', conv) if norm_layer is not None: norm_layer = get_norm(norm_layer) self.add_module('norm', norm_layer(out_channels)) diff --git a/sscma/models/detectors/base.py b/sscma/models/detectors/base.py index 5bd75ade..8002818e 100644 --- a/sscma/models/detectors/base.py +++ b/sscma/models/detectors/base.py @@ -1,25 +1,25 @@ # Copyright (c) Seeed Technology Co.,Ltd. All rights reserved. -from typing import Dict, Optional, Union, Tuple, List -from abc import ABCMeta, abstractmethod import copy +from abc import ABCMeta, abstractmethod +from typing import Dict, List, Optional, Tuple, Union +import torch from mmdet.models.detectors import BaseDetector, SemiBaseDetector -from mmdet.structures import DetDataSample, OptSampleList, SampleList -from mmdet.utils import OptConfigType, OptMultiConfig, ConfigType, InstanceList from mmdet.models.utils import rename_loss_dict, reweight_loss_dict -from mmdet.structures import SampleList +from mmdet.structures import DetDataSample, OptSampleList, SampleList +from mmdet.utils import ConfigType, InstanceList, OptConfigType, OptMultiConfig from mmengine.model import BaseModel -import torch -from torch import Tensor from mmengine.optim import OptimWrapper -from ..utils import samplelist_boxtype2tensor +from torch import Tensor -from sscma.registry import MODELS from sscma.models.semi import BasePseudoLabelCreator +from sscma.registry import MODELS +from ..utils import samplelist_boxtype2tensor ForwardResults = Union[Dict[str, torch.Tensor], List[DetDataSample], Tuple[torch.Tensor], torch.Tensor] + @MODELS.register_module() class BaseSsod(SemiBaseDetector): teacher: BaseDetector diff --git a/sscma/models/layers/csp_layer.py b/sscma/models/layers/csp_layer.py index 5951c951..20f6aa34 100644 --- a/sscma/models/layers/csp_layer.py +++ b/sscma/models/layers/csp_layer.py @@ -2,81 +2,60 @@ # Copyright (c) OpenMMLab. import torch import torch.nn as nn -from sscma.models.base import ConvModule, DepthwiseSeparableConvModule +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig from mmengine.model import BaseModule from torch import Tensor -from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from sscma.models.base import ConvNormActivation - -class ChannelAttention(BaseModule): - def __init__(self, channels: int, init_cfg: OptMultiConfig = None) -> None: - super().__init__(init_cfg=init_cfg) - self.global_avgpool = nn.AdaptiveAvgPool2d(1) - self.fc = nn.Conv2d(channels, channels, 1, 1, 0, bias=True) - self.act = nn.Hardsigmoid(inplace=True) - - def forward(self, x: Tensor) -> Tensor: - """Forward function for ChannelAttention.""" - with torch.cuda.amp.autocast(enabled=False): - out = self.global_avgpool(x) - out = self.fc(out) - out = self.act(out) - return x * out +from .attention import ChannelAttention class CSPLayer(BaseModule): - def __init__(self, - in_channels: int, - out_channels: int, - expand_ratio: float = 0.5, - num_blocks: int = 1, - add_identity: bool = True, - use_depthwise: bool = False, - use_cspnext_block: bool = False, - channel_attention: bool = False, - conv_cfg: OptConfigType = None, - norm_cfg: ConfigType = dict( - type='BN', momentum=0.03, eps=0.001), - act_cfg: ConfigType = dict(type='Swish'), - init_cfg: OptMultiConfig = None) -> None: + def __init__( + self, + in_channels: int, + out_channels: int, + expand_ratio: float = 0.5, + num_blocks: int = 1, + add_identity: bool = True, + use_depthwise: bool = False, + use_cspnext_block: bool = False, + channel_attention: bool = False, + conv_cfg: OptConfigType = None, + norm_cfg: ConfigType = dict(type='BN', momentum=0.03, eps=0.001), + act_cfg: ConfigType = dict(type='Swish'), + init_cfg: OptMultiConfig = None, + ) -> None: super().__init__(init_cfg=init_cfg) block = CSPNeXtBlock if use_cspnext_block else DarknetBottleneck mid_channels = int(out_channels * expand_ratio) self.channel_attention = channel_attention - self.main_conv = ConvModule( - in_channels, - mid_channels, - 1, - conv_cfg=conv_cfg, - norm_cfg=norm_cfg, - act_cfg=act_cfg) - self.short_conv = ConvModule( - in_channels, - mid_channels, - 1, - conv_cfg=conv_cfg, - norm_cfg=norm_cfg, - act_cfg=act_cfg) - self.final_conv = ConvModule( - 2 * mid_channels, - out_channels, - 1, - conv_cfg=conv_cfg, - norm_cfg=norm_cfg, - act_cfg=act_cfg) - - self.blocks = nn.Sequential(*[ - block( - mid_channels, - mid_channels, - 1.0, - add_identity, - use_depthwise, - conv_cfg=conv_cfg, - norm_cfg=norm_cfg, - act_cfg=act_cfg) for _ in range(num_blocks) - ]) + self.main_conv = ConvNormActivation( + in_channels, mid_channels, 1, conv_layer=conv_cfg, norm_layer=norm_cfg, activation_layer=act_cfg + ) + self.short_conv = ConvNormActivation( + in_channels, mid_channels, 1, conv_layer=conv_cfg, norm_layer=norm_cfg, activation_layer=act_cfg + ) + self.final_conv = ConvNormActivation( + 2 * mid_channels, out_channels, 1, conv_layer=conv_cfg, norm_layer=norm_cfg, activation_layer=act_cfg + ) + + self.blocks = nn.Sequential( + *[ + block( + mid_channels, + mid_channels, + 1.0, + add_identity, + use_depthwise, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + ) + for _ in range(num_blocks) + ] + ) if channel_attention: self.attention = ChannelAttention(2 * mid_channels) @@ -92,40 +71,44 @@ def forward(self, x: Tensor) -> Tensor: if self.channel_attention: x_final = self.attention(x_final) return self.final_conv(x_final) - + + class DarknetBottleneck(BaseModule): - def __init__(self, - in_channels: int, - out_channels: int, - expansion: float = 0.5, - add_identity: bool = True, - use_depthwise: bool = False, - conv_cfg: OptConfigType = None, - norm_cfg: ConfigType = dict( - type='BN', momentum=0.03, eps=0.001), - act_cfg: ConfigType = dict(type='Swish'), - init_cfg: OptMultiConfig = None) -> None: + def __init__( + self, + in_channels: int, + out_channels: int, + expansion: float = 0.5, + add_identity: bool = True, + use_depthwise: bool = False, + conv_cfg: OptConfigType = None, + norm_cfg: ConfigType = dict(type='BN', momentum=0.03, eps=0.001), + act_cfg: ConfigType = dict(type='Swish'), + init_cfg: OptMultiConfig = None, + ) -> None: super().__init__(init_cfg=init_cfg) hidden_channels = int(out_channels * expansion) - conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule - self.conv1 = ConvModule( + self.conv1 = ConvNormActivation( in_channels, hidden_channels, 1, - conv_cfg=conv_cfg, - norm_cfg=norm_cfg, - act_cfg=act_cfg) - self.conv2 = conv( + conv_layer=conv_cfg, + norm_layer=norm_cfg, + activation_layer=act_cfg, + use_depthwise=False, + ) + self.conv2 = ConvNormActivation( hidden_channels, out_channels, 3, stride=1, padding=1, - conv_cfg=conv_cfg, - norm_cfg=norm_cfg, - act_cfg=act_cfg) - self.add_identity = \ - add_identity and in_channels == out_channels + conv_layer=conv_cfg, + norm_layer=norm_cfg, + activation_layer=act_cfg, + use_depthwise=use_depthwise, + ) + self.add_identity = add_identity and in_channels == out_channels def forward(self, x: Tensor) -> Tensor: """Forward function.""" @@ -140,40 +123,43 @@ def forward(self, x: Tensor) -> Tensor: class CSPNeXtBlock(BaseModule): - def __init__(self, - in_channels: int, - out_channels: int, - expansion: float = 0.5, - add_identity: bool = True, - use_depthwise: bool = False, - kernel_size: int = 5, - conv_cfg: OptConfigType = None, - norm_cfg: ConfigType = dict( - type='BN', momentum=0.03, eps=0.001), - act_cfg: ConfigType = dict(type='SiLU'), - init_cfg: OptMultiConfig = None) -> None: + def __init__( + self, + in_channels: int, + out_channels: int, + expansion: float = 0.5, + add_identity: bool = True, + use_depthwise: bool = False, + kernel_size: int = 5, + conv_cfg: OptConfigType = None, + norm_cfg: ConfigType = dict(type='BN', momentum=0.03, eps=0.001), + act_cfg: ConfigType = dict(type='SiLU'), + init_cfg: OptMultiConfig = None, + ) -> None: super().__init__(init_cfg=init_cfg) hidden_channels = int(out_channels * expansion) - conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule - self.conv1 = conv( + self.conv1 = ConvNormActivation( in_channels, hidden_channels, 3, stride=1, padding=1, - norm_cfg=norm_cfg, - act_cfg=act_cfg) - self.conv2 = DepthwiseSeparableConvModule( + norm_layer=norm_cfg, + activation_layer=act_cfg, + use_depthwise=use_depthwise, + ) + self.conv2 = ConvNormActivation( hidden_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2, - conv_cfg=conv_cfg, - norm_cfg=norm_cfg, - act_cfg=act_cfg) - self.add_identity = \ - add_identity and in_channels == out_channels + conv_layer=conv_cfg, + norm_layer=norm_cfg, + activation_layer=act_cfg, + use_depthwise=True, + ) + self.add_identity = add_identity and in_channels == out_channels def forward(self, x: Tensor) -> Tensor: """Forward function.""" diff --git a/sscma/models/layers/sppf.py b/sscma/models/layers/sppf.py index 248a5d2f..fdaeb8c7 100644 --- a/sscma/models/layers/sppf.py +++ b/sscma/models/layers/sppf.py @@ -1,63 +1,52 @@ # Copyright (c) Seeed Technology Co.,Ltd. # Copyright (c) OpenMMLab. -import torch -from torch import Tensor from typing import Sequence, Union import torch import torch.nn as nn -from ..base import ConvModule from mmdet.utils import ConfigType, OptMultiConfig from mmengine.model import BaseModule from torch import Tensor +from ..base import ConvModule + class SPPFBottleneck(BaseModule): - def __init__(self, - in_channels: int, - out_channels: int, - kernel_sizes: Union[int, Sequence[int]] = 5, - use_conv_first: bool = True, - mid_channels_scale: float = 0.5, - conv_cfg: ConfigType = None, - norm_cfg: ConfigType = dict( - type='BN', momentum=0.03, eps=0.001), - act_cfg: ConfigType = dict(type='SiLU', inplace=True), - init_cfg: OptMultiConfig = None): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_sizes: Union[int, Sequence[int]] = 5, + use_conv_first: bool = True, + mid_channels_scale: float = 0.5, + conv_cfg: ConfigType = None, + norm_cfg: ConfigType = dict(type='BN', momentum=0.03, eps=0.001), + act_cfg: ConfigType = dict(type='SiLU', inplace=True), + init_cfg: OptMultiConfig = None, + ): super().__init__(init_cfg) if use_conv_first: mid_channels = int(in_channels * mid_channels_scale) self.conv1 = ConvModule( - in_channels, - mid_channels, - 1, - stride=1, - conv_cfg=conv_cfg, - norm_cfg=norm_cfg, - act_cfg=act_cfg) + in_channels, mid_channels, 1, stride=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg + ) else: mid_channels = in_channels self.conv1 = None self.kernel_sizes = kernel_sizes if isinstance(kernel_sizes, int): - self.poolings = nn.MaxPool2d( - kernel_size=kernel_sizes, stride=1, padding=kernel_sizes // 2) + self.poolings = nn.MaxPool2d(kernel_size=kernel_sizes, stride=1, padding=kernel_sizes // 2) conv2_in_channels = mid_channels * 4 else: - self.poolings = nn.ModuleList([ - nn.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2) - for ks in kernel_sizes - ]) + self.poolings = nn.ModuleList( + [nn.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2) for ks in kernel_sizes] + ) conv2_in_channels = mid_channels * (len(kernel_sizes) + 1) self.conv2 = ConvModule( - conv2_in_channels, - out_channels, - 1, - conv_cfg=conv_cfg, - norm_cfg=norm_cfg, - act_cfg=act_cfg) + conv2_in_channels, out_channels, 1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg + ) def forward(self, x: Tensor) -> Tensor: if self.conv1: @@ -67,7 +56,6 @@ def forward(self, x: Tensor) -> Tensor: y2 = self.poolings(y1) x = torch.cat([x, y1, y2, self.poolings(y2)], dim=1) else: - x = torch.cat( - [x] + [pooling(x) for pooling in self.poolings], dim=1) + x = torch.cat([x] + [pooling(x) for pooling in self.poolings], dim=1) x = self.conv2(x) - return x \ No newline at end of file + return x diff --git a/sscma/models/losses/__init__.py b/sscma/models/losses/__init__.py index 8b817f96..879c3956 100644 --- a/sscma/models/losses/__init__.py +++ b/sscma/models/losses/__init__.py @@ -1,10 +1,10 @@ # Copyright (c) Seeed Technology Co.,Ltd. All rights reserved. from .bce_withlogits_loss import BCEWithLogitsLoss from .classfication_loss import LabelSmoothCrossEntropyLoss +from .domain_focal_loss import DomainFocalLoss, DomainLoss, TargetLoss +from .IouLoss import IoULoss from .nll_loss import NLLLoss from .pfld_loss import PFLDLoss -from .domain_focal_loss import DomainFocalLoss, TargetLoss, DomainLoss -from .IouLoss import * __all__ = [ 'LabelSmoothCrossEntropyLoss', @@ -14,4 +14,5 @@ 'DomainFocalLoss', 'TargetLoss', 'DomainLoss', + 'IoULoss', ] diff --git a/sscma/models/utils/misc.py b/sscma/models/utils/misc.py index 4e51624f..10a2fdd1 100644 --- a/sscma/models/utils/misc.py +++ b/sscma/models/utils/misc.py @@ -1,19 +1,17 @@ # Copyright (c) Seeed Technology Co.,Ltd. # Copyright (c) OpenMMLab. import math -import numpy as np import os import urllib -from typing import Union, List +from typing import List, Union + import numpy as np import torch -from mmengine.utils import scandir from mmdet.structures import SampleList from mmdet.structures.bbox import BaseBoxes +from mmengine.utils import scandir - -IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', - '.tiff', '.webp') +IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp') def samplelist_boxtype2tensor(batch_data_samples: SampleList) -> SampleList: @@ -41,6 +39,7 @@ def make_round(x: float, deepen_factor: float = 1.0) -> int: """Make sure that x*deepen_factor becomes an integer not less than 1.""" return max(round(x * deepen_factor), 1) if x > 1 else x + def auto_arrange_images(image_list: list, image_column: int = 2) -> np.ndarray: """Auto arrange image to image_column x N row. @@ -57,9 +56,7 @@ def auto_arrange_images(image_list: list, image_column: int = 2) -> np.ndarray: else: # arrange image according to image_column image_row = round(img_count / image_column) - fill_img_list = [np.ones(image_list[0].shape, dtype=np.uint8) * 255 - ] * ( - image_row * image_column - img_count) + fill_img_list = [np.ones(image_list[0].shape, dtype=np.uint8) * 255] * (image_row * image_column - img_count) image_list.extend(fill_img_list) merge_imgs_col = [] for i in range(image_row): @@ -95,8 +92,7 @@ def get_file_list(source_root: str) -> Union[List, dict]: source_file_path_list.append(os.path.join(source_root, file)) elif is_url: # when input source is url - filename = os.path.basename( - urllib.parse.unquote(source_root).split('?')[0]) + filename = os.path.basename(urllib.parse.unquote(source_root).split('?')[0]) file_save_path = os.path.join(os.getcwd(), filename) print(f'Downloading source file to {file_save_path}') torch.hub.download_url_to_file(source_root, file_save_path)