Skip to content

Commit

Permalink
OCR案例更新 (PaddlePaddle/book#927)
Browse files Browse the repository at this point in the history
* 根据最新反馈建议新增了以下修改:
1.添加数据集压缩包下载链接和使用方法
2.测试pillow在paddlepaddle安装后是否需要单独安装
3.在目录下添加测试数据
4.更换第三方解码器并实现青春版,待2.0更新ctc-decode后再更新该处代码。

* 适配2.0RC0
  • Loading branch information
GT-ZhangAcer authored and wadefelix committed Jul 30, 2021
1 parent b23bad8 commit 4fcdc1b
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 20 deletions.
56 changes: 36 additions & 20 deletions paddle2.0_docs/image_ocr/OCR.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
"**数据展示**\n",
"<p align=\"center\">\n",
"<img src=\"https://ai-studio-static-online.cdn.bcebos.com/865dd55861e24cfaa601d9f87655776c1458d099b487449f946da9b3138fc700\" width=\"400\"><br/>\n",
"</p>"
"</p> \n",
"\n",
"点此[快速获取本节数据集](https://aistudio.baidu.com/aistudio/datasetdetail/57285),待数据集下载完毕后可使用`!unzip OCR_Dataset.zip -d data/`命令或熟悉的解压软件进行解压,待数据准备工作完成后修改本文“训练准备”中的`DATA_PATH = 解压后数据集路径`。"
],
"cell_type": "markdown",
"metadata": {}
Expand Down Expand Up @@ -126,16 +128,13 @@
"\n",
"CTC相关论文:[Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with Recurrent Neu](http://people.idsia.ch/~santiago/papers/icml2006.pdf) \n",
"<p align=\"center\">\n",
"<img src=\"https://ai-studio-static-online.cdn.bcebos.com/50cf1fc38f6b40e596acf71dc43333ff49dcaafb5a9f484b8aeee2db2c08ca67\" width=\"800\"><br/>\n",
"<img src=\"https://ai-studio-static-online.cdn.bcebos.com/f9458cedbb4441d682f15fefd3f3cae5e49d499bcf0a4bbdb976dfdff5a2e656\" width=\"800\"><br/>\n",
"</p>\n",
"\n",
"网络部分,因本篇采用数据集较为简单且图像尺寸较小并不适合较深层次网络。若在对尺寸较大的图像进行模型构建,可以考虑使用更深层次网络/注意力机制来完成。当然也可以通过目标检测形式先检出文本位置,然后进行OCR部分模型构建。\n",
"网络部分,因本篇采用数据集较为简单且图像尺寸较小并不适合较深层次网络。若在对尺寸较大的图像进行模型构建,可以考虑使用更深层次网络/注意力机制来完成。当然也可以通过目标检测形式先检出文本位置,然后进行OCR部分模型构建。(下方样例来源[PaddleOCR](v))\n",
"\n",
"<p align=\"center\">\n",
"<img src=\"https://ai-studio-static-online.cdn.bcebos.com/6e0665ddfe6a46e1b658da870cd4043f1d50e5f4dc2746018c710c58d2e0c18c\" width=\"400\"><br/>\n",
" \n",
" \n",
"<a href=\"https://github.com/PaddlePaddle/PaddleOCR\">PaddleOCR效果图</a>\n",
"<img src=\"https://ai-studio-static-online.cdn.bcebos.com/6e0665ddfe6a46e1b658da870cd4043f1d50e5f4dc2746018c710c58d2e0c18c\" width=\"400\"></br>\n",
"</p>"
]
},
Expand Down Expand Up @@ -164,16 +163,16 @@
" self.is_infer = is_infer\n",
"\n",
" # 定义一层3x3卷积+BatchNorm\n",
" self.conv1 = paddle.nn.Conv2d(in_channels=IMAGE_SHAPE_C,\n",
" self.conv1 = paddle.nn.Conv2D(in_channels=IMAGE_SHAPE_C,\n",
" out_channels=32,\n",
" kernel_size=3)\n",
" self.bn1 = paddle.nn.BatchNorm2d(32)\n",
" self.bn1 = paddle.nn.BatchNorm2D(32)\n",
" # 定义一层步长为2的3x3卷积进行下采样+BatchNorm\n",
" self.conv2 = paddle.nn.Conv2d(in_channels=32,\n",
" self.conv2 = paddle.nn.Conv2D(in_channels=32,\n",
" out_channels=64,\n",
" kernel_size=3,\n",
" stride=2)\n",
" self.bn2 = paddle.nn.BatchNorm2d(64)\n",
" self.bn2 = paddle.nn.BatchNorm2D(64)\n",
" # 定义一层1x1卷积压缩通道数,输出通道数设置为比LABEL_MAX_LEN稍大的定值可获取更优效果,当然也可设置为LABEL_MAX_LEN\n",
" self.conv3 = paddle.nn.Conv2d(in_channels=64,\n",
" out_channels=LABEL_MAX_LEN + 4,\n",
Expand Down Expand Up @@ -215,6 +214,8 @@
" if self.is_infer:\n",
" # 输出层 - Shape = (Batch Size, Max label len, Prob) \n",
" x = paddle.nn.functional.softmax(x)\n",
" # 转换为标签\n",
" x = paddle.tensor.argmax(x, axis=-1)\n",
" return x"
]
},
Expand Down Expand Up @@ -286,8 +287,8 @@
" super().__init__()\n",
"\n",
" def forward(self, ipt, label):\n",
" input_lengths = paddle.tensor.fill_constant([BATCH_SIZE, 1], \"int64\", LABEL_MAX_LEN + 4)\n",
" label_lengths = paddle.tensor.fill_constant([BATCH_SIZE, 1], \"int64\", LABEL_MAX_LEN)\n",
" input_lengths = paddle.tensor.creation.fill_constant([BATCH_SIZE, 1], \"int64\", LABEL_MAX_LEN + 4)\n",
" label_lengths = paddle.tensor.creation.fill_constant([BATCH_SIZE, 1], \"int64\", LABEL_MAX_LEN)\n",
" # 按文档要求进行转换dim顺序\n",
" ipt = paddle.tensor.transpose(ipt, [1, 0, 2])\n",
" # 计算loss\n",
Expand Down Expand Up @@ -452,7 +453,7 @@
},
"outputs": [],
"source": [
"# 待预测目录\n",
"# 待预测目录 - 可在测试数据集中挑出\b3张图像放在该目录中进行推理\n",
"INFER_DATA_PATH = \"./sample_img\"\n",
"# 训练后存档点路径 - 10代表使用第10个存档点\n",
"CHECKPOINT_PATH = \"./output/10\"\n",
Expand Down Expand Up @@ -505,7 +506,7 @@
{
"source": [
"## 开始预测\n",
"> 飞桨2.0 CTC Decoder 相关API正在迁移中,本节暂时使用[第三方解码器](https://github.com/awni/speech/blob/072bcf9ff510d814fbfcaad43b2883ecf8f60806/speech/models/ctc_decoder.py)进行解码"
"> 飞桨2.0 CTC Decoder 相关API正在迁移中,本节暂时使用简易版解码器"
],
"cell_type": "markdown",
"metadata": {
Expand Down Expand Up @@ -533,7 +534,22 @@
}
],
"source": [
"from ctc import decode\n",
"# 编写简易版解码器\n",
"def ctc_decode(text, blank=10):\n",
" \"\"\"\n",
" 简易CTC解码器\n",
" :param text: 待解码数据\n",
" :param blank: 分隔符索引值\n",
" :return: 解码后数据\n",
" \"\"\"\n",
" result = []\n",
" cache_idx = -1\n",
" for char in text:\n",
" if char != blank and char != cache_idx:\n",
" result.append(char)\n",
" cache_idx = char\n",
" return result\n",
"\n",
"\n",
"# 实例化预测模型\n",
"model = paddle.Model(Net(is_infer=True), inputs=input_define)\n",
Expand All @@ -547,10 +563,10 @@
"img_names = infer_reader.get_names()\n",
"results = model.predict(infer_reader, batch_size=BATCH_SIZE)\n",
"index = 0\n",
"for result in results[0]:\n",
" for prob in result:\n",
" out, _ = decode(prob, blank=10)\n",
" print(f\"文件名:{img_names[index]},预测结果为:{out}\")\n",
"for text_batch in results[0]:\n",
" for prob in text_batch:\n",
" out = ctc_decode(prob, blank=10)\n",
" print(f\"文件名:{img_names[index]},推理结果为:{out}\")\n",
" index += 1"
]
}
Expand Down
Binary file added paddle2.0_docs/image_ocr/sample_img/9450.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added paddle2.0_docs/image_ocr/sample_img/9451.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added paddle2.0_docs/image_ocr/sample_img/9452.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 4fcdc1b

Please sign in to comment.