Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

update hapi and image_segmentation #931

Merged
merged 1 commit into from
Oct 31, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions paddle2.0_docs/high_level_api/high_level_api.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -384,15 +384,23 @@
"metadata": {},
"outputs": [],
"source": [
"# 场景1:动态图模式\n",
"\n",
"# 使用GPU训练\n",
"paddle.set_device('gpu')\n",
"\n",
"# 模型封装\n",
"model = paddle.Model(mnist)\n",
"\n",
"## 场景1:动态图模式\n",
"## 1.1 为模型预测部署场景进行模型训练\n",
"## 需要添加input和label数据描述,否则会导致使用model.save(training=False)保存的预测模型在使用时出错\n",
"input = paddle.static.InputSpec([None, 1, 28, 28], dtype='float32')\n",
"label = paddle.static.InputSpec([None, 1], dtype='int8')\n",
"model = paddle.Model(mnist, input, label)\n",
"\n",
"## 1.2 面向实验而进行的模型训练\n",
"## 可以不传递input和label信息\n",
"# model = paddle.Model(mnist)\n",
"\n",
"# 场景2:静态图模式\n",
"## 场景2:静态图模式\n",
"# paddle.enable_static()\n",
"# paddle.set_device('gpu')\n",
"# input = paddle.static.InputSpec([None, 1, 28, 28], dtype='float32')\n",
Expand Down Expand Up @@ -584,6 +592,21 @@
" verbose=1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 注:\n",
"\n",
"`fit()`的第一个参数不仅可以传递数据集`paddle.io.Dataset`,还可以传递DataLoader,如果想要实现某个自定义的数据集抽样等逻辑,可以在fit外自定义DataLoader,然后传递给fit函数。\n",
"\n",
"```python\n",
"train_dataloader = paddle.io.DataLoader(train_dataset)\n",
"...\n",
"model.fit(train_dataloader, ...)\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
Loading