diff --git a/CHANGELOG.md b/CHANGELOG.md
index 9a8b1928c..4b7d3174b 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -16,6 +16,36 @@ No changes to highlight.
No changes to highlight.
+# v1.0.3
+
+## New Features:
+
+- Add RandomResize2 by `@illian01` in [PR 550](https://github.com/Nota-NetsPresso/netspresso-trainer/pull/550)
+- Add confidence score on detection visualization by `@hglee98` in [PR 552](https://github.com/Nota-NetsPresso/netspresso-trainer/pull/552)
+- Add save_best_only option for saving model by `@hglee98` in [PR 555](https://github.com/Nota-NetsPresso/netspresso-trainer/pull/555), [PR 567](https://github.com/Nota-NetsPresso/netspresso-trainer/pull/567)
+- Add option to select best model saving criterion by `@hglee98` and `@illian01` in [PR 557](https://github.com/Nota-NetsPresso/netspresso-trainer/pull/557), [PR 573](https://github.com/Nota-NetsPresso/netspresso-trainer/pull/573), [PR 574](https://github.com/Nota-NetsPresso/netspresso-trainer/pull/574)
+- Add MultiStepLR scheduler by `@hglee98` in [PR 559](https://github.com/Nota-NetsPresso/netspresso-trainer/pull/559)
+- Add class-wise metric analysis option by `@illian01` in [PR 568](https://github.com/Nota-NetsPresso/netspresso-trainer/pull/568)
+- Add ReLU6 by `@hglee98` in [PR 566](https://github.com/Nota-NetsPresso/netspresso-trainer/pull/566)
+- Add TFLite model evaluation feature by `@hglee98` in [PR 563](https://github.com/Nota-NetsPresso/netspresso-trainer/pull/563)
+- Add YOLO-Fastest-v2 by `@hglee98` in [PR 548](https://github.com/Nota-NetsPresso/netspresso-trainer/pull/548)
+- Add tabulating step for metric standard outputs by `@illian01` in [PR 570](https://github.com/Nota-NetsPresso/netspresso-trainer/pull/570)
+
+## Bug Fixes:
+
+- Fix keyword error of segmentation training by `@illian01` in [PR 551](https://github.com/Nota-NetsPresso/netspresso-trainer/pull/551)
+- Fix typo when saving optimizer state_dict by `@hglee98` in [PR 553](https://github.com/Nota-NetsPresso/netspresso-trainer/pull/553)
+- Fix not initialized save_dtype error by `@hglee98` in [PR 565](https://github.com/Nota-NetsPresso/netspresso-trainer/pull/565)
+- Fix mAP error in case of certain classes object is not in the dataset `@hglee98` in [PR 571](https://github.com/Nota-NetsPresso/netspresso-trainer/pull/571), [PR 572](https://github.com/Nota-NetsPresso/netspresso-trainer/pull/572)
+
+## Breaking Changes:
+
+- Massive refactoring metric modules and add flexible metric selecting option by `@illian01` in [PR 564](https://github.com/Nota-NetsPresso/netspresso-trainer/pull/564)
+
+## Other Changes:
+
+No changes to highlight.
+
# v1.0.2
## New Features:
diff --git a/README.md b/README.md
index 899d245b0..19e462cd5 100644
--- a/README.md
+++ b/README.md
@@ -58,7 +58,7 @@ pip install -e netspresso-trainer
### Set-up with docker
Please clone this repository and refer to [`Dockerfile`](./Dockerfile) and [`docker-compose-example.yml`](./docker-compose-example.yml).
-For docker users, we provide more detailed guide in our [Docs](https://nota-netspresso.github.io/netspresso-trainer).
+For docker users, we provide more detailed guide in our [Docs](https://nota-netspresso.github.io/netspresso-trainer/getting_started/installation/docker_installation).
## Getting started
@@ -142,4 +142,4 @@ tensorboard --logdir ./outputs --port 50001 --bind_all
```
where `PORT` for tensorboard is 50001.
-Note that the default directory of saving result will be `./outputs` directory.
\ No newline at end of file
+Note that the default directory of saving result will be `./outputs` directory.
diff --git a/config/benchmark_examples/classification-imagenet1k-resnet18/logging.yaml b/config/benchmark_examples/classification-imagenet1k-resnet18/logging.yaml
index 072071f7f..2eac8c15b 100644
--- a/config/benchmark_examples/classification-imagenet1k-resnet18/logging.yaml
+++ b/config/benchmark_examples/classification-imagenet1k-resnet18/logging.yaml
@@ -4,8 +4,14 @@ logging:
tensorboard: true
image: true
stdout: true
- save_optimizer_state: true
- sample_input_size: [224, 224] # Used for flops and onnx export
- onnx_export_opset: 13 # Recommend in range [13, 17]
- validation_epoch: &validation_epoch 5
- save_checkpoint_epoch: *validation_epoch # Multiplier of `validation_epoch`.
\ No newline at end of file
+ model_save_options:
+ save_optimizer_state: true
+ save_best_only: false
+ best_model_criterion: loss # metric
+ sample_input_size: [224, 224] # Used for flops and onnx export
+ onnx_export_opset: 13 # Recommend in range [13, 17]
+ validation_epoch: &validation_epoch 5
+ save_checkpoint_epoch: *validation_epoch # Multiplier of `validation_epoch`.
+ metrics:
+ classwise_analysis: False
+ metric_names: ~ # None for default settings
\ No newline at end of file
diff --git a/config/benchmark_examples/classification-imagenet1k-resnet34/logging.yaml b/config/benchmark_examples/classification-imagenet1k-resnet34/logging.yaml
index 072071f7f..2eac8c15b 100644
--- a/config/benchmark_examples/classification-imagenet1k-resnet34/logging.yaml
+++ b/config/benchmark_examples/classification-imagenet1k-resnet34/logging.yaml
@@ -4,8 +4,14 @@ logging:
tensorboard: true
image: true
stdout: true
- save_optimizer_state: true
- sample_input_size: [224, 224] # Used for flops and onnx export
- onnx_export_opset: 13 # Recommend in range [13, 17]
- validation_epoch: &validation_epoch 5
- save_checkpoint_epoch: *validation_epoch # Multiplier of `validation_epoch`.
\ No newline at end of file
+ model_save_options:
+ save_optimizer_state: true
+ save_best_only: false
+ best_model_criterion: loss # metric
+ sample_input_size: [224, 224] # Used for flops and onnx export
+ onnx_export_opset: 13 # Recommend in range [13, 17]
+ validation_epoch: &validation_epoch 5
+ save_checkpoint_epoch: *validation_epoch # Multiplier of `validation_epoch`.
+ metrics:
+ classwise_analysis: False
+ metric_names: ~ # None for default settings
\ No newline at end of file
diff --git a/config/benchmark_examples/classification-imagenet1k-resnet50/logging.yaml b/config/benchmark_examples/classification-imagenet1k-resnet50/logging.yaml
index 8eca4d377..9d20c3303 100644
--- a/config/benchmark_examples/classification-imagenet1k-resnet50/logging.yaml
+++ b/config/benchmark_examples/classification-imagenet1k-resnet50/logging.yaml
@@ -4,8 +4,14 @@ logging:
tensorboard: true
image: true
stdout: true
- save_optimizer_state: true
- sample_input_size: [224, 224] # Used for flops and onnx export
- onnx_export_opset: 13 # Recommend in range [13, 17]
- validation_epoch: &validation_epoch 10
- save_checkpoint_epoch: *validation_epoch # Multiplier of `validation_epoch`.
\ No newline at end of file
+ model_save_options:
+ save_optimizer_state: true
+ save_best_only: false
+ best_model_criterion: loss # metric
+ sample_input_size: [224, 224] # Used for flops and onnx export
+ onnx_export_opset: 13 # Recommend in range [13, 17]
+ validation_epoch: &validation_epoch 10
+ save_checkpoint_epoch: *validation_epoch # Multiplier of `validation_epoch`.
+ metrics:
+ classwise_analysis: False
+ metric_names: ~ # None for default settings
\ No newline at end of file
diff --git a/config/benchmark_examples/detection-coco2017-yolo_fastest_v2/augmentation.yaml b/config/benchmark_examples/detection-coco2017-yolo_fastest_v2/augmentation.yaml
new file mode 100644
index 000000000..2d6584127
--- /dev/null
+++ b/config/benchmark_examples/detection-coco2017-yolo_fastest_v2/augmentation.yaml
@@ -0,0 +1,56 @@
+augmentation:
+ train:
+ -
+ name: mosaicdetection
+ size: [640, 640]
+ mosaic_prob: 1.0
+ affine_scale: [0.5, 1.5]
+ degrees: 10.0
+ translate: 0.1
+ shear: 2.0
+ enable_mixup: True
+ mixup_prob: 1.0
+ mixup_scale: [0.1, 2.0]
+ fill: 0
+ mosaic_off_epoch: 285
+ -
+ name: hsvjitter
+ h_mag: 5
+ s_mag: 30
+ v_mag: 30
+ -
+ name: randomhorizontalflip
+ p: 0.5
+ -
+ name: resize
+ size: 640
+ interpolation: bilinear
+ max_size: ~
+ resize_criteria: long
+ -
+ name: pad
+ size: 640
+ fill: 0
+ -
+ name: randomresize
+ base_size: [640, 640]
+ stride: 32
+ random_range: 5
+ interpolation: bilinear
+ -
+ name: totensor
+ pixel_range: 1.0
+ inference:
+ -
+ name: resize
+ size: 640
+ interpolation: bilinear
+ max_size: ~
+ resize_criteria: long
+ -
+ name: pad
+ size: 640
+ fill: 0
+ -
+ name: totensor
+ pixel_range: 1.0
\ No newline at end of file
diff --git a/config/benchmark_examples/detection-coco2017-yolo_fastest_v2/data.yaml b/config/benchmark_examples/detection-coco2017-yolo_fastest_v2/data.yaml
new file mode 100644
index 000000000..ff5f5c897
--- /dev/null
+++ b/config/benchmark_examples/detection-coco2017-yolo_fastest_v2/data.yaml
@@ -0,0 +1,21 @@
+data:
+ name: coco2017
+ task: detection
+ format: local # local, huggingface
+ path:
+ root: ~/data/coco2017 # dataset root
+ train:
+ image: images/train # directory for training images
+ label: labels/train # directory for training labels
+ valid:
+ image: images/valid # directory for valid images
+ label: labels/valid # directory for valid labels
+ test:
+ image: ~
+ label: ~
+ pattern:
+ image: ~
+ label: ~
+ id_mapping: id_mapping.json
+ # id_mapping: ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush']
+ pallete: ~
diff --git a/config/benchmark_examples/detection-coco2017-yolo_fastest_v2/environment.yaml b/config/benchmark_examples/detection-coco2017-yolo_fastest_v2/environment.yaml
new file mode 100644
index 000000000..58b3dfd91
--- /dev/null
+++ b/config/benchmark_examples/detection-coco2017-yolo_fastest_v2/environment.yaml
@@ -0,0 +1,5 @@
+environment:
+ seed: 1
+ num_workers: 8
+ gpus: 0, 1, 2, 3
+ batch_size: 64 # Batch size per gpu
\ No newline at end of file
diff --git a/config/benchmark_examples/detection-coco2017-yolo_fastest_v2/logging.yaml b/config/benchmark_examples/detection-coco2017-yolo_fastest_v2/logging.yaml
new file mode 100644
index 000000000..245a6f25c
--- /dev/null
+++ b/config/benchmark_examples/detection-coco2017-yolo_fastest_v2/logging.yaml
@@ -0,0 +1,17 @@
+logging:
+ project_id: ~
+ output_dir: ./outputs
+ tensorboard: true
+ image: true
+ stdout: true
+ model_save_options:
+ save_optimizer_state: true
+ save_best_only: false
+ best_model_criterion: loss # metric
+ sample_input_size: [640, 640] # Used for flops and onnx export
+ onnx_export_opset: 13 # Recommend in range [13, 17]
+ validation_epoch: &validation_epoch 10
+ save_checkpoint_epoch: *validation_epoch # Multiplier of `validation_epoch`.
+ metrics:
+ classwise_analysis: False
+ metric_names: ~ # None for default settings
diff --git a/config/benchmark_examples/detection-coco2017-yolo_fastest_v2/model.yaml b/config/benchmark_examples/detection-coco2017-yolo_fastest_v2/model.yaml
new file mode 100644
index 000000000..94bf7afc2
--- /dev/null
+++ b/config/benchmark_examples/detection-coco2017-yolo_fastest_v2/model.yaml
@@ -0,0 +1,41 @@
+model:
+ task: detection
+ name: yolo_fastest_v2
+ checkpoint:
+ use_pretrained: False
+ load_head: False
+ path: ~
+ optimizer_path: ~
+ freeze_backbone: False
+ architecture:
+ full: ~ # auto
+ backbone:
+ name: shufflenetv2
+ params:
+ model_size: 0.5x
+ stage_params:
+ ~
+ neck:
+ name: lightfpn
+ params:
+ out_channels: 72
+ head:
+ name: yolo_fastest_head_v2
+ params:
+ anchors:
+ &anchors
+ - [12.,18., 37.,49., 52.,132.] # P2
+ - [115.,73., 119.,199., 242.,238.] # P3
+ postprocessor:
+ params:
+ # postprocessor - decode
+ score_thresh: 0.01
+ # postprocessor - nms
+ nms_thresh: 0.65
+ anchors: *anchors
+ class_agnostic: False
+ losses:
+ - criterion: yolofastest_loss
+ anchors: *anchors
+ l1_activate_epoch: ~
+ weight: ~
diff --git a/config/benchmark_examples/detection-coco2017-yolo_fastest_v2/training.yaml b/config/benchmark_examples/detection-coco2017-yolo_fastest_v2/training.yaml
new file mode 100644
index 000000000..edf63b9d3
--- /dev/null
+++ b/config/benchmark_examples/detection-coco2017-yolo_fastest_v2/training.yaml
@@ -0,0 +1,21 @@
+training:
+ epochs: 300
+ mixed_precision: True
+ max_norm: ~
+ ema:
+ name: exp_decay
+ decay: 0.9999
+ beta: 2000
+ optimizer:
+ name: sgd
+ lr: 0.001
+ momentum: 0.949
+ weight_decay: 0.0005 # No bias and norm decay
+ nesterov: True
+ no_bias_decay: True
+ no_norm_weight_decay: True
+ overwrite: ~
+ scheduler:
+ name: multi_step
+ milestones: [150, 250]
+ gamma: 0.1
diff --git a/config/benchmark_examples/detection-coco2017-yolox_s/logging.yaml b/config/benchmark_examples/detection-coco2017-yolox_s/logging.yaml
index d1f80afbc..040394158 100644
--- a/config/benchmark_examples/detection-coco2017-yolox_s/logging.yaml
+++ b/config/benchmark_examples/detection-coco2017-yolox_s/logging.yaml
@@ -4,8 +4,14 @@ logging:
tensorboard: true
image: true
stdout: true
- save_optimizer_state: true
- sample_input_size: [640, 640] # Used for flops and onnx export
- onnx_export_opset: 13 # Recommend in range [13, 17]
- validation_epoch: &validation_epoch 10
- save_checkpoint_epoch: *validation_epoch # Multiplier of `validation_epoch`.
\ No newline at end of file
+ model_save_options:
+ save_optimizer_state: true
+ save_best_only: false
+ best_model_criterion: loss # metric
+ sample_input_size: [640, 640] # Used for flops and onnx export
+ onnx_export_opset: 13 # Recommend in range [13, 17]
+ validation_epoch: &validation_epoch 10
+ save_checkpoint_epoch: *validation_epoch # Multiplier of `validation_epoch`.
+ metrics:
+ classwise_analysis: False
+ metric_names: ~ # None for default settings
\ No newline at end of file
diff --git a/config/logging.yaml b/config/logging.yaml
index 7021502b1..be24fc967 100644
--- a/config/logging.yaml
+++ b/config/logging.yaml
@@ -4,8 +4,14 @@ logging:
tensorboard: true
image: true
stdout: true
- save_optimizer_state: true
- sample_input_size: [512, 512] # Used for flops and onnx export
- onnx_export_opset: 13 # Recommend in range [13, 17]
- validation_epoch: &validation_epoch 10
- save_checkpoint_epoch: *validation_epoch # Multiplier of `validation_epoch`.
\ No newline at end of file
+ model_save_options:
+ save_optimizer_state: true
+ save_best_only: false
+ best_model_criterion: loss # metric
+ sample_input_size: [512, 512] # Used for flops and onnx export
+ onnx_export_opset: 13 # Recommend in range [13, 17]
+ validation_epoch: &validation_epoch 10
+ save_checkpoint_epoch: *validation_epoch # Multiplier of `validation_epoch`.
+ metrics:
+ classwise_analysis: False
+ metric_names: ~ # None for default settings
\ No newline at end of file
diff --git a/config/model/yolo/yolo-fastest-v2-detection.yaml b/config/model/yolo/yolo-fastest-v2-detection.yaml
index 906610620..163876d21 100644
--- a/config/model/yolo/yolo-fastest-v2-detection.yaml
+++ b/config/model/yolo/yolo-fastest-v2-detection.yaml
@@ -1,8 +1,8 @@
model:
task: detection
- name: yolofastest
+ name: yolo_fastest_v2
checkpoint:
- use_pretrained: False
+ use_pretrained: True
load_head: False
path: ~
optimizer_path: ~
@@ -23,17 +23,19 @@ model:
name: yolo_fastest_head_v2
params:
anchors:
- - [12,18, 37,49, 52,132] # P2
- - [115,73, 119,199, 242,238] # P3
+ &anchors
+ - [12.,18., 37.,49., 52.,132.] # P2
+ - [115.,73., 119.,199., 242.,238.] # P3
postprocessor:
params:
# postprocessor - decode
- topk_candidates: 1000
- score_thresh: 0.05
+ score_thresh: 0.01
# postprocessor - nms
- nms_thresh: 0.45
+ nms_thresh: 0.65
+ anchors: *anchors
class_agnostic: False
- # Temporary loss to test the full YOLOFastestV2 model to work right
losses:
- - criterion: retinanet_loss
+ - criterion: yolofastest_loss
+ anchors: *anchors
+ l1_activate_epoch: ~
weight: ~
diff --git a/docs/benchmarks/benchmarks.md b/docs/benchmarks/benchmarks.md
index 13fd7361a..2a34bfbc8 100644
--- a/docs/benchmarks/benchmarks.md
+++ b/docs/benchmarks/benchmarks.md
@@ -45,4 +45,5 @@ If you have a better recipe, please share with us anytime. We appreciate all eff
| COCO-val | [YOLOX-s](https://github.com/Nota-NetsPresso/netspresso-trainer/blob/master/config/model/yolox/yolox-s-detection.yaml) | [download](https://netspresso-trainer-public.s3.ap-northeast-2.amazonaws.com/checkpoint/yolox/yolox_s_coco.safetensors?versionId=QRLqHKqhv8TSYBrmsQ3M8lCR8w7HEZyA) | (640, 640) | 58.56 | 44.10 | 40.63 | 8.97M | 26.81G | Supported | conf_thresh=0.01, nms_thresh=0.65 |
| COCO-val | [YOLOX-m*](https://github.com/Nota-NetsPresso/netspresso-trainer/blob/master/config/model/yolox/yolox-m-detection.yaml) | [download](https://netspresso-trainer-public.s3.ap-northeast-2.amazonaws.com/checkpoint/yolox/yolox_m_coco.safetensors?versionId=xVUySP8xgVTpa6NhCMQpulqmYeRUAhpS) | (640, 640) | 65.00 | 51.34 | 47.04 | 25.33M | 73.76G | Supported | [Megvii-BaseDetection/YOLOX](https://github.com/Megvii-BaseDetection/YOLOX?tab=readme-ov-file#benchmark), conf_thresh=0.01, nms_thresh=0.65 |
| COCO-val | [YOLOX-l*](https://github.com/Nota-NetsPresso/netspresso-trainer/blob/master/config/model/yolox/yolox-l-detection.yaml) | [download](https://netspresso-trainer-public.s3.ap-northeast-2.amazonaws.com/checkpoint/yolox/yolox_l_coco.safetensors?versionId=1GR6YNRu.yUfnjq8hKPgARyZ6YejdxMB) | (640, 640) | 68.07 | 55.18 | 50.68 | 54.21M | 155.65G | Supported | [Megvii-BaseDetection/YOLOX](https://github.com/Megvii-BaseDetection/YOLOX?tab=readme-ov-file#benchmark), conf_thresh=0.01, nms_thresh=0.65 |
-| COCO-val | [YOLOX-x*](https://github.com/Nota-NetsPresso/netspresso-trainer/blob/master/config/model/yolox/yolox-x-detection.yaml) | [download](https://netspresso-trainer-public.s3.ap-northeast-2.amazonaws.com/checkpoint/yolox/yolox_x_coco.safetensors?versionId=NWskUEbSGviBWskHQ3P1dQZXnRXOR1WN) | (640, 640) | 69.13 | 56.46 | 51.79 | 99.07M | 281.94G | Supported | [Megvii-BaseDetection/YOLOX](https://github.com/Megvii-BaseDetection/YOLOX?tab=readme-ov-file#benchmark), conf_thresh=0.01, nms_thresh=0.65 |
\ No newline at end of file
+| COCO-val | [YOLOX-x*](https://github.com/Nota-NetsPresso/netspresso-trainer/blob/master/config/model/yolox/yolox-x-detection.yaml) | [download](https://netspresso-trainer-public.s3.ap-northeast-2.amazonaws.com/checkpoint/yolox/yolox_x_coco.safetensors?versionId=NWskUEbSGviBWskHQ3P1dQZXnRXOR1WN) | (640, 640) | 69.13 | 56.46 | 51.79 | 99.07M | 281.94G | Supported | [Megvii-BaseDetection/YOLOX](https://github.com/Megvii-BaseDetection/YOLOX?tab=readme-ov-file#benchmark), conf_thresh=0.01, nms_thresh=0.65 |
+| COCO-val | [YOLO-Fastest-v2](https://github.com/Nota-NetsPresso/netspresso-trainer/blob/master/config/model/yolo/yolo-fastest-v2-detection.yaml) | [download](https://netspresso-trainer-public.s3.ap-northeast-2.amazonaws.com/checkpoint/yolofastest/yolo_fastest_v2_coco.safetensors?versionId=CGhNjiZygGVjtHm0M586DzQ6.2FqWvl1) | (640, 640) | 25.03 | 11.60 | 12.78 | 0.25M | 0.74G | Supported | conf_thresh=0.01, nms_thresh=0.65 |
diff --git a/docs/components/augmentation/transforms.md b/docs/components/augmentation/transforms.md
index b561a9f6d..02fe63577 100644
--- a/docs/components/augmentation/transforms.md
+++ b/docs/components/augmentation/transforms.md
@@ -353,6 +353,34 @@ Since applying random resize to every image arises the difficulty of a batch han
+### RandomResize2
+
+RandomResize2 transforms the input image by resizing it based on a randomly selected scaling factor within a specified range. Note that RandomResize2 preserves the original aspect ratio, result image size might be largely different with `base_size`.
+
+Applying random resize to every image arises the difficulty of a batch handling, and RandomResize2 does not support per-batch target size handling function. We recommend to use RandomResize2 with other trasnform method.
+
+| Field
| Description |
+|---|---|
+| `name` | (str) Name must be "randomresize" to use `RandomResize` transform. |
+| `base_size` | (list) The base (height, width) of the target image, which is the initial size before applying randomness. |
+| `random_range` | (int) A range [min_factor, max_factor] within which the random scaling factor is selected. The input image will be resized by a combination of base_size and random factors, while maintaining the aspect ratio. |
+| `interpolation` | (str) Desired interpolation type. Supporting interpolations are 'nearest', 'bilinear' and 'bicubic'. |
+
+
+ RandomResize
+
+ ```yaml
+ augmentation:
+ train:
+ -
+ name: randomresize
+ base_size: [512, 2048]
+ random_range: [0.5, 1.5]
+ interpolation: 'bilinear'
+ ```
+
+
+
### RandomResizedCrop
Crop a random portion of image with different aspect of ratio in width and height, and resize it to a given size. This augmentation follows the [RandomResizedCrop](https://pytorch.org/vision/0.15/generated/torchvision.transforms.RandomResizedCrop.html#torchvision.transforms.RandomResizedCrop) in torchvision library.
diff --git a/docs/components/logging.md b/docs/components/logging.md
index d47ff5155..dad6b9406 100644
--- a/docs/components/logging.md
+++ b/docs/components/logging.md
@@ -9,11 +9,17 @@ logging:
tensorboard: true
image: true
stdout: true
- save_optimizer_state: true
- sample_input_size: [512, 512] # Used for flops and onnx export
- onnx_export_opset: 13 # Recommend in range [13, 17]
- validation_epoch: &validation_epoch 5
- save_checkpoint_epoch: *validation_epoch # Multiplier of `validation_epoch`.
+ model_save_options:
+ save_optimizer_state: true
+ save_best_only: false
+ best_model_criterion: loss # metric
+ sample_input_size: [512, 512] # Used for flops and onnx export
+ onnx_export_opset: 13 # Recommend in range [13, 17]
+ validation_epoch: &validation_epoch 10
+ save_checkpoint_epoch: *validation_epoch # Multiplier of `validation_epoch`.
+ metrics:
+ classwise_analysis: False
+ metric_names: ~ # None for default settings
```
## Tensorboard
@@ -36,8 +42,12 @@ The port number `50001` is same with the port forwarded in example docker setup.
| `logging.tensorboard` | (bool) Whether to use the tensorboard. |
| `logging.image` | (bool) Whether to save the validation results. It is ignored if the task is `classification`. |
| `logging.stdout` | (bool) Whether to log the standard output. |
-| `logging.save_optimizer_state` | (bool) Whether to save optimizer state with model checkpoint to resume training. |
-| `logging.sample_input_size` | (list[int]) The size of the sample input used for calculating FLOPs and exporting the model to ONNX format. |
-| `logging.onnx_export_opset` | (int) The ONNX opset version to be used for model export |
-| `logging.validation_epoch` | (int) Validation frequency in total training process. |
-| `logging.save_checkpoint_epoch` | (int) Checkpoint saving frequency in total training process. |
\ No newline at end of file
+| `logging.model_save_options.save_optimizer_state` | (bool) Whether to save optimizer state with model checkpoint to resume training. |
+| `logging.model_save_options.save_best_only` | (bool) Whether to only the best model. |
+| `logging.model_save_options.best_model_criterion` | (str) Criterion to determine which checkpoint is considered the best. One of 'loss' or 'metric'. |
+| `logging.model_save_options.sample_input_size` | (list[int]) The size of the sample input used for calculating FLOPs and exporting the model to ONNX format. |
+| `logging.model_save_options.onnx_export_opset` | (int) The ONNX opset version to be used for model export |
+| `logging.model_save_options.validation_epoch` | (int) Validation frequency in total training process. |
+| `logging.model_save_options.save_checkpoint_epoch` | (int) Checkpoint saving frequency in total training process. |
+| `logging.metrics.classwise_analysis` | (bool) Whether to perform class-wise analysis of metrics during validation. |
+| `logging.metrics.metric_names` | (list(str), optional) List of metric names to be logged. If not specified, default metrics for the task will be used. |
\ No newline at end of file
diff --git a/docs/components/training/schedulers.md b/docs/components/training/schedulers.md
index ff177b170..843307603 100644
--- a/docs/components/training/schedulers.md
+++ b/docs/components/training/schedulers.md
@@ -108,6 +108,28 @@ training:
```
+### Multi step
+
+This scheduler follows the [MultiStepLR](https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.MultiStepLR.html) in torch library.
+
+| Field
| Description |
+|---|---|
+| `name` | (str) Name must be "multi_step" to use `MultiStepLR` scheduler. |
+| `milestones` | (list) List of epoch indices. Must be increasing. |
+| `gamma` | (float) Multiplicative factor of learning rate decay. |
+
+
+ Step example
+```yaml
+training:
+ scheduler:
+ name: multi_step
+ milestones: [30, 80]
+ gamma: 0.1
+```
+
+
+
## Gradio demo for simulating the learning rate scheduler
In many training feature repositories, it is recommended to perform the entire training pipeline and check the log to see how the learning rate scheduler works.
diff --git a/src/netspresso_trainer/VERSION b/src/netspresso_trainer/VERSION
index e6d5cb833..e4c0d46e5 100644
--- a/src/netspresso_trainer/VERSION
+++ b/src/netspresso_trainer/VERSION
@@ -1 +1 @@
-1.0.2
\ No newline at end of file
+1.0.3
\ No newline at end of file
diff --git a/src/netspresso_trainer/dataloaders/augmentation/custom/image_proc.py b/src/netspresso_trainer/dataloaders/augmentation/custom/image_proc.py
index 7c6a452d1..3e230bc07 100644
--- a/src/netspresso_trainer/dataloaders/augmentation/custom/image_proc.py
+++ b/src/netspresso_trainer/dataloaders/augmentation/custom/image_proc.py
@@ -793,6 +793,40 @@ def __repr__(self):
)
+class RandomResize2:
+ visualize = True
+
+ def __init__(
+ self,
+ base_size: List[int],
+ random_range: Union[int, float],
+ interpolation: str,
+ ):
+ self.base_size = base_size
+ self.random_range = random_range
+ # Temporarily assign 0 size to Resize class
+ self.resize = Resize(size=self.base_size, interpolation=interpolation, max_size=None, resize_criteria=None)
+
+ def random_set(self, image):
+ target_h, target_w = self.base_size
+ resize_factor = random.random() * (self.random_range[1] - self.random_range[0]) + self.random_range[0]
+ max_size = (int(target_w * resize_factor), int(target_h * resize_factor))
+ w, h = image.size
+ resize_factor = min(max_size[0] / w, max_size[1] / h)
+ size = [int(h * resize_factor), int(w * resize_factor)]
+ self.resize.size = size
+
+ def __call__(self, image, label=None, mask=None, bbox=None, keypoint=None, dataset=None):
+ self.random_set(image)
+ image, label, mask, bbox, keypoint = self.resize(image, label, mask, bbox, keypoint, dataset)
+ return image, label, mask, bbox, keypoint
+
+ def __repr__(self):
+ return self.__class__.__name__ + "(max_size={0}, random_range={1})".format(
+ self.max_size, self.random_range
+ )
+
+
class PoseTopDownAffine:
"""
Based on the mmpose implementation.
diff --git a/src/netspresso_trainer/dataloaders/augmentation/registry.py b/src/netspresso_trainer/dataloaders/augmentation/registry.py
index 19e80f75e..60e8fe551 100644
--- a/src/netspresso_trainer/dataloaders/augmentation/registry.py
+++ b/src/netspresso_trainer/dataloaders/augmentation/registry.py
@@ -29,6 +29,7 @@
RandomHorizontalFlip,
RandomIoUCrop,
RandomResize,
+ RandomResize2,
RandomResizedCrop,
RandomVerticalFlip,
RandomZoomOut,
@@ -47,6 +48,7 @@
'randomresizedcrop': RandomResizedCrop,
'randomhorizontalflip': RandomHorizontalFlip,
'randomresize': RandomResize,
+ 'randomresize2': RandomResize2,
'randomverticalflip': RandomVerticalFlip,
'randomerasing': RandomErasing,
'randomioucrop': RandomIoUCrop,
diff --git a/src/netspresso_trainer/dataloaders/builder.py b/src/netspresso_trainer/dataloaders/builder.py
index b75838e4a..3d06c5f29 100644
--- a/src/netspresso_trainer/dataloaders/builder.py
+++ b/src/netspresso_trainer/dataloaders/builder.py
@@ -178,7 +178,6 @@ def build_dataloader(conf, task: str, model_name: str, dataset, phase, profile=F
#TODO: Temporarily set ``cache_data`` as optional since this is experimental
cache_data = conf.environment.cache_data if hasattr(conf.environment, 'cache_data') else False
-
if task == 'classification':
# TODO: ``phase`` should be removed later.
transforms = getattr(conf.augmentation, phase, None)
diff --git a/src/netspresso_trainer/loggers/base.py b/src/netspresso_trainer/loggers/base.py
index 34fe18781..e021a20e6 100644
--- a/src/netspresso_trainer/loggers/base.py
+++ b/src/netspresso_trainer/loggers/base.py
@@ -58,10 +58,11 @@ def __init__(
self.use_tensorboard: bool = self.conf.logging.tensorboard
self.use_imagesaver: bool = self.conf.logging.image
self.use_stdout: bool = self.conf.logging.stdout
+ self._save_best_only: bool = self.conf.logging.model_save_options.save_best_only
self.loggers = []
if self.use_imagesaver:
- self.loggers.append(ImageSaver(model=model, result_dir=self._result_dir))
+ self.loggers.append(ImageSaver(model=model, result_dir=self._result_dir, save_best_only=self._save_best_only))
if self.use_tensorboard:
self.tensorboard_logger = TensorboardLogger(task=task, model=model, result_dir=self._result_dir,
step_per_epoch=step_per_epoch, num_sample_images=num_sample_images)
@@ -100,6 +101,9 @@ def _convert_scalar_as_readable(self, scalar_dict: Dict):
v_new = v.avg
scalar_dict.update({k: v_new})
continue
+ if isinstance(v, dict):
+ pass
+ continue
raise TypeError(f"Unsupported type for {k}!!! Current type: {type(v)}")
return scalar_dict
diff --git a/src/netspresso_trainer/loggers/image.py b/src/netspresso_trainer/loggers/image.py
index 253b0cebd..1c42f7890 100644
--- a/src/netspresso_trainer/loggers/image.py
+++ b/src/netspresso_trainer/loggers/image.py
@@ -22,11 +22,12 @@
class ImageSaver:
- def __init__(self, model, result_dir) -> None:
+ def __init__(self, model, result_dir, save_best_only: Optional[bool]=None) -> None:
super(ImageSaver, self).__init__()
self.model = model
self.save_dir: Path = Path(result_dir) / "result_image"
self.save_dir.mkdir(exist_ok=True)
+ self.save_best_only = save_best_only
def save_ndarray_as_image(self, image_array: np.ndarray, filename: Union[str, Path], dataformats: Literal['HWC', 'CHW'] = 'HWC'):
assert image_array.ndim == 3
@@ -47,6 +48,8 @@ def save_result(self, image_dict: Dict, prefix, epoch):
assert isinstance(v, np.ndarray)
if epoch is None:
self.save_ndarray_as_image(v, f"{prefix_dir}/{idx:03d}_{k}.png", dataformats='HWC')
+ elif self.save_best_only:
+ self.save_ndarray_as_image(v, f"{prefix_dir}/best_{idx:03d}_{k}.png", dataformats='HWC')
else:
self.save_ndarray_as_image(v, f"{prefix_dir}/{epoch:04d}_{idx:03d}_{k}.png", dataformats='HWC')
diff --git a/src/netspresso_trainer/loggers/stdout.py b/src/netspresso_trainer/loggers/stdout.py
index 85c264f4c..854abb564 100644
--- a/src/netspresso_trainer/loggers/stdout.py
+++ b/src/netspresso_trainer/loggers/stdout.py
@@ -18,6 +18,7 @@
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from loguru import logger
+from tabulate import tabulate
LOG_FILENAME = "result.log"
@@ -49,4 +50,20 @@ def __call__(
if losses is not None:
logger.info(f"{prefix} loss: {losses['total']:.7f}")
if metrics is not None:
- logger.info(f"{prefix} metric: {[(name, value) for name, value in metrics.items()]}")
+ metric_std_log = f'{prefix} metric:\n'
+
+ headers = ['Class number', 'Class name', *list(metrics.keys())]
+
+ rows = []
+ if 'classwise' in metrics[headers[-1]]: # If classwise analysis is activated
+ rows += [class_info.split('_', 1) for class_info in list(metrics[headers[-1]]['classwise'].keys())]
+ rows += [['-', 'Mean', ]]
+
+ for _metric_name, score_dict in metrics.items():
+ if 'classwise' in score_dict: # If classwise analysis is activated
+ for cls_num, item in enumerate(score_dict['classwise']):
+ rows[cls_num].append(score_dict['classwise'][item])
+ rows[-1].append(score_dict['mean'])
+
+ metric_std_log += tabulate(rows, headers=headers, tablefmt='grid', numalign='left', stralign='left')
+ logger.info(metric_std_log) # tabulaate is already contained as pandas dependency
diff --git a/src/netspresso_trainer/loggers/tensorboard.py b/src/netspresso_trainer/loggers/tensorboard.py
index 00a964d8c..bf17c5048 100644
--- a/src/netspresso_trainer/loggers/tensorboard.py
+++ b/src/netspresso_trainer/loggers/tensorboard.py
@@ -131,7 +131,8 @@ def __call__(
if losses is not None:
self.log_scalars_with_dict(losses, mode=prefix)
if metrics is not None:
- self.log_scalars_with_dict(metrics, mode=prefix)
+ for k, v in metrics.items(): # Only mean values
+ self.log_scalar(k, v['mean'], mode=prefix)
if isinstance(images, dict): # TODO: array with multiple dicts
self.log_images_with_dict(images, mode=prefix)
diff --git a/src/netspresso_trainer/loggers/visualizer.py b/src/netspresso_trainer/loggers/visualizer.py
index 6ad54e256..984365b0d 100644
--- a/src/netspresso_trainer/loggers/visualizer.py
+++ b/src/netspresso_trainer/loggers/visualizer.py
@@ -107,13 +107,14 @@ def __call__(self, results: List[Tuple[np.ndarray, np.ndarray]], images=None):
y1 = int(bbox_label[1])
x2 = int(bbox_label[2])
y2 = int(bbox_label[3])
+ conf_score = "" if len(bbox_label) <= 4 else " " + str(round(bbox_label[4], 2))
color = self.cmap[class_label].tolist()
image = cv2.rectangle(image, (x1, y1), (x2, y2), color=color, thickness=2)
- text_size, _ = cv2.getTextSize(str(class_name), cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
+ text_size, _ = cv2.getTextSize(f"{class_name}{conf_score}", cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
text_w, text_h = text_size
image = cv2.rectangle(image, (x1, y1-5-text_h), (x1+text_w, y1), color=color, thickness=-1)
- image = cv2.putText(image, str(class_name), (x1, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
+ image = cv2.putText(image, f"{class_name}{conf_score}", (x1, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
return_images.append(image[np.newaxis, ...])
return_images = np.concatenate(return_images, axis=0)
diff --git a/src/netspresso_trainer/losses/detection/__init__.py b/src/netspresso_trainer/losses/detection/__init__.py
index 1358f9b18..9b068dd04 100644
--- a/src/netspresso_trainer/losses/detection/__init__.py
+++ b/src/netspresso_trainer/losses/detection/__init__.py
@@ -16,4 +16,5 @@
from .retinanet import RetinaNetLoss
from .rtdetr import DETRLoss
+from .yolo import YOLOFastestLoss
from .yolox import YOLOXLoss
diff --git a/src/netspresso_trainer/losses/detection/yolo.py b/src/netspresso_trainer/losses/detection/yolo.py
new file mode 100644
index 000000000..127fd0a21
--- /dev/null
+++ b/src/netspresso_trainer/losses/detection/yolo.py
@@ -0,0 +1,149 @@
+# Copyright (C) 2024 Nota Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ----------------------------------------------------------------------------
+
+import math
+from typing import Dict, List
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .yolox import IOUloss, YOLOXLoss, xyxy2cxcywh
+
+
+def xyxy2cxcywhn(bboxes, img_size):
+ new_bboxes = bboxes.clone() / img_size
+ new_bboxes[:, 2] = new_bboxes[:, 2] - new_bboxes[:, 0]
+ new_bboxes[:, 3] = new_bboxes[:, 3] - new_bboxes[:, 1]
+ new_bboxes[:, 0] = new_bboxes[:, 0] + new_bboxes[:, 2] * 0.5
+ new_bboxes[:, 1] = new_bboxes[:, 1] + new_bboxes[:, 3] * 0.5
+ return new_bboxes
+
+def bboxes_iou(bboxes_a, bboxes_b, xyxy=False):
+ if bboxes_a.shape[1] != 4 or bboxes_b.shape[1] != 4:
+ raise IndexError
+
+ if xyxy:
+ tl = torch.max(bboxes_a[:, None, :2], bboxes_b[:, :2])
+ br = torch.min(bboxes_a[:, None, 2:], bboxes_b[:, 2:])
+ area_a = torch.prod(bboxes_a[:, 2:] - bboxes_a[:, :2], 1)
+ area_b = torch.prod(bboxes_b[:, 2:] - bboxes_b[:, :2], 1)
+ else:
+ tl = torch.max(
+ (bboxes_a[:, None, :2] - bboxes_a[:, None, 2:] / 2),
+ (bboxes_b[:, :2] - bboxes_b[:, 2:] / 2),
+ )
+ br = torch.min(
+ (bboxes_a[:, None, :2] + bboxes_a[:, None, 2:] / 2),
+ (bboxes_b[:, :2] + bboxes_b[:, 2:] / 2),
+ )
+
+ area_a = torch.prod(bboxes_a[:, 2:], 1)
+ area_b = torch.prod(bboxes_b[:, 2:], 1)
+ en = (tl < br).type(tl.type()).prod(dim=2)
+ area_i = torch.prod(br - tl, 2) * en # * ((tl < br).all())
+ return area_i / (area_a[:, None] + area_b - area_i)
+
+
+class YOLOFastestLoss(YOLOXLoss):
+ def __init__(self, anchors, l1_activate_epoch=None, cur_epoch=None, **kwargs) -> None:
+ super().__init__(l1_activate_epoch, cur_epoch, **kwargs)
+ self.iou_loss = IOUloss(reduction="none", loss_type="giou")
+ self.anchors = [torch.tensor(anchor, dtype=torch.float).view(-1, 2) for anchor in anchors]
+ self.num_anchors = self.anchors[0].size(0)
+
+ def use_l1_update(self):
+ return False
+
+ def get_output_and_grid(self, output, k, stride, dtype):
+ grid = self.grids[k]
+ device = output.device
+ batch_size = output.shape[0]
+ hsize, wsize = output.shape[-2:]
+ if grid.shape[2:4] != output.shape[2:4]:
+ yv, xv = torch.meshgrid(torch.arange(hsize), torch.arange(wsize), indexing="ij")
+ grid = torch.stack((xv, yv), 2).repeat(self.num_anchors,1,1,1).view(1, self.num_anchors, hsize, wsize, 2).type(dtype)
+ self.grids[k] = grid
+ anchors = self.anchors[k].view(1, self.num_anchors, 1, 1, 2).to(device)
+ output = output.permute(0, 1, 3, 4, 2)
+ output = torch.cat([
+ (output[..., :2].sigmoid() + grid) * stride,
+ 2. * (torch.tanh(output[..., 2:4]/2 -.549306) + 1.) * anchors,
+ output[..., 4:]
+ ], dim=-1).reshape(
+ batch_size, hsize * wsize * self.num_anchors, -1
+ )
+ return output, grid.view(1, -1, 2)
+
+ def forward(self, out: List, target: Dict) -> torch.Tensor:
+ self.use_l1 = self.use_l1_update()
+
+ out = out['pred']
+ x_shifts = []
+ y_shifts = []
+ expanded_strides = []
+
+ self.grids = [torch.zeros(1)] * len(out)
+ self.num_classes = target['num_classes']
+ img_size = target['img_size']
+
+ target = target['gt']
+
+ out_for_loss = []
+ for k, o in enumerate(out):
+ stride_this_level = img_size // o.size(-1)
+
+ o, grid = self.get_output_and_grid(
+ o, k, stride_this_level, o.type()
+ )
+ x_shifts.append(grid[:, :, 0])
+ y_shifts.append(grid[:, :, 1])
+ expanded_strides.append(
+ torch.zeros(1, grid.shape[1])
+ .fill_(stride_this_level)
+ .type_as(o)
+ )
+ out_for_loss.append(o)
+
+ # YOLOX model learns box cxcywh format directly,
+ # but our detection dataloader gives xyxy format.
+ for i in range(len(target)):
+ target[i]['boxes'] = xyxy2cxcywh(target[i]['boxes'])
+
+ # Ready for l1 loss
+ origin_preds = []
+ for o in out:
+ out_for_l1 = o.view(o.shape[0], self.num_anchors, -1, o.shape[-2], o.shape[-1]).permute(0, 1, 3, 4, 2)
+ reg_output = out_for_l1[..., :4]
+ batch_size = reg_output.shape[0]
+ reg_output = reg_output.reshape(
+ batch_size, -1, 4
+ )
+ origin_preds.append(reg_output.clone())
+
+ total_loss, iou_loss, conf_loss, cls_loss, l1_loss, num_fg = self.get_losses(
+ None,
+ x_shifts,
+ y_shifts,
+ expanded_strides,
+ target,
+ torch.cat(out_for_loss, 1),
+ origin_preds,
+ dtype=out[0].dtype,
+ )
+
+ # TODO: return as dict
+ return total_loss
diff --git a/src/netspresso_trainer/losses/registry.py b/src/netspresso_trainer/losses/registry.py
index dc296a317..aede74790 100644
--- a/src/netspresso_trainer/losses/registry.py
+++ b/src/netspresso_trainer/losses/registry.py
@@ -15,7 +15,7 @@
# ----------------------------------------------------------------------------
from .common import CrossEntropyLoss, SigmoidFocalLoss
-from .detection import DETRLoss, RetinaNetLoss, YOLOXLoss
+from .detection import DETRLoss, RetinaNetLoss, YOLOFastestLoss, YOLOXLoss
from .pose_estimation import RTMCCLoss
from .segmentation import PIDNetLoss, SegCrossEntropyLoss
@@ -24,6 +24,7 @@
'seg_cross_entropy': SegCrossEntropyLoss,
'pidnet_loss': PIDNetLoss,
'yolox_loss': YOLOXLoss,
+ 'yolofastest_loss': YOLOFastestLoss,
'retinanet_loss': RetinaNetLoss,
'detr_loss': DETRLoss,
'focal_loss': SigmoidFocalLoss,
diff --git a/src/netspresso_trainer/metrics/base.py b/src/netspresso_trainer/metrics/base.py
index 719562566..2dc3e37be 100644
--- a/src/netspresso_trainer/metrics/base.py
+++ b/src/netspresso_trainer/metrics/base.py
@@ -14,17 +14,61 @@
#
# ----------------------------------------------------------------------------
-from typing import Dict, List
+from typing import Any, Dict, List
+
+import torch
from ..utils.record import MetricMeter
class BaseMetric:
- def __init__(self, metric_names, primary_metric, **kwargs):
- assert primary_metric in metric_names
- self.metric_names = metric_names
- self.primary_metric = primary_metric
- self.metric_meter = {metric_name: MetricMeter(metric_name, ':6.2f') for metric_name in metric_names}
+ def __init__(self, metric_name, num_classes, classwise_analysis, **kwargs):
+ self.metric_name = metric_name
+ self.num_classes = num_classes
+ self.classwise_analysis = classwise_analysis
+ if self.classwise_analysis:
+ self.classwise_metric_meters = [MetricMeter(f'{metric_name}_{i}', ':6.2f') for i in range(num_classes)]
+ self.metric_meter = MetricMeter(metric_name, ':6.2f')
def calibrate(self, pred, target, **kwargs):
- pass
+ raise NotImplementedError
+
+
+class MetricFactory:
+ def __init__(self, task, metrics, metric_adaptor, classwise_analysis) -> None:
+ self.task = task
+ self.metrics = metrics
+ self.metric_adaptor = metric_adaptor
+ self.classwise_analysis = classwise_analysis
+
+ def reset_values(self):
+ for phase in self.metrics:
+ [metric.metric_meter.reset() for metric in self.metrics[phase]]
+
+ def update(self, pred: torch.Tensor, target: torch.Tensor, phase: str, **kwargs: Any) -> None:
+ if len(pred) == 0: # Removed dummy batch has 0 len
+ return
+ kwargs.update(self.metric_adaptor(pred, target))
+ for metric in self.metrics[phase.lower()]:
+ metric.calibrate(pred, target, **kwargs)
+
+ def result(self, phase='train'):
+ ret = {metric.metric_name: {} for metric in self.metrics[phase]} # Initialize with empty dict
+
+ if phase == 'valid' and self.classwise_analysis: # Add classwise results only for valid phase
+ for metric in self.metrics[phase]:
+ classwise_result_dict = {i:classwise_meter.avg for i, classwise_meter in enumerate(metric.classwise_metric_meters)}
+ ret[metric.metric_name] = {'classwise': classwise_result_dict}
+
+ for metric in self.metrics[phase]:
+ ret[metric.metric_name]['mean'] = metric.metric_meter.avg # Add mean score
+
+ return ret
+
+ @property
+ def metric_names(self):
+ return [metric.metric_name for metric in self.metrics[list(self.metrics.keys())[0]]]
+
+ @property
+ def primary_metric(self):
+ return self.metric_names[0]
diff --git a/src/netspresso_trainer/metrics/builder.py b/src/netspresso_trainer/metrics/builder.py
index 8a7793213..f305101f8 100644
--- a/src/netspresso_trainer/metrics/builder.py
+++ b/src/netspresso_trainer/metrics/builder.py
@@ -16,46 +16,32 @@
from typing import Any, Dict
-import torch
+from .base import MetricFactory
+from .registry import METRIC_ADAPTORS, METRIC_LIST, PHASE_LIST, TASK_AVAILABLE_METRICS, TASK_DEFUALT_METRICS
-from .registry import PHASE_LIST, TASK_METRIC
+def build_metrics(task: str, model_conf, metrics_conf, num_classes, **kwargs) -> MetricFactory:
+ metric_names = metrics_conf.metric_names
+ classwise_analysis = metrics_conf.classwise_analysis
-class MetricFactory:
- def __init__(self, task, conf_model, **kwargs) -> None:
- self.task = task
- self.conf_model = conf_model
+ if metric_names is None:
+ metric_names = TASK_DEFUALT_METRICS[task]
+ metric_names = [m.lower() for m in metric_names]
+ assert all(metric in TASK_AVAILABLE_METRICS[task] for metric in metric_names), \
+ f"Available metrics for {task} are {TASK_AVAILABLE_METRICS[task]}"
- assert self.task in TASK_METRIC
- self.metric_cls = TASK_METRIC[self.task]
+ # TODO: This code assumes there is only one loss module. Fix here later.
+ if hasattr(model_conf.losses[0], 'ignore_index'):
+ kwargs['ignore_index'] = model_conf.losses[0].ignore_index
- # TODO: This code assumes there is only one loss module. Fix here later.
- if hasattr(conf_model.losses[0], 'ignore_index'):
- kwargs['ignore_index'] = conf_model.losses[0].ignore_index
- self.metrics = {phase: self.metric_cls(**kwargs) for phase in PHASE_LIST}
+ metrics = {}
+ for phase in PHASE_LIST:
+ if phase == 'valid': # classwise_analysis is only available in valid phase
+ metrics[phase] = [METRIC_LIST[name](num_classes=num_classes, classwise_analysis=classwise_analysis, **kwargs) for name in metric_names]
+ else:
+ metrics[phase] = [METRIC_LIST[name](num_classes=num_classes, classwise_analysis=False, **kwargs) for name in metric_names]
- def reset_values(self):
- for phase in PHASE_LIST:
- [meter.reset() for _, meter in self.metrics[phase].metric_meter.items()]
+ metric_adaptor = METRIC_ADAPTORS[task](metric_names)
- def update(self, pred: torch.Tensor, target: torch.Tensor, phase: str, **kwargs: Any) -> None:
- if len(pred) == 0: # Removed dummy batch has 0 len
- return
- phase = phase.lower()
- self.metrics[phase].calibrate(pred, target)
-
- def result(self, phase='train'):
- return {metric_name: meter.avg for metric_name, meter in self.metrics[phase].metric_meter.items()}
-
- @property
- def metric_names(self):
- return self.metrics[list(self.metrics.keys())[0]].metric_names
-
- @property
- def primary_metric(self):
- return self.metrics[list(self.metrics.keys())[0]].primary_metric
-
-
-def build_metrics(task: str, conf_model, **kwargs) -> MetricFactory:
- metric_handler = MetricFactory(task, conf_model, **kwargs)
+ metric_handler = MetricFactory(task, metrics, metric_adaptor, classwise_analysis)
return metric_handler
diff --git a/src/netspresso_trainer/metrics/classification/__init__.py b/src/netspresso_trainer/metrics/classification/__init__.py
index d4433a76a..1e4ccae3a 100644
--- a/src/netspresso_trainer/metrics/classification/__init__.py
+++ b/src/netspresso_trainer/metrics/classification/__init__.py
@@ -14,4 +14,4 @@
#
# ----------------------------------------------------------------------------
-from .metric import ClassificationMetric
+from .metric import ClassificationMetricAdaptor, Top1Accuracy, Top5Accuracy
diff --git a/src/netspresso_trainer/metrics/classification/metric.py b/src/netspresso_trainer/metrics/classification/metric.py
index 2fa11fbfa..2b91f0f1f 100644
--- a/src/netspresso_trainer/metrics/classification/metric.py
+++ b/src/netspresso_trainer/metrics/classification/metric.py
@@ -21,35 +21,64 @@
from ..base import BaseMetric
-TOPK_MAX = 20
-
@torch.no_grad()
def accuracy_topk(pred, target):
"""Computes the accuracy over the k top predictions for the specified values of k"""
- maxk = pred.shape[-1]
pred = pred.T
class_num = pred.shape[0]
correct = np.equal(pred, np.tile(target, (class_num, 1)))
- return lambda topk: correct[:min(topk, maxk)].reshape(-1).astype('float').sum(0)
+ return correct
+
+
+class ClassificationMetricAdaptor:
+ '''
+ Adapter to process redundant operations for the metrics.
+ '''
+ def __init__(self, metric_names) -> None:
+ self.metric_names = metric_names
+
+ def __call__(self, predictions: List[dict], targets: List[dict]):
+ ret = {}
+ if 'top1_accuracy' in self.metric_names or 'top5_accuracy' in self.metric_names:
+ correct = accuracy_topk(predictions, targets)
+ ret['correct'] = correct
+
+ return ret
+
+class Top1Accuracy(BaseMetric):
+ def __init__(self, num_classes, classwise_analysis, **kwargs):
+ metric_name = 'Acc@1' # Name for logging
+ super().__init__(metric_name=metric_name, num_classes=num_classes, classwise_analysis=classwise_analysis)
+
+ def calibrate(self, pred, target, **kwargs):
+ topk = 1
+ correct = kwargs['correct']
+ correct = np.logical_or.reduce(correct[:topk])
+ if self.classwise_analysis: # Classwise analysis
+ for correct_class, c in zip(target, correct):
+ count = 1 if c else 0
+ self.classwise_metric_meters[correct_class].update(count, n=1)
-class ClassificationMetric(BaseMetric):
- SUPPORT_METRICS: List[str] = ['Acc@1', 'Acc@5']
+ Acc1_correct = correct.sum()
+ self.metric_meter.update(Acc1_correct, n=pred.shape[0])
- def __init__(self, **kwargs):
- # TODO: Select metrics by user
- metric_names = ['Acc@1', 'Acc@5']
- primary_metric = 'Acc@1'
- assert set(metric_names).issubset(ClassificationMetric.SUPPORT_METRICS)
- super().__init__(metric_names=metric_names, primary_metric=primary_metric)
+class Top5Accuracy(BaseMetric):
+ def __init__(self, num_classes, classwise_analysis, **kwargs):
+ metric_name = 'Acc@5' # Name for logging
+ super().__init__(metric_name=metric_name, num_classes=num_classes, classwise_analysis=classwise_analysis)
def calibrate(self, pred, target, **kwargs):
- topk_callable = accuracy_topk(pred, target)
+ topk = 5
+ correct = kwargs['correct']
+ correct = np.logical_or.reduce(correct[:topk])
- Acc1_correct = topk_callable(topk=1)
- Acc5_correct = topk_callable(topk=5)
+ if self.classwise_analysis: # Classwise analysis
+ for correct_class, c in zip(target, correct):
+ count = 1 if c else 0
+ self.classwise_metric_meters[correct_class].update(count, n=1)
- self.metric_meter['Acc@1'].update(Acc1_correct, n=pred.shape[0])
- self.metric_meter['Acc@5'].update(Acc5_correct, n=pred.shape[0])
+ Acc1_correct = correct.sum()
+ self.metric_meter.update(Acc1_correct, n=pred.shape[0])
diff --git a/src/netspresso_trainer/metrics/detection/__init__.py b/src/netspresso_trainer/metrics/detection/__init__.py
index cd02cecd9..4dca99168 100644
--- a/src/netspresso_trainer/metrics/detection/__init__.py
+++ b/src/netspresso_trainer/metrics/detection/__init__.py
@@ -14,4 +14,4 @@
#
# ----------------------------------------------------------------------------
-from .metric import DetectionMetric
+from .metric import DetectionMetricAdaptor, mAP50, mAP50_95, mAP75
diff --git a/src/netspresso_trainer/metrics/detection/metric.py b/src/netspresso_trainer/metrics/detection/metric.py
index eb7eabd68..1ff539ac8 100644
--- a/src/netspresso_trainer/metrics/detection/metric.py
+++ b/src/netspresso_trainer/metrics/detection/metric.py
@@ -132,6 +132,7 @@ def average_precisions_per_class(
prediction_confidence: np.ndarray,
prediction_class_ids: np.ndarray,
true_class_ids: np.ndarray,
+ num_classes,
eps: float = 1e-16,
) -> np.ndarray:
"""
@@ -143,6 +144,7 @@ def average_precisions_per_class(
prediction_confidence (np.ndarray): Objectness value from 0-1.
prediction_class_ids (np.ndarray): Predicted object classes.
true_class_ids (np.ndarray): True object classes.
+ num_classes (int): The number of classes.
eps (float, optional): Small value to prevent division by zero.
Returns:
@@ -153,16 +155,22 @@ def average_precisions_per_class(
prediction_class_ids = prediction_class_ids[sorted_indices]
unique_classes, class_counts = np.unique(true_class_ids, return_counts=True)
- num_classes = unique_classes.shape[0]
- average_precisions = np.zeros((num_classes, matches.shape[1]))
+ average_precisions = np.full((num_classes, matches.shape[1]), np.nan)
for class_idx, class_id in enumerate(unique_classes):
is_class = prediction_class_ids == class_id
total_true = class_counts[class_idx]
- total_prediction = is_class.sum()
+ total_predictions = is_class.sum()
- if total_prediction == 0 or total_true == 0:
+ if total_true == 0:
+ continue
+
+ if total_predictions == 0:
+ for iou_level_idx in range(matches.shape[1]):
+ average_precisions[
+ int(class_id), iou_level_idx
+ ] = 0.0
continue
false_positives = (1 - matches[is_class]).cumsum(0)
@@ -172,7 +180,7 @@ def average_precisions_per_class(
for iou_level_idx in range(matches.shape[1]):
average_precisions[
- class_idx, iou_level_idx
+ int(class_id), iou_level_idx
] = compute_average_precision(
recall[:, iou_level_idx], precision[:, iou_level_idx]
)
@@ -180,19 +188,14 @@ def average_precisions_per_class(
return average_precisions
-class DetectionMetric(BaseMetric):
- SUPPORT_METRICS: List[str] = ['map50', 'map75', 'map50_95']
-
- def __init__(self, **kwargs):
- # TODO: Select metrics by user
- metric_names: List[str] = ['map50', 'map75', 'map50_95']
- primary_metric: str = 'map50'
- assert set(metric_names).issubset(DetectionMetric.SUPPORT_METRICS)
- super().__init__(metric_names=metric_names, primary_metric=primary_metric)
-
- def calibrate(self, predictions, targets, **kwargs):
- result_dict = {k: 0. for k in self.metric_names}
+class DetectionMetricAdaptor:
+ '''
+ Adapter to process redundant operations for the metrics.
+ '''
+ def __init__(self, metric_names) -> None:
+ self.metric_names = metric_names
+ def __call__(self, predictions: List[dict], targets: List[dict]):
iou_thresholds = np.linspace(0.5, 0.95, 10)
stats = []
@@ -224,20 +227,69 @@ def calibrate(self, predictions, targets, **kwargs):
)
)
+ return {'stats': stats}
+
+
+class mAP50(BaseMetric):
+ def __init__(self, num_classes, classwise_analysis, **kwargs):
+ metric_name = 'mAP50' # Name for logging
+ super().__init__(metric_name=metric_name, num_classes=num_classes, classwise_analysis=classwise_analysis)
+
+ def calibrate(self, predictions, targets, **kwargs):
+ stats = kwargs['stats'] # Get from DetectionMetricAdapter
+
# Compute average precisions if any matches exist
if stats:
concatenated_stats = [np.concatenate(items, 0) for items in zip(*stats)]
- average_precisions = average_precisions_per_class(*concatenated_stats)
- #result_dict['map50'] = average_precisions[:, 0].mean()
- #result_dict['map75'] = average_precisions[:, 5].mean()
- #result_dict['map50_95'] = average_precisions.mean()
- self.metric_meter['map50'].update(average_precisions[:, 0].mean())
- self.metric_meter['map75'].update(average_precisions[:, 5].mean())
- self.metric_meter['map50_95'].update(average_precisions.mean())
+ average_precisions = average_precisions_per_class(*concatenated_stats, num_classes=self.num_classes)
+
+ if self.classwise_analysis:
+ for i, classwise_meter in enumerate(self.classwise_metric_meters):
+ classwise_meter.update(average_precisions[i, 0])
+ self.metric_meter.update(np.nanmean(average_precisions[:, 0]))
else:
- #result_dict['map50'], result_dict['map75'], result_dict['map50_95'] = 0, 0, 0
- self.metric_meter['map50'].update(0)
- self.metric_meter['map75'].update(0)
- self.metric_meter['map50_95'].update(0)
+ self.metric_meter.update(0)
+
+
+class mAP75(BaseMetric):
+ def __init__(self, num_classes, classwise_analysis, **kwargs):
+ # TODO: Select metrics by user
+ metric_name = 'mAP75'
+ super().__init__(metric_name=metric_name, num_classes=num_classes, classwise_analysis=classwise_analysis)
+
+ def calibrate(self, predictions, targets, **kwargs):
+ stats = kwargs['stats'] # Get from DetectionMetricAdapter
- return result_dict
+ # Compute average precisions if any matches exist
+ if stats:
+ concatenated_stats = [np.concatenate(items, 0) for items in zip(*stats)]
+ average_precisions = average_precisions_per_class(*concatenated_stats, num_classes=self.num_classes)
+
+ if self.classwise_analysis:
+ for i, classwise_meter in enumerate(self.classwise_metric_meters):
+ classwise_meter.update(average_precisions[i, 5])
+ self.metric_meter.update(np.nanmean(average_precisions[:, 5]))
+ else:
+ self.metric_meter.update(0)
+
+
+class mAP50_95(BaseMetric):
+ def __init__(self, num_classes, classwise_analysis, **kwargs):
+ # TODO: Select metrics by user
+ metric_name = 'mAP50_95'
+ super().__init__(metric_name=metric_name, num_classes=num_classes, classwise_analysis=classwise_analysis)
+
+ def calibrate(self, predictions, targets, **kwargs):
+ stats = kwargs['stats'] # Get from DetectionMetricAdapter
+
+ # Compute average precisions if any matches exist
+ if stats:
+ concatenated_stats = [np.concatenate(items, 0) for items in zip(*stats)]
+ average_precisions = average_precisions_per_class(*concatenated_stats, num_classes=self.num_classes)
+
+ if self.classwise_analysis:
+ for i, classwise_meter in enumerate(self.classwise_metric_meters):
+ classwise_meter.update(np.nanmean(average_precisions[i, :]))
+ self.metric_meter.update(np.nanmean(average_precisions))
+ else:
+ self.metric_meter.update(0)
diff --git a/src/netspresso_trainer/metrics/pose_estimation/__init__.py b/src/netspresso_trainer/metrics/pose_estimation/__init__.py
index 5317e75fa..c7cb5da7c 100644
--- a/src/netspresso_trainer/metrics/pose_estimation/__init__.py
+++ b/src/netspresso_trainer/metrics/pose_estimation/__init__.py
@@ -14,4 +14,4 @@
#
# ----------------------------------------------------------------------------
-from .metric import PoseEstimationMetric
+from .metric import PCK, PoseEstimationMetricAdaptor
diff --git a/src/netspresso_trainer/metrics/pose_estimation/metric.py b/src/netspresso_trainer/metrics/pose_estimation/metric.py
index 4605792e5..d3cb83608 100644
--- a/src/netspresso_trainer/metrics/pose_estimation/metric.py
+++ b/src/netspresso_trainer/metrics/pose_estimation/metric.py
@@ -21,16 +21,24 @@
from ..base import BaseMetric
-class PoseEstimationMetric(BaseMetric):
- SUPPORT_METRICS: List[str] = ['pck']
+class PoseEstimationMetricAdaptor:
+ '''
+ Adapter to process redundant operations for the metrics.
+ '''
+ def __init__(self, metric_names) -> None:
+ self.metric_names = metric_names
+
+ def __call__(self, predictions: List[dict], targets: List[dict]):
+ return {} # Do nothing
- def __init__(self, **kwargs):
- # TODO: Select metrics by user
- metric_names: List[str] = ['pck']
- primary_metric: str = 'pck'
- assert set(metric_names).issubset(PoseEstimationMetric.SUPPORT_METRICS)
- super().__init__(metric_names=metric_names, primary_metric=primary_metric)
+class PCK(BaseMetric):
+ def __init__(self, num_classes, classwise_analysis, **kwargs):
+ # TODO: Select metrics by user
+ metric_name = 'PCK'
+ if classwise_analysis: # TODO: Implement classwise analysis
+ raise NotImplementedError('Classwise analysis is not supported for PCK metric')
+ super().__init__(metric_name=metric_name, num_classes=num_classes, classwise_analysis=classwise_analysis)
# TODO: Get from config
self.thr = 0.05
self.input_size = (256, 256)
@@ -75,4 +83,4 @@ def calibrate(self, pred, target, **kwargs):
normalize = np.tile(np.array([[self.input_size[0], self.input_size[1]]]), (N, 1))
acc, avg_acc, cnt = self.keypoint_pck_accuracy(pred, target, mask, self.thr, normalize)
- self.metric_meter['pck'].update(avg_acc)
+ self.metric_meter.update(avg_acc)
diff --git a/src/netspresso_trainer/metrics/registry.py b/src/netspresso_trainer/metrics/registry.py
index 3adcfa276..2753a165c 100644
--- a/src/netspresso_trainer/metrics/registry.py
+++ b/src/netspresso_trainer/metrics/registry.py
@@ -17,16 +17,41 @@
from typing import Callable, Dict, Literal, Type
from .base import BaseMetric
-from .classification import ClassificationMetric
-from .detection import DetectionMetric
-from .pose_estimation import PoseEstimationMetric
-from .segmentation import SegmentationMetric
+from .classification import ClassificationMetricAdaptor, Top1Accuracy, Top5Accuracy
+from .detection import DetectionMetricAdaptor, mAP50, mAP50_95, mAP75
+from .pose_estimation import PCK, PoseEstimationMetricAdaptor
+from .segmentation import PixelAccuracy, SegmentationMetricAdaptor, mIoU
-TASK_METRIC: Dict[Literal['classification', 'segmentation', 'detection'], Type[BaseMetric]] = {
- 'classification': ClassificationMetric,
- 'segmentation': SegmentationMetric,
- 'detection': DetectionMetric,
- 'pose_estimation': PoseEstimationMetric,
+METRIC_LIST: Dict[str, Type[BaseMetric]] = {
+ 'top1_accuracy': Top1Accuracy,
+ 'top5_accuracy': Top5Accuracy,
+ 'miou': mIoU,
+ 'pixel_accuracy': PixelAccuracy,
+ 'map50': mAP50,
+ 'map75': mAP75,
+ 'map50_95': mAP50_95,
+ 'pck': PCK,
+}
+
+METRIC_ADAPTORS = {
+ 'classification': ClassificationMetricAdaptor,
+ 'segmentation': SegmentationMetricAdaptor,
+ 'detection': DetectionMetricAdaptor,
+ 'pose_estimation': PoseEstimationMetricAdaptor,
}
PHASE_LIST = ['train', 'valid', 'test']
+
+TASK_AVAILABLE_METRICS = {
+ 'classification': ['top1_accuracy', 'top5_accuracy'],
+ 'segmentation': ['miou', 'pixel_accuracy'],
+ 'detection': ['map50', 'map75', 'map50_95'],
+ 'pose_estimation': ['pck'],
+}
+
+TASK_DEFUALT_METRICS = {
+ 'classification': ['top1_accuracy', 'top5_accuracy'],
+ 'segmentation': ['miou', 'pixel_accuracy'],
+ 'detection': ['map50', 'map75', 'map50_95'],
+ 'pose_estimation': ['pck'],
+}
diff --git a/src/netspresso_trainer/metrics/segmentation/__init__.py b/src/netspresso_trainer/metrics/segmentation/__init__.py
index 59f2c5cbd..b46a6dee5 100644
--- a/src/netspresso_trainer/metrics/segmentation/__init__.py
+++ b/src/netspresso_trainer/metrics/segmentation/__init__.py
@@ -14,4 +14,4 @@
#
# ----------------------------------------------------------------------------
-from .metric import SegmentationMetric
+from .metric import PixelAccuracy, SegmentationMetricAdaptor, mIoU
diff --git a/src/netspresso_trainer/metrics/segmentation/metric.py b/src/netspresso_trainer/metrics/segmentation/metric.py
index 28b1e5e05..1c3cec33e 100644
--- a/src/netspresso_trainer/metrics/segmentation/metric.py
+++ b/src/netspresso_trainer/metrics/segmentation/metric.py
@@ -49,20 +49,24 @@ def avg(self) -> float:
return np.nanmean(self.intersection / self.union)
-class SegmentationMetric(BaseMetric):
- SUPPORT_METRICS: List[str] = ['iou', 'pixel_acc']
-
- def __init__(self, num_classes=None, ignore_index=IGNORE_INDEX_NONE_VALUE):
- # TODO: Select metrics by user
- metric_names = ['iou', 'pixel_acc']
- primary_metric = 'iou'
-
- assert set(metric_names).issubset(SegmentationMetric.SUPPORT_METRICS)
- super().__init__(metric_names=metric_names, primary_metric=primary_metric)
+class SegmentationMetricAdaptor:
+ '''
+ Adapter to process redundant operations for the metrics.
+ '''
+ def __init__(self, metric_names) -> None:
+ self.metric_names = metric_names
+
+ def __call__(self, predictions: List[dict], targets: List[dict]):
+ return {} # Do nothing
+
+
+# TODO: Unify repeated code
+class mIoU(BaseMetric):
+ def __init__(self, num_classes, classwise_analysis, ignore_index=IGNORE_INDEX_NONE_VALUE, **kwargs):
+ metric_name = 'mIoU' # Name for logging
+ super().__init__(metric_name=metric_name, num_classes=num_classes, classwise_analysis=classwise_analysis)
self.ignore_index = ignore_index if ignore_index is not None else IGNORE_INDEX_NONE_VALUE
- self.K = num_classes
-
- self.metric_meter['iou'] = IoUMeter(num_classes=self.K, name='iou')
+ self.metric_meter = IoUMeter(num_classes=self.num_classes, name='iou') #TODO: Temporarily added IoUMeter.
def intersection_and_union(self, output, target):
@@ -77,9 +81,9 @@ def intersection_and_union(self, output, target):
#area_intersection = torch.histc(intersection, bins=self.K, min=0, max=self.K-1)
#area_output = torch.histc(output, bins=self.K, min=0, max=self.K-1)
#area_target = torch.histc(target, bins=self.K, min=0, max=self.K-1)
- area_intersection = np.histogram(intersection, bins=np.linspace(0, self.K, self.K+1))[0]
- area_output = np.histogram(output, bins=np.linspace(0, self.K, self.K+1))[0]
- area_target = np.histogram(target, bins=np.linspace(0, self.K, self.K+1))[0]
+ area_intersection = np.histogram(intersection, bins=np.linspace(0, self.num_classes, self.num_classes+1))[0]
+ area_output = np.histogram(output, bins=np.linspace(0, self.num_classes, self.num_classes+1))[0]
+ area_target = np.histogram(target, bins=np.linspace(0, self.num_classes, self.num_classes+1))[0]
area_union = area_output + area_target - area_intersection
intersection, union, target, output = area_intersection, area_union, area_target, area_output
@@ -91,9 +95,52 @@ def intersection_and_union(self, output, target):
'output': output
}
+ def calibrate(self, pred, target, **kwargs):
+ metrics = self.intersection_and_union(pred, target)
+
+ if self.classwise_analysis: # TODO: Compute in a better way
+ for cls_meter, cls_intersection, cls_union in zip(self.classwise_metric_meters, metrics['intersection'], metrics['union']):
+ if cls_union != 0:
+ cls_meter.update(cls_intersection, cls_union)
+
+ self.metric_meter.update(metrics['intersection'], metrics['union'])
+
+
+class PixelAccuracy(BaseMetric):
+ def __init__(self, num_classes, classwise_analysis, ignore_index=IGNORE_INDEX_NONE_VALUE, **kwargs):
+ metric_name = 'Pixel_acc' # Name for logging
+ super().__init__(metric_name=metric_name, num_classes=num_classes, classwise_analysis=classwise_analysis)
+ self.ignore_index = ignore_index if ignore_index is not None else IGNORE_INDEX_NONE_VALUE
+
+ def intersection_and_union(self, output, target):
+
+ # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1.
+ assert (len(output.shape) in [1, 2, 3])
+
+ assert output.shape == target.shape
+ output = output.reshape(-1)
+ target = target.reshape(-1)
+ output[target == self.ignore_index] = self.ignore_index
+ intersection = output[output == target]
+
+ area_intersection = np.histogram(intersection, bins=np.linspace(0, self.num_classes, self.num_classes+1))[0]
+ area_target = np.histogram(target, bins=np.linspace(0, self.num_classes, self.num_classes+1))[0]
+
+ intersection, target = area_intersection, area_target
+
+ return {
+ 'intersection': intersection,
+ 'target': target,
+ }
+
def calibrate(self, pred, target, **kwargs):
B = pred.shape[0]
metrics = self.intersection_and_union(pred, target)
- self.metric_meter['iou'].update(metrics['intersection'], metrics['union'])
- self.metric_meter['pixel_acc'].update(sum(metrics['intersection']) / (sum(metrics['target']) + 1e-10), n=B)
+
+ if self.classwise_analysis: # TODO: Compute in a better way
+ for cls_meter, cls_intersection, cls_target in zip(self.classwise_metric_meters, metrics['intersection'], metrics['target']):
+ if cls_target != 0:
+ cls_meter.update(cls_intersection, cls_target)
+
+ self.metric_meter.update(sum(metrics['intersection']) / (sum(metrics['target']) + 1e-10), n=B)
diff --git a/src/netspresso_trainer/models/base.py b/src/netspresso_trainer/models/base.py
index e2559259c..8c5a79df4 100644
--- a/src/netspresso_trainer/models/base.py
+++ b/src/netspresso_trainer/models/base.py
@@ -16,8 +16,9 @@
import os
from abc import abstractmethod
-from typing import Callable, Optional, Tuple, Union
+from typing import Callable, List, Optional, Tuple, Union
+import numpy as np
import torch
import torch.nn as nn
from loguru import logger
@@ -44,6 +45,7 @@ def __init__(self, conf_model, backbone, neck, head, freeze_backbone: bool = Fal
if freeze_backbone:
self._freeze_backbone()
logger.info(f"Freeze! {self.backbone_name} is now freezed. Now only tuning with {self.head_name}.")
+ self.__save_dtype = None
def _freeze_backbone(self):
for m in self.backbone.parameters():
@@ -57,6 +59,14 @@ def head_list(self):
def device(self):
return next(self.parameters()).device
+ @property
+ def save_dtype(self):
+ return self.__save_dtype
+
+ @save_dtype.setter
+ def save_dtype(self, dtype):
+ self.__save_dtype = dtype
+
def _get_name(self):
if hasattr(self, 'neck'):
return f"{self.__class__.__name__}[task={self.task}, backbone={self.backbone_name}, neck={self.neck_name}, head={self.head_name}]"
@@ -159,3 +169,110 @@ def set_provider(self, device):
self.inference_session.set_providers(['CUDAExecutionProvider'])
else:
self.inference_session.set_providers(['CPUExecutionProvider'])
+
+
+class TFLiteModel:
+ """
+ TensorFlow Lite (tflite) wrapper class for inferencing.
+ """
+
+ NUM_THREADS = 4
+
+ def __init__(self, model_conf) -> None:
+ self.tflite = self._import_tflite()
+ self.task = model_conf.task
+ assert self.task == 'detection', f"Task {self.task} is not yet supported in this TensorFlow Lite (tflite) model inference."
+ self.name = model_conf.name + '_tflite'
+ self.tflite_path = model_conf.checkpoint.path
+ self.interpreter = self.tflite.Interpreter(model_path=self.tflite_path, num_threads=self.NUM_THREADS)
+ self.interpreter.allocate_tensors()
+ self.input_details = self.interpreter.get_input_details()
+ self.output_details = self.interpreter.get_output_details()
+ self.input_dtype = self.input_details[0]['dtype']
+ self.input_shape = tuple(self.input_details[0]['shape'])
+
+ self.output_dtype = self.output_details[0]['dtype']
+ self.quantized_input = self.input_dtype in [np.int8, np.uint8]
+ self.quantized_output = self.output_dtype in [np.int8, np.uint8]
+ if self.quantized_input:
+ self.input_scale, self.input_zero_point = self.input_details[0]['quantization']
+
+ def _import_tflite(self):
+ try:
+ import tflite_runtime.interpreter as tflite
+ except ImportError:
+ try:
+ import tensorflow.lite as tflite
+ except ImportError as e:
+ raise ImportError("Failed to import tensorflow lite. Please install tflite_runtime or tensorflow") from e
+ return tflite
+
+ def get_name(self) -> str:
+ """Get the name of the model."""
+ return f"{self.__class__.__name__}[model={self.name}]"
+
+ def __call__(self, x: Union[np.ndarray, torch.Tensor], label_size=None, targets=None):
+ """
+ Perform inference on the input tensor.
+
+ Args:
+ x (Union[np.ndarray, torch.Tensor]): Input tensor
+ label_size: Not used in this implementation
+ targets: Not used in this implementation
+
+ Returns:
+ ModelOutput: Output of the model
+ """
+ try:
+ device = x.device if hasattr(x, 'device') else 'cpu'
+ x = self._prepare_input(x)
+
+ if self.quantized_input:
+ x = x / self.input_scale + self.input_zero_point
+ x = x.astype(self.input_dtype)
+
+ assert x.shape == self.input_shape, f"Your input shape {x.shape} does not match with the expected input shape {self.input_shape}"
+ self.interpreter.set_tensor(self.input_details[0]['index'], x)
+ self.interpreter.invoke()
+
+ output = self._process_output(device)
+ return ModelOutput(pred=output)
+ except Exception as e:
+ raise RuntimeError(f"Error during inference: {str(e)}") from e
+
+ def _process_output(self, device: str) -> List[torch.Tensor]:
+ """Process the output of the interpreter."""
+ output = []
+ for details in self.output_details:
+ o = self.interpreter.get_tensor(details["index"])
+ if self.quantized_output:
+ output_quantization_params = details['quantization']
+ o = (o.astype(np.float32) - output_quantization_params[1]) * output_quantization_params[0]
+ output.append(torch.tensor(np.transpose(o, (0, 3, 1, 2))).to(device))
+
+ if len(output) > 1:
+ output.sort(key=lambda x: sum(x.shape), reverse=True)
+ return output
+
+ def eval(self):
+ """Set the model to evaluation mode."""
+ pass # Do nothing for TFLite model
+
+ def _prepare_input(self, x: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
+ """
+ Prepare input tensor for inference.
+
+ Args:
+ x (Union[np.ndarray, torch.Tensor]): Input tensor
+
+ Returns:
+ np.ndarray: Prepared input tensor
+ """
+ if isinstance(x, torch.Tensor):
+ x = x.detach().cpu().numpy()
+ elif not isinstance(x, np.ndarray):
+ raise TypeError(f"Unsupported input type: {type(x)}")
+
+ if x.shape[1] == 3:
+ x = np.transpose(x, (0, 2, 3, 1))
+ return x
diff --git a/src/netspresso_trainer/models/builder.py b/src/netspresso_trainer/models/builder.py
index 77ba01db9..64a80d2ac 100644
--- a/src/netspresso_trainer/models/builder.py
+++ b/src/netspresso_trainer/models/builder.py
@@ -24,7 +24,7 @@
from omegaconf import OmegaConf
from torch.nn.parallel import DistributedDataParallel as DDP
-from .base import ClassificationModel, DetectionModel, ONNXModel, SegmentationModel, TaskModel
+from .base import ClassificationModel, DetectionModel, ONNXModel, SegmentationModel, TaskModel, TFLiteModel
from .registry import (
MODEL_BACKBONE_DICT,
MODEL_FULL_DICT,
@@ -45,9 +45,8 @@ def load_full_model(conf_model, model_name, num_classes, model_checkpoint, use_p
model, model_checkpoint,
load_checkpoint_head=conf_model.checkpoint.load_head,
)
- # TODO: Move to model property
- model.save_dtype = next(model.parameters()).dtype # If loaded model is float16, save it as float16
- model = model.float() # Train with float32
+ model.save_dtype = next(model.parameters()).dtype # If loaded model is float16, save it as float16
+ model = model.float() # Train with float32
return model
@@ -89,9 +88,8 @@ def load_backbone_and_head_model(
model, model_checkpoint,
load_checkpoint_head=conf_model.checkpoint.load_head,
)
- # TODO: Move to model property
- model.save_dtype = next(model.parameters()).dtype # If loaded model is float16, save it as float16
- model = model.float() # Train with float32
+ model.save_dtype = next(model.parameters()).dtype
+ model = model.float() # Train with float32
return model
@@ -138,4 +136,8 @@ def build_model(model_conf, num_classes, devices, distributed) -> nn.Module:
model = ONNXModel(model_conf)
model.set_provider(devices)
+ elif model_format == 'tflite':
+ assert Path(model_conf.checkpoint.path).exists()
+ model = TFLiteModel(model_conf)
+
return model
diff --git a/src/netspresso_trainer/models/heads/detection/experimental/yolo_fastest_head.py b/src/netspresso_trainer/models/heads/detection/experimental/yolo_fastest_head.py
index ec0258920..44e37aa9f 100644
--- a/src/netspresso_trainer/models/heads/detection/experimental/yolo_fastest_head.py
+++ b/src/netspresso_trainer/models/heads/detection/experimental/yolo_fastest_head.py
@@ -21,7 +21,7 @@
import math
from ....op.custom import SeparableConvLayer
-from ....utils import AnchorBasedDetectionModelOutput
+from ....utils import ModelOutput
from .detection import AnchorGenerator
@@ -32,29 +32,23 @@ def __init__(
intermediate_features_dim: List[int],
params: DictConfig) -> None:
super().__init__()
- num_anchors = 3 # TODO
anchors = params.anchors
- num_anchors = len(anchors[0]) // 2
- self.anchors = anchors
- tmp_cell_anchors = []
- for a in self.anchors:
- a = torch.tensor(a).view(-1, 2)
- wa = a[:, 0:1]
- ha = a[:, 1:]
- base_anchors = torch.cat([-wa, -ha, wa, ha], dim=-1)/2
- tmp_cell_anchors.append(base_anchors)
- self.anchor_generator = AnchorGenerator(sizes=((128),)) # TODO: dynamic image_size, and anchor_size as a parameters
- self.anchor_generator.cell_anchors = tmp_cell_anchors
- num_anchors = self.anchor_generator.num_anchors_per_location()[0]
- in_channel = intermediate_features_dim[0]
- self.cls_head = YOLOFastestClassificationHead(in_channel, num_anchors, num_classes)
- self.reg_head = YOLOFastestRegressionHead(in_channel, num_anchors)
+ self.num_anchors = len(anchors[0]) // 2
+ hidden_dim = int(intermediate_features_dim[0])
+ self.cls_head = YOLOFastestClassificationHead(hidden_dim, self.num_anchors, num_classes)
+ self.reg_head = YOLOFastestRegressionHead(hidden_dim, self.num_anchors)
- def forward(self, x):
- anchors = torch.cat(self.anchor_generator(x), dim=0)
- cls_logits, objs = self.cls_head(x)
+ def forward(self, x, target=None):
+ cls_logits, objs = self.cls_head(x)
bbox_regression = self.reg_head(x)
- return AnchorBasedDetectionModelOutput(anchors=anchors, cls_logits=cls_logits, bbox_regression=bbox_regression)
+ outputs = list()
+ for reg, obj, logits in zip(bbox_regression, objs, cls_logits):
+ reg = reg.view(reg.shape[0], self.num_anchors, -1, reg.shape[-2], reg.shape[-1])
+ obj = obj.view(obj.shape[0], self.num_anchors, -1, obj.shape[-2], obj.shape[-1])
+ logits = logits.repeat(1, self.num_anchors, 1, 1).view(logits.shape[0], self.num_anchors, -1, logits.shape[-2], logits.shape[-1])
+ output = torch.cat([reg, obj, logits], 2)
+ outputs.append(output)
+ return ModelOutput(pred=outputs)
def yolo_fastest_head_v2(num_classes, intermediate_features_dim, conf_model_head) -> YOLOFastestHeadV2:
return YOLOFastestHeadV2(num_classes=num_classes,
@@ -68,93 +62,77 @@ def __init__(
in_channels,
num_anchors,
num_classes,
- prior_prob = 0.01,
+ prior_prob = 1e-2,
+ num_layers = 2,
) -> None:
super().__init__()
-
- self.layer_1 = nn.Sequential(*[
- SeparableConvLayer(in_channels, in_channels, 5, padding=2, no_out_act=True),
- SeparableConvLayer(in_channels, in_channels, 5, padding=2, no_out_act=True),
- ])
- self.layer_2 = nn.Sequential(*[
- SeparableConvLayer(in_channels, in_channels, 5, padding=2, no_out_act=True),
- SeparableConvLayer(in_channels, in_channels, 5, padding=2, no_out_act=True),
- ])
-
- self.cls_logits = nn.Conv2d(in_channels, num_classes * num_anchors, 1, 1, 0, bias=True)
- self.obj = nn.Conv2d(in_channels, num_anchors, 1, 1, 0, bias=True)
- nn.init.normal_(self.cls_logits.weight, std=0.01)
- nn.init.constant_(self.cls_logits.bias, -math.log((1 - prior_prob) / prior_prob))
-
self.num_classes = num_classes
self.num_anchors = num_anchors
+ self.layer = nn.ModuleList()
+ self.cls_logits = nn.ModuleList()
+ self.obj = nn.ModuleList()
+ for _ in range(num_layers):
+ self.layer.append(
+ nn.Sequential(*[
+ SeparableConvLayer(in_channels, in_channels, 5, padding=2, no_out_act=True),
+ SeparableConvLayer(in_channels, in_channels, 5, padding=2, no_out_act=True),
+ ])
+ )
+ self.cls_logits.append(nn.Conv2d(in_channels, num_classes, 1, 1, 0, bias=True))
+ self.obj.append(nn.Conv2d(in_channels, num_anchors, 1, 1, 0, bias=True))
+
+ self.initialize_biases(prior_prob=prior_prob)
+
+
+ def initialize_biases(self, prior_prob):
+ for conv in self.cls_logits:
+ b = conv.bias.view(1, -1)
+ b.data.fill_(-math.log((1 - prior_prob) / prior_prob))
+ conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+
+ for conv in self.obj:
+ b = conv.bias.view(1, -1)
+ b.data.fill_(-math.log((1 - prior_prob) / prior_prob))
+ conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
def forward(self, x):
all_cls_logits = []
all_objs = []
- out1 = self.layer_1(x[0])
- out2 = self.layer_2(x[1])
- outputs = [out1, out2]
-
- for idx, features in enumerate(x):
- cls_logits = outputs[idx]
- objectness = cls_logits
- cls_logits = self.cls_logits(cls_logits)
- objectness = self.obj(objectness)
-
- # Permute classification output from (N, A * K, H, W) to (N, HWA, K).
- N, _, H, W = cls_logits.shape
- cls_logits = cls_logits.view(N, -1, self.num_classes, H, W)
- cls_logits = cls_logits.permute(0, 3, 4, 1, 2)
- cls_logits = cls_logits.reshape(N, -1, self.num_classes) # Size=(N, HWA, K)
-
+ for idx, out in enumerate(x):
+ features = self.layer[idx](out)
+ cls_logits = features
+ objectness = features
+ cls_logits = self.cls_logits[idx](cls_logits)
+ objectness = self.obj[idx](objectness)
all_cls_logits.append(cls_logits)
-
- # Permute objectness output from (N, A, H, W) to (N, HWA, 1).
- N, _, H, W = objectness.shape
- objectness = objectness.view(N, -1, 1, H, W)
- objectness = objectness.permute(0, 3, 4, 1, 2)
- objectness = objectness.reshape(N, -1, 1) # Size=(N, HWA, 1)
all_objs.append(objectness)
-
return all_cls_logits, all_objs
-
class YOLOFastestRegressionHead(nn.Module):
def __init__(
self,
in_channels,
- num_anchors,) -> None:
+ num_anchors,
+ num_layers=2,
+ ) -> None:
super().__init__()
-
- self.layer_1 = nn.Sequential(*[
- SeparableConvLayer(in_channels, in_channels, 5, padding=2, no_out_act=True),
- SeparableConvLayer(in_channels, in_channels, 5, padding=2, no_out_act=True),
- ])
- self.layer_2 = nn.Sequential(*[
- SeparableConvLayer(in_channels, in_channels, 5, padding=2, no_out_act=True),
- SeparableConvLayer(in_channels, in_channels, 5, padding=2, no_out_act=True),
- ])
- self.bbox_reg = nn.Conv2d(in_channels, num_anchors * 4, 1, 1, 0, bias=True)
+ self.layer = nn.ModuleList()
+ self.bbox_reg = nn.ModuleList()
+ for _ in range(num_layers):
+ self.layer.append(
+ nn.Sequential(*[
+ SeparableConvLayer(in_channels, in_channels, 5, padding=2, no_out_act=True),
+ SeparableConvLayer(in_channels, in_channels, 5, padding=2, no_out_act=True)
+ ]))
+ self.bbox_reg.append(nn.Conv2d(in_channels, 4 * num_anchors, 1, 1, 0))
- def forward(self, x, targets=None):
+ def forward(self, x):
all_bbox_regression = []
- out1 = self.layer_1(x[0])
- out2 = self.layer_2(x[1])
- outputs = [out1, out2]
-
- for idx, features in enumerate(x):
- bbox_regression = outputs[idx]
- bbox_regression = self.bbox_reg(bbox_regression)
-
- # Permute bbox regression output from (N, 4 * A, H, W) to (N, HWA, 4).
- N, _, H, W = bbox_regression.shape
- bbox_regression = bbox_regression.view(N, -1, 4, H, W)
- bbox_regression = bbox_regression.permute(0, 3, 4, 1, 2)
- bbox_regression = bbox_regression.reshape(N, -1, 4) # Size=(N, HWA, 4)
-
+ for idx, out in enumerate(x):
+ features = self.layer[idx](out)
+ bbox_regression = features
+ bbox_regression = self.bbox_reg[idx](bbox_regression)
all_bbox_regression.append(bbox_regression)
-
return all_bbox_regression
diff --git a/src/netspresso_trainer/models/necks/experimental/fpn.py b/src/netspresso_trainer/models/necks/experimental/fpn.py
index c15fce1e6..e0c08bbaa 100644
--- a/src/netspresso_trainer/models/necks/experimental/fpn.py
+++ b/src/netspresso_trainer/models/necks/experimental/fpn.py
@@ -178,16 +178,17 @@ def __init__(
) -> None:
super().__init__()
- self.in_channels = intermediate_features_dim
+ self.input2_depth = intermediate_features_dim[-2] + intermediate_features_dim[-1]
+ self.input3_depth = intermediate_features_dim[-1]
self.out_channels = params.out_channels
- self.num_ins = len(self.in_channels)
+ self.num_ins = len(intermediate_features_dim)
self._intermediate_features_dim = [self.out_channels for _ in range(self.num_ins)]
# TODO: Make sure this module can process multi-scale features greater than 2.
- self.conv_C2 = ConvLayer(self.in_channels[0]+self.in_channels[1], self.out_channels, 1, 1, padding=0)
- self.conv_C3 = ConvLayer(self.in_channels[1], self.out_channels, 1, 1, padding=0)
+ self.conv_C2 = ConvLayer(self.input2_depth, self.out_channels, 1, 1, padding=0)
+ self.conv_C3 = ConvLayer(self.input3_depth, self.out_channels, 1, 1, padding=0)
- def forward(self, inputs):
- C2, C3 = inputs[0], inputs[1]
+ def forward(self, inputs):
+ C2, C3 = inputs[-2], inputs[-1]
S3 = self.conv_C3(C3)
P2 = F.interpolate(C3, scale_factor=2)
P2 = torch.cat((P2, C2), dim=1)
diff --git a/src/netspresso_trainer/models/op/registry.py b/src/netspresso_trainer/models/op/registry.py
index 19150b6b6..c0c04eb94 100644
--- a/src/netspresso_trainer/models/op/registry.py
+++ b/src/netspresso_trainer/models/op/registry.py
@@ -26,6 +26,7 @@
ACTIVATION_REGISTRY: Dict[str, Type[nn.Module]] = {
'relu': nn.ReLU,
+ 'relu6': nn.ReLU6,
'prelu': nn.PReLU,
'leaky_relu': nn.LeakyReLU,
'gelu': nn.GELU,
diff --git a/src/netspresso_trainer/models/utils.py b/src/netspresso_trainer/models/utils.py
index aa1d09183..2b1c080c0 100644
--- a/src/netspresso_trainer/models/utils.py
+++ b/src/netspresso_trainer/models/utils.py
@@ -58,6 +58,7 @@
'yolox_x': 'coco',
'rtdetr_res18': 'coco',
'rtdetr_res50': 'coco',
+ 'yolo_fastest_v2': 'coco',
}
MODEL_CHECKPOINT_URL_DICT = {
@@ -139,6 +140,9 @@
'rtdetr_res50': {
'coco': "https://netspresso-trainer-public.s3.ap-northeast-2.amazonaws.com/checkpoint/rtdetr/rtdetr_res50_coco.safetensors?versionId=JHmnjY13BEflpnDCYPFJ1c17UwpqDrLQ",
},
+ 'yolo_fastest_v2': {
+ 'coco': "https://netspresso-trainer-public.s3.ap-northeast-2.amazonaws.com/checkpoint/yolofastest/yolo_fastest_v2_coco.safetensors?versionId=CGhNjiZygGVjtHm0M586DzQ6.2FqWvl1"
+ }
}
@@ -257,5 +261,7 @@ def get_model_format(model_conf: omegaconf.DictConfig):
return 'torch.fx'
elif ext == '.onnx':
return 'onnx'
+ elif ext == '.tflite':
+ return 'tflite'
else:
raise ValueError(f"Unsupported model format: {model_conf.checkpoint.path}")
diff --git a/src/netspresso_trainer/pipelines/base.py b/src/netspresso_trainer/pipelines/base.py
index db61bec0c..658a88a9c 100644
--- a/src/netspresso_trainer/pipelines/base.py
+++ b/src/netspresso_trainer/pipelines/base.py
@@ -48,7 +48,7 @@ def __init__(
@property
def sample_input(self):
- return torch.randn((1, 3, self.conf.logging.sample_input_size[0], self.conf.logging.sample_input_size[1]))
+ return torch.randn((1, 3, self.conf.logging.model_save_options.sample_input_size[0], self.conf.logging.model_save_options.sample_input_size[1]))
def log_results(
self,
diff --git a/src/netspresso_trainer/pipelines/builder.py b/src/netspresso_trainer/pipelines/builder.py
index cbd90598f..262926af4 100644
--- a/src/netspresso_trainer/pipelines/builder.py
+++ b/src/netspresso_trainer/pipelines/builder.py
@@ -99,7 +99,7 @@ def build_pipeline(
# Build loss and metric modules
loss_factory = build_losses(conf.model, cur_epoch=cur_epoch)
- metric_factory = build_metrics(task, conf.model, num_classes=train_dataloader.dataset.num_classes)
+ metric_factory = build_metrics(task, conf.model, conf.logging.metrics, train_dataloader.dataset.num_classes)
# Set model EMA
model_ema = None
@@ -143,7 +143,7 @@ def build_pipeline(
# Build modules for evaluation
loss_factory = build_losses(conf.model)
- metric_factory = build_metrics(task, conf.model, num_classes=eval_dataloader.dataset.num_classes)
+ metric_factory = build_metrics(task, conf.model, conf.logging.metrics, eval_dataloader.dataset.num_classes)
# Build logger
single_gpu_or_rank_zero = (not conf.distributed) or (conf.distributed and dist.get_rank() == 0)
diff --git a/src/netspresso_trainer/pipelines/evaluation.py b/src/netspresso_trainer/pipelines/evaluation.py
index a40f05111..c247dd2de 100644
--- a/src/netspresso_trainer/pipelines/evaluation.py
+++ b/src/netspresso_trainer/pipelines/evaluation.py
@@ -97,6 +97,18 @@ def log_end_evaluation(
):
losses = self.loss_factory.result('valid')
metrics = self.metric_factory.result('valid')
+
+ # TODO: Move to logger
+ # If class-wise metrics, convert to class names
+ if 'classwise' in metrics[list(metrics.keys())[0]]:
+ tmp_metrics = {}
+ for metric_name, metric in metrics.items():
+ tmp_metrics[metric_name] = {'mean': metric['mean'], 'classwise': {}}
+ for cls_num, score in metric['classwise'].items():
+ cls_name = self.logger.class_map[cls_num] if cls_num in self.logger.class_map else 'mean'
+ tmp_metrics[metric_name]['classwise'][f'{cls_num}_{cls_name}'] = score
+ metrics = tmp_metrics
+
self.log_results(
prefix='evaluation',
samples=valid_samples,
diff --git a/src/netspresso_trainer/pipelines/task_processors/segmentation.py b/src/netspresso_trainer/pipelines/task_processors/segmentation.py
index bd3ef3264..01bb8e95a 100644
--- a/src/netspresso_trainer/pipelines/task_processors/segmentation.py
+++ b/src/netspresso_trainer/pipelines/task_processors/segmentation.py
@@ -39,7 +39,7 @@ def train_step(self, train_model, batch, optimizer, loss_factory, metric_factory
optimizer.zero_grad()
with torch.cuda.amp.autocast(enabled=self.mixed_precision):
- out = train_model(images, target=target)
+ out = train_model(images, targets=target)
loss_factory.calc(out, target, phase='train')
loss_factory.backward(self.grad_scaler)
diff --git a/src/netspresso_trainer/pipelines/train.py b/src/netspresso_trainer/pipelines/train.py
index 6812bec3b..b34f9cf85 100644
--- a/src/netspresso_trainer/pipelines/train.py
+++ b/src/netspresso_trainer/pipelines/train.py
@@ -88,28 +88,56 @@ def __init__(
Literal['train_losses', 'valid_losses', 'train_metrics', 'valid_metrics'], Dict[str, float]
]] = {}
- # TODO: These will be removed
- self.save_optimizer_state = True
@final
def _is_ready(self):
assert self.model is not None, "`self.model` is not defined!"
assert self.optimizer is not None, "`self.optimizer` is not defined!"
"""Append here if you need more assertion checks!"""
- assert self.conf.logging.save_checkpoint_epoch % self.conf.logging.validation_epoch == 0, \
+ assert self.conf.logging.model_save_options.save_checkpoint_epoch % self.conf.logging.model_save_options.validation_epoch == 0, \
"`save_checkpoint_epoch` should be the multiplier of `validation_epoch`."
+ assert self.conf.logging.model_save_options.best_model_criterion.lower() in ['loss', 'metric'], \
+ "`best_model_criterion` should be selected from ['loss', 'metric']"
return True
def epoch_with_valid_logging(self, epoch: int):
- validation_freq = self.conf.logging.validation_epoch
+ validation_freq = self.conf.logging.model_save_options.validation_epoch
last_epoch = epoch == self.conf.training.epochs
return (epoch % validation_freq == 1 % validation_freq) or last_epoch
def epoch_with_checkpoint_saving(self, epoch: int):
- checkpoint_freq = self.conf.logging.save_checkpoint_epoch
+ checkpoint_freq = self.conf.logging.model_save_options.save_checkpoint_epoch
last_epoch = epoch == self.conf.training.epochs
return (epoch % checkpoint_freq == 1 % checkpoint_freq) or last_epoch
+ def _get_valid_records(self, best_model_criterion):
+ if best_model_criterion == 'loss':
+ return {
+ epoch: record['valid_losses'].get('total')
+ for epoch, record in self.training_history.items()
+ if 'valid_losses' in record and 'total' in record['valid_losses']
+ }
+ elif best_model_criterion == 'metric':
+ metric_key = self.metric_factory.primary_metric
+ return {
+ epoch: record['valid_metrics'].get(metric_key)['mean'] # Only mean value is considered
+ for epoch, record in self.training_history.items()
+ if 'valid_metrics' in record and metric_key in record['valid_metrics']
+ }
+ else:
+ raise ValueError("best_model_criterion should be either 'loss' or 'metric'")
+
+ def get_best_epoch(self):
+ best_model_criterion = self.conf.logging.model_save_options.best_model_criterion.lower()
+
+ valid_records = self._get_valid_records(best_model_criterion)
+
+ if not valid_records:
+ return
+
+ comparison_func = min if best_model_criterion == 'loss' else max # TODO: It may depends on the specific metric
+ return comparison_func(valid_records, key=valid_records.get)
+
@property
def learning_rate(self):
return mean([param_group['lr'] for param_group in self.optimizer.param_groups])
@@ -226,6 +254,18 @@ def log_end_epoch(
if valid_logging:
valid_losses = self.loss_factory.result('valid') if valid_logging else None
valid_metrics = self.metric_factory.result('valid') if valid_logging else None
+
+ # TODO: Move to logger
+ # If class-wise metrics, convert to class names
+ if 'classwise' in valid_metrics[list(valid_metrics.keys())[0]]:
+ tmp_metrics = {}
+ for metric_name, metric in valid_metrics.items():
+ tmp_metrics[metric_name] = {'mean': metric['mean'], 'classwise': {}}
+ for cls_num, score in metric['classwise'].items():
+ cls_name = self.logger.class_map[cls_num] if cls_num in self.logger.class_map else 'mean'
+ tmp_metrics[metric_name]['classwise'][f'{cls_num}_{cls_name}'] = score
+ valid_metrics = tmp_metrics
+
self.log_results(prefix='validation', epoch=epoch, samples=valid_samples, losses=valid_losses, metrics=valid_metrics)
summary_record = {'train_losses': train_losses, 'train_metrics': train_metrics}
@@ -238,17 +278,29 @@ def save_checkpoint(self, epoch: int):
model = self.model_ema.ema_model
else:
model = self.model.module if hasattr(self.model, 'module') else self.model
+
if hasattr(model, 'deploy'):
model.deploy()
- save_dtype = model.save_dtype
+ save_dtype = model.save_dtype
if save_dtype == torch.float16:
model = copy.deepcopy(model).type(save_dtype)
+
logging_dir = self.logger.result_dir
- model_path = Path(logging_dir) / f"{self.task}_{self.model_name}_epoch_{epoch}.ext"
- optimizer_path = Path(logging_dir) / f"{self.task}_{self.model_name}_epoch_{epoch}_optimzer.pth"
+ save_best_only = self.conf.logging.model_save_options.save_best_only
- if self.save_optimizer_state:
+ if save_best_only:
+ if epoch == self.get_best_epoch():
+ self._save_model(model=model, epoch=epoch, model_name_tag="best", logging_dir=logging_dir)
+ self._save_model(model=model, epoch=epoch, model_name_tag="last", logging_dir=logging_dir)
+ else:
+ self._save_model(model=model, epoch=epoch, model_name_tag=f"epoch_{epoch}", logging_dir=logging_dir)
+
+ def _save_model(self, model, epoch: int, model_name_tag: str, logging_dir: Path):
+ model_path = Path(logging_dir) / f"{self.task}_{self.model_name}_{model_name_tag}.ext"
+ optimizer_path = Path(logging_dir) / f"{self.task}_{self.model_name}_{model_name_tag}_optimizer.pth"
+
+ if self.conf.logging.model_save_options.save_optimizer_state:
optimizer = self.optimizer.module if hasattr(self.optimizer, 'module') else self.optimizer
save_dict = {'optimizer': optimizer.state_dict(), 'last_epoch': epoch}
torch.save(save_dict, optimizer_path)
@@ -259,47 +311,49 @@ def save_checkpoint(self, epoch: int):
torch.save(model, model_path.with_suffix(".pt"))
logger.debug(f"PyTorch FX model saved at {str(model_path.with_suffix('.pt'))}")
return
+
pytorch_model_state_dict_path = model_path.with_suffix(".safetensors")
save_checkpoint(model.state_dict(), pytorch_model_state_dict_path)
logger.debug(f"PyTorch model saved at {str(pytorch_model_state_dict_path)}")
def save_best(self):
- valid_losses = {epoch: record['valid_losses'].get('total') for epoch, record in self.training_history.items()
- if 'valid_losses' in record}
- if not valid_losses:
- return # No validation loss recorded
- best_epoch = min(valid_losses, key=valid_losses.get)
-
+ opset_version = self.conf.logging.model_save_options.onnx_export_opset
logging_dir = self.logger.result_dir
+ best_epoch = self.get_best_epoch()
- best_checkpoint_path = Path(logging_dir) / f"{self.task}_{self.model_name}_epoch_{best_epoch}.ext"
- best_model_save_path = Path(logging_dir) / f"{self.task}_{self.model_name}_best.ext"
+ if not best_epoch:
+ return
model = self.model.module if hasattr(self.model, 'module') else self.model
- save_dtype = model.save_dtype
- best_model_to_save = copy.deepcopy(model)
- if hasattr(best_model_to_save, 'deploy'):
- best_model_to_save.deploy()
+ best_model = copy.deepcopy(model)
+ if hasattr(best_model, 'deploy'):
+ best_model.deploy()
- if self.is_graphmodule_training:
- best_model_to_save.load_state_dict(load_checkpoint(best_checkpoint_path.with_suffix('.pt')).state_dict())
- save_onnx(best_model_to_save, best_model_save_path.with_suffix(".onnx"), sample_input=self.sample_input.type(save_dtype), opset_version=self.conf.logging.onnx_export_opset)
- logger.info(f"ONNX model converting and saved at {str(best_model_save_path.with_suffix('.onnx'))}")
- torch.save(best_model_to_save, best_model_save_path.with_suffix(".pt"))
- logger.info(f"Best model saved at {str(best_model_save_path.with_suffix('.pt'))}")
- return
+ save_dtype = best_model.save_dtype
+ if save_dtype == torch.float16:
+ best_model = best_model.type(save_dtype)
- best_model_to_save.load_state_dict(load_checkpoint(best_checkpoint_path.with_suffix('.safetensors')))
- pytorch_best_model_state_dict_path = best_model_save_path.with_suffix(".safetensors")
- save_checkpoint(best_model_to_save.state_dict(), pytorch_best_model_state_dict_path)
- logger.info(f"Best model saved at {str(pytorch_best_model_state_dict_path)}")
+ model_name_tag = "best" if self.conf.logging.model_save_options.save_best_only else f"epoch_{best_epoch}"
+ checkpoint_path = Path(logging_dir) / f"{self.task}_{self.model_name}_{model_name_tag}.ext"
- try:
- save_onnx(best_model_to_save, best_model_save_path.with_suffix(".onnx"), sample_input=self.sample_input.type(save_dtype), opset_version=self.conf.logging.onnx_export_opset)
- logger.info(f"ONNX model converting and saved at {str(best_model_save_path.with_suffix('.onnx'))}")
+ model_checkpoint = (load_checkpoint(checkpoint_path.with_suffix('.pt')).state_dict() if self.is_graphmodule_training else load_checkpoint(checkpoint_path.with_suffix('.safetensors')))
+ best_model.load_state_dict(model_checkpoint)
+
+ self._save_model(model=best_model, epoch=best_epoch, model_name_tag="best", logging_dir=logging_dir)
- save_graphmodule(best_model_to_save, (best_model_save_path.parent / f"{best_model_save_path.stem}_fx").with_suffix(".pt"))
- logger.info(f"PyTorch FX model tracing and saved at {str(best_model_save_path.with_suffix('.pt'))}")
+ try:
+ model_save_path = Path(logging_dir) / f"{self.task}_{self.model_name}_best.ext"
+
+ save_onnx(best_model,
+ model_save_path.with_suffix(".onnx"),
+ sample_input=self.sample_input.type(save_dtype),
+ opset_version=opset_version)
+ logger.info(f"ONNX model converting and saved at {str(model_save_path.with_suffix('.onnx'))}")
+
+ if not self.is_graphmodule_training:
+ save_graphmodule(best_model,
+ (model_save_path.parent / f"{model_save_path.stem}_fx").with_suffix(".pt"))
+ logger.info(f"PyTorch FX model tracing and saved at {str(model_save_path.with_suffix('.pt'))}")
except Exception as e:
logger.error(e)
pass
diff --git a/src/netspresso_trainer/postprocessors/detection.py b/src/netspresso_trainer/postprocessors/detection.py
index 0e8ef6995..cfae27e76 100644
--- a/src/netspresso_trainer/postprocessors/detection.py
+++ b/src/netspresso_trainer/postprocessors/detection.py
@@ -162,6 +162,58 @@ def anchor_free_decoupled_head_decode(pred, original_shape, score_thresh=0.7):
return detections
+def yolo_fastest_head_decode(pred, original_shape, score_thresh=0.7, anchors=None):
+ pred = pred['pred']
+ dtype = pred[0].type()
+ stage_strides = [original_shape[-1] // o.shape[-1] for o in pred]
+ hw = [x.shape[-2:] for x in pred]
+ device = pred[0].device
+ anchors = [torch.tensor(anchor, dtype=torch.float).view(-1, 2) for anchor in anchors]
+ num_anchors = anchors[0].shape[0]
+ anchors = torch.stack(anchors, dim=0).to(device)
+
+ grids = []
+ strides = []
+ for (hsize, wsize), stride in zip(hw, stage_strides):
+ yv, xv = torch.meshgrid(torch.arange(hsize), torch.arange(wsize), indexing='ij')
+ grid = torch.stack((xv, yv), 2).repeat(num_anchors, 1,1,1).view(1, num_anchors, hsize, wsize, 2).type(dtype).to(device)
+ grids.append(grid)
+ shape = grid.shape[:-1]
+ strides.append(torch.full((*shape, 1), stride).to(device))
+
+ preds = []
+ for idx, p in enumerate(pred):
+ p = p.permute(0, 1, 3, 4, 2)
+ p = torch.cat([
+ (p[..., 0:2].sigmoid() + grids[idx]) * strides[idx],
+ 2. * (torch.tanh(p[..., 2:4]/2 -.549306) + 1.) * anchors[idx].view(1, num_anchors, 1, 1, 2),
+ p[..., 4:].sigmoid()
+ ], dim=-1).flatten(start_dim=1, end_dim=-2)
+ preds.append(p)
+ pred = torch.cat(preds, dim=1)
+
+ box_corner = pred.new(pred.shape)
+ box_corner[:, :, 0] = pred[:, :, 0] - pred[:, :, 2] / 2
+ box_corner[:, :, 1] = pred[:, :, 1] - pred[:, :, 3] / 2
+ box_corner[:, :, 2] = pred[:, :, 0] + pred[:, :, 2] / 2
+ box_corner[:, :, 3] = pred[:, :, 1] + pred[:, :, 3] / 2
+ pred[:, :, :4] = box_corner[:, :, :4]
+
+ # Discard boxes with low score
+ detections = []
+ for p in pred:
+ class_conf, class_pred = torch.max(p[:, 5:], 1, keepdim=True)
+
+ conf_mask = (p[:, 4] * class_conf.squeeze() >= score_thresh).squeeze()
+
+ # x1, y1, x2, y2, obj_conf, pred_score, pred_label
+ detections.append(
+ torch.cat((p[:, :5], class_conf, class_pred.float()), 1)[conf_mask]
+ )
+
+ return detections
+
+
def nms(prediction, nms_thresh=0.45, class_agnostic=False):
output = [torch.zeros(0, 7).to(prediction[0].device) for i in range(len(prediction))]
for i, image_pred in enumerate(prediction):
@@ -197,9 +249,12 @@ def __init__(self, conf_model):
if head_name == 'anchor_free_decoupled_head':
self.decode_outputs = partial(anchor_free_decoupled_head_decode, score_thresh=params.score_thresh)
self.postprocess = partial(nms, nms_thresh=params.nms_thresh, class_agnostic=params.class_agnostic)
- elif head_name == 'anchor_decoupled_head' or head_name == 'yolo_fastest_head_v2':
+ elif head_name == 'anchor_decoupled_head':
self.decode_outputs = partial(anchor_decoupled_head_decode, topk_candidates=params.topk_candidates, score_thresh=params.score_thresh)
self.postprocess = partial(nms, nms_thresh=params.nms_thresh, class_agnostic=params.class_agnostic)
+ elif head_name == 'yolo_fastest_head_v2':
+ self.decode_outputs = partial(yolo_fastest_head_decode, score_thresh=params.score_thresh, anchors=params.anchors)
+ self.postprocess = partial(nms, nms_thresh=params.nms_thresh, class_agnostic=params.class_agnostic)
elif head_name == 'rtdetr_head':
self.decode_outputs = partial(rtdetr_decode, num_top_queries=params.num_top_queries, score_thresh=params.score_thresh)
self.postprocess = None
diff --git a/src/netspresso_trainer/schedulers/multi_step_lr.py b/src/netspresso_trainer/schedulers/multi_step_lr.py
new file mode 100644
index 000000000..2c9af8f3f
--- /dev/null
+++ b/src/netspresso_trainer/schedulers/multi_step_lr.py
@@ -0,0 +1,69 @@
+# Copyright (C) 2024 Nota Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ----------------------------------------------------------------------------
+
+import warnings
+from bisect import bisect_right
+from collections import Counter
+
+from torch.optim.lr_scheduler import _LRScheduler
+
+
+class MultiStepLR(_LRScheduler):
+ """Decays the learning rate of each parameter group by gamma once the
+ number of epochs reaches one of the milestones.
+ Args:
+ optimizer (Optimizer): Wrapped optimizer.
+ milestones (list): List of epoch indices. Must be increasing.
+ gamma (float): Multiplicative factor of learning rate decay.
+ Default: 0.1.
+ Example:
+ >>> # Assuming optimizer uses lr = 0.05 for all groups
+ >>> # lr = 0.05 if epoch < 150
+ >>> # lr = 0.005 if 150 <= epoch < 250
+ >>> # lr = 0.0005 if epoch >= 250
+ >>> for epoch in range(300):
+ >>> train(...)
+ >>> validate(...)
+ >>> scheduler.step()
+ """
+ def __init__(
+ self,
+ optimizer,
+ scheduler_conf,
+ training_epochs,
+ ):
+ self.milestones = Counter(scheduler_conf.milestones)
+ self.gamma = scheduler_conf.gamma
+ super().__init__(optimizer)
+
+ def get_lr(self):
+ if not self._get_lr_called_within_step:
+ warnings.warn("To get the last learning rate computed by the scheduler, "
+ "please use `get_last_lr()`.", UserWarning, stacklevel=2)
+
+ if self.last_epoch not in self.milestones:
+ return [group['lr'] for group in self.optimizer.param_groups]
+ return [
+ group["lr"] * self.gamma ** self.milestones[self.last_epoch]
+ for group in self.optimizer.param_groups
+ ]
+
+ def _get_closed_form_lr(self):
+ milestones = sorted(self.milestones.elements())
+ return [
+ base_lr * self.gamma ** bisect_right(milestones, self.last_epoch)
+ for base_lr in self.base_lrs
+ ]
diff --git a/src/netspresso_trainer/schedulers/registry.py b/src/netspresso_trainer/schedulers/registry.py
index 4847bd204..ddd5e6a44 100644
--- a/src/netspresso_trainer/schedulers/registry.py
+++ b/src/netspresso_trainer/schedulers/registry.py
@@ -21,12 +21,14 @@
from .cosine_lr import CosineAnnealingLRWithCustomWarmUp
from .cosine_warm_restart import CosineAnnealingWarmRestartsWithCustomWarmUp
+from .multi_step_lr import MultiStepLR
from .poly_lr import PolynomialLRWithWarmUp
from .step_lr import StepLR
SCHEDULER_DICT: Dict[str, Type[_LRScheduler]] = {
'cosine': CosineAnnealingWarmRestartsWithCustomWarmUp,
'cosine_no_sgdr': CosineAnnealingLRWithCustomWarmUp,
+ 'multi_step': MultiStepLR,
'poly': PolynomialLRWithWarmUp,
'step': StepLR
}
diff --git a/src/netspresso_trainer/utils/record.py b/src/netspresso_trainer/utils/record.py
index 4cf34ca02..e7e544801 100644
--- a/src/netspresso_trainer/utils/record.py
+++ b/src/netspresso_trainer/utils/record.py
@@ -156,7 +156,7 @@ def __post_init__(self):
@dataclass
class EvaluationSummary:
losses: float
- metrics: float
+ metrics: dict
metrics_list: List[str]
primary_metric: str
flops: Optional[int] = None