Skip to content

Commit

Permalink
add svtr large model (#10937)
Browse files Browse the repository at this point in the history
* add svtr large model

* [WIP]add svtr large model
  • Loading branch information
zhangyubo0722 authored Sep 26, 2023
1 parent 2751cb3 commit e49e491
Show file tree
Hide file tree
Showing 7 changed files with 458 additions and 19 deletions.
144 changes: 144 additions & 0 deletions configs/rec/PP-OCRv4/ch_PP-OCRv4_rec_svtr_large.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
Global:
debug: false
use_gpu: true
epoch_num: 200
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/rec/svtr_large/
save_epoch_step: 10
# evaluation is run every 2000 iterations after the 0th iteration
eval_batch_step: [0, 2000]
cal_metric_during_train: true
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: false
infer_img: doc/imgs_words/ch/word_1.jpg
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
max_text_length: &max_text_length 40
infer_mode: false
use_space_char: true
distributed: true
save_res_path: ./output/rec/predicts_svtr_large.txt


Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.99
epsilon: 1.0e-08
weight_decay: 0.05
no_weight_decay_name: norm pos_embed char_node_embed pos_node_embed char_pos_embed vis_pos_embed
one_dim_param_no_weight_decay: true
lr:
name: Cosine
learning_rate: 0.00025 # 8gpus 64bs
warmup_epoch: 5


Architecture:
model_type: rec
algorithm: SVTR_LCNet
Transform: null
Backbone:
name: SVTRNet
img_size:
- 48
- 320
out_char_num: 40
out_channels: 512
patch_merging: Conv
embed_dim: [192, 256, 512]
depth: [6, 6, 9]
num_heads: [6, 8, 16]
mixer: ['Conv','Conv','Conv','Conv','Conv','Conv','Conv','Conv','Conv','Global','Global','Global','Global','Global','Global','Global','Global','Global','Global','Global','Global']
local_mixer: [[5, 5], [5, 5], [5, 5]]
last_stage: False
prenorm: True
Head:
name: MultiHead
use_pool: true
use_pos: true
head_list:
- CTCHead:
Neck:
name: svtr
dims: 256
depth: 2
hidden_dims: 256
kernel_size: [1, 3]
use_guide: True
Head:
fc_decay: 0.00001
- NRTRHead:
nrtr_dim: 512
max_text_length: *max_text_length

Loss:
name: MultiLoss
loss_config_list:
- CTCLoss:
- NRTRLoss:

PostProcess:
name: CTCLabelDecode

Metric:
name: RecMetric
main_indicator: acc
ignore_space: true

Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/
ext_op_transform_idx: 1
label_file_list:
- ./train_data/train_list.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- RecAug:
- MultiLabelEncode:
gtc_encode: NRTRLabelEncode
- RecResizeImg:
image_shape: [3, 48, 320]
- KeepKeys:
keep_keys:
- image
- label_ctc
- label_gtc
- length
- valid_ratio
loader:
shuffle: true
batch_size_per_card: 64
drop_last: true
num_workers: 8
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data
label_file_list:
- ./train_data/val_list.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- MultiLabelEncode:
gtc_encode: NRTRLabelEncode
- SVTRRecResizeImg:
image_shape: [3, 48, 320]
- KeepKeys:
keep_keys:
- image
- label_ctc
- label_gtc
- length
- valid_ratio
loader:
shuffle: false
drop_last: false
batch_size_per_card: 128
num_workers: 4
2 changes: 1 addition & 1 deletion ppocr/data/imaug/rec_img_aug.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __call__(self, data):
if h >= 20 and w >= 20:
img = tia_distort(img, random.randint(3, 6))
img = tia_stretch(img, random.randint(3, 6))
img = tia_perspective(img)
img = tia_perspective(img)

# bda
data['image'] = img
Expand Down
3 changes: 2 additions & 1 deletion ppocr/modeling/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def build_backbone(config, model_type):
from .det_pp_lcnet import PPLCNet
from .rec_lcnetv3 import PPLCNetV3
from .rec_hgnet import PPHGNet_small
from .rec_vit import ViT
support_dict = [
"MobileNetV3", "ResNet", "ResNet_vd", "ResNet_SAST", "PPLCNet",
"PPLCNetV3", "PPHGNet_small"
Expand Down Expand Up @@ -55,7 +56,7 @@ def build_backbone(config, model_type):
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB',
'ResNet31', 'ResNet45', 'ResNet_ASTER', 'MicroNet',
'EfficientNetb3_PREN', 'SVTRNet', 'ViTSTR', 'ResNet32', 'ResNetRFL',
'DenseNet', 'ShallowCNN', 'PPLCNetV3', 'PPHGNet_small', 'ViTParseQ'
'DenseNet', 'ShallowCNN', 'PPLCNetV3', 'PPHGNet_small', 'ViTParseQ', 'ViT'
]
elif model_type == 'e2e':
from .e2e_resnet_vd_pg import ResNet
Expand Down
Loading

0 comments on commit e49e491

Please sign in to comment.