Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
update hapi and image_segmentation (#931)
Browse files Browse the repository at this point in the history
  • Loading branch information
saxon-zh authored Oct 31, 2020
1 parent c65c0cb commit fb1b628
Show file tree
Hide file tree
Showing 2 changed files with 590 additions and 379 deletions.
31 changes: 27 additions & 4 deletions paddle2.0_docs/high_level_api/high_level_api.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -384,15 +384,23 @@
"metadata": {},
"outputs": [],
"source": [
"# 场景1:动态图模式\n",
"\n",
"# 使用GPU训练\n",
"paddle.set_device('gpu')\n",
"\n",
"# 模型封装\n",
"model = paddle.Model(mnist)\n",
"\n",
"## 场景1:动态图模式\n",
"## 1.1 为模型预测部署场景进行模型训练\n",
"## 需要添加input和label数据描述,否则会导致使用model.save(training=False)保存的预测模型在使用时出错\n",
"input = paddle.static.InputSpec([None, 1, 28, 28], dtype='float32')\n",
"label = paddle.static.InputSpec([None, 1], dtype='int8')\n",
"model = paddle.Model(mnist, input, label)\n",
"\n",
"## 1.2 面向实验而进行的模型训练\n",
"## 可以不传递input和label信息\n",
"# model = paddle.Model(mnist)\n",
"\n",
"# 场景2:静态图模式\n",
"## 场景2:静态图模式\n",
"# paddle.enable_static()\n",
"# paddle.set_device('gpu')\n",
"# input = paddle.static.InputSpec([None, 1, 28, 28], dtype='float32')\n",
Expand Down Expand Up @@ -584,6 +592,21 @@
" verbose=1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 注:\n",
"\n",
"`fit()`的第一个参数不仅可以传递数据集`paddle.io.Dataset`,还可以传递DataLoader,如果想要实现某个自定义的数据集抽样等逻辑,可以在fit外自定义DataLoader,然后传递给fit函数。\n",
"\n",
"```python\n",
"train_dataloader = paddle.io.DataLoader(train_dataset)\n",
"...\n",
"model.fit(train_dataloader, ...)\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
Loading

0 comments on commit fb1b628

Please sign in to comment.