diff --git a/paddle2.0_docs/image_classification/mnist_lenet_classification.ipynb b/paddle2.0_docs/image_classification/mnist_lenet_classification.ipynb index 2e5ba0a5..84077559 100644 --- a/paddle2.0_docs/image_classification/mnist_lenet_classification.ipynb +++ b/paddle2.0_docs/image_classification/mnist_lenet_classification.ipynb @@ -19,7 +19,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 35, "metadata": {}, "outputs": [ { @@ -46,7 +46,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 36, "metadata": {}, "outputs": [ { @@ -74,7 +74,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 37, "metadata": {}, "outputs": [ { @@ -117,7 +117,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 38, "metadata": {}, "outputs": [], "source": [ @@ -127,9 +127,9 @@ " def __init__(self):\n", " super(LeNet, self).__init__()\n", " self.conv1 = paddle.nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2)\n", - " self.max_pool1 = paddle.nn.Pool2D(pool_size=2, pool_type='max', pool_stride=2)\n", + " self.max_pool1 = paddle.nn.MaxPool2d(kernel_size=2, stride=2)\n", " self.conv2 = paddle.nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1)\n", - " self.max_pool2 = paddle.nn.Pool2D(pool_size=2, pool_type='max', pool_stride=2)\n", + " self.max_pool2 = paddle.nn.MaxPool2d(kernel_size=2, stride=2)\n", " self.linear1 = paddle.nn.Linear(in_features=16*5*5, out_features=120)\n", " self.linear2 = paddle.nn.Linear(in_features=120, out_features=84)\n", " self.linear3 = paddle.nn.Linear(in_features=84, out_features=10)\n", @@ -155,13 +155,141 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## 训练方式一\n", - "通过`Model` 构建实例,快速完成模型训练" + "# 3.训练方式一\n", + "组网后,开始对模型进行训练,先构建`train_loader`,加载训练数据,然后定义`train`函数,设置好损失函数后,按batch加载数据,完成模型的训练。" ] }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 39, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch: 0, batch_id: 0, loss is: [2.3064885], acc is: [0.109375]\n", + "epoch: 0, batch_id: 100, loss is: [1.5477252], acc is: [1.]\n", + "epoch: 0, batch_id: 200, loss is: [1.5201148], acc is: [1.]\n", + "epoch: 0, batch_id: 300, loss is: [1.525354], acc is: [0.953125]\n", + "epoch: 0, batch_id: 400, loss is: [1.5201038], acc is: [1.]\n", + "epoch: 0, batch_id: 500, loss is: [1.4901408], acc is: [1.]\n", + "epoch: 0, batch_id: 600, loss is: [1.4925538], acc is: [0.984375]\n", + "epoch: 0, batch_id: 700, loss is: [1.5247533], acc is: [0.96875]\n", + "epoch: 0, batch_id: 800, loss is: [1.5365943], acc is: [1.]\n", + "epoch: 0, batch_id: 900, loss is: [1.5154861], acc is: [0.984375]\n", + "epoch: 1, batch_id: 0, loss is: [1.4988302], acc is: [0.984375]\n", + "epoch: 1, batch_id: 100, loss is: [1.493154], acc is: [0.984375]\n", + "epoch: 1, batch_id: 200, loss is: [1.4974915], acc is: [1.]\n", + "epoch: 1, batch_id: 300, loss is: [1.5089471], acc is: [0.984375]\n", + "epoch: 1, batch_id: 400, loss is: [1.5041347], acc is: [1.]\n", + "epoch: 1, batch_id: 500, loss is: [1.5145375], acc is: [1.]\n", + "epoch: 1, batch_id: 600, loss is: [1.4904011], acc is: [0.984375]\n", + "epoch: 1, batch_id: 700, loss is: [1.5121607], acc is: [0.96875]\n", + "epoch: 1, batch_id: 800, loss is: [1.5078678], acc is: [1.]\n", + "epoch: 1, batch_id: 900, loss is: [1.500349], acc is: [0.984375]\n" + ] + } + ], + "source": [ + "import paddle\n", + "train_loader = paddle.io.DataLoader(train_dataset, places=paddle.CPUPlace(), batch_size=64)\n", + "# 加载训练集 batch_size 设为 64\n", + "def train(model):\n", + " model.train()\n", + " epochs = 2\n", + " optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())\n", + " # 用Adam作为优化函数\n", + " for epoch in range(epochs):\n", + " for batch_id, data in enumerate(train_loader()):\n", + " x_data = data[0]\n", + " y_data = data[1]\n", + " predicts = model(x_data)\n", + " loss = paddle.nn.functional.cross_entropy(predicts, y_data)\n", + " # 计算损失\n", + " acc = paddle.metric.accuracy(predicts, y_data, k=2)\n", + " avg_loss = paddle.mean(loss)\n", + " avg_acc = paddle.mean(acc)\n", + " avg_loss.backward()\n", + " if batch_id % 100 == 0:\n", + " print(\"epoch: {}, batch_id: {}, loss is: {}, acc is: {}\".format(epoch, batch_id, avg_loss.numpy(), avg_acc.numpy()))\n", + " optim.minimize(avg_loss)\n", + " model.clear_gradients()\n", + "model = LeNet()\n", + "train(model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 对模型进行验证\n", + "训练完成后,需要验证模型的效果,此时,加载测试数据集,然后用训练好的模对测试集进行预测,计算损失与精度。" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "batch_id: 0, loss is: [1.4659549], acc is: [1.]\n", + "batch_id: 100, loss is: [1.4933192], acc is: [0.984375]\n", + "batch_id: 200, loss is: [1.4779761], acc is: [1.]\n", + "batch_id: 300, loss is: [1.4919193], acc is: [0.984375]\n", + "batch_id: 400, loss is: [1.5036212], acc is: [1.]\n", + "batch_id: 500, loss is: [1.4922347], acc is: [0.984375]\n", + "batch_id: 600, loss is: [1.4765416], acc is: [0.984375]\n", + "batch_id: 700, loss is: [1.4997746], acc is: [0.984375]\n", + "batch_id: 800, loss is: [1.4831288], acc is: [1.]\n", + "batch_id: 900, loss is: [1.498342], acc is: [0.984375]\n" + ] + } + ], + "source": [ + "import paddle\n", + "test_loader = paddle.io.DataLoader(test_dataset, places=paddle.CPUPlace(), batch_size=64)\n", + "# 加载测试数据集\n", + "def test(model):\n", + " model.eval()\n", + " batch_size = 64\n", + " for batch_id, data in enumerate(train_loader()):\n", + " x_data = data[0]\n", + " y_data = data[1]\n", + " predicts = model(x_data)\n", + " # 获取预测结果\n", + " loss = paddle.nn.functional.cross_entropy(predicts, y_data)\n", + " acc = paddle.metric.accuracy(predicts, y_data, k=2)\n", + " avg_loss = paddle.mean(loss)\n", + " avg_acc = paddle.mean(acc)\n", + " avg_loss.backward()\n", + " if batch_id % 100 == 0:\n", + " print(\"batch_id: {}, loss is: {}, acc is: {}\".format(batch_id, avg_loss.numpy(), avg_acc.numpy()))\n", + "test(model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 训练方式一结束\n", + "以上就是训练方式一,通过这种方式,可以清楚的看到训练和测试中的每一步过程。但是,这种方式句法比较复杂。因此,我们提供了训练方式二,能够更加快速、高效的完成模型的训练与测试。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3.训练方式二\n", + "通过paddle提供的`Model` 构建实例,使用封装好的训练与测试接口,快速完成模型训练与测试。" + ] + }, + { + "cell_type": "code", + "execution_count": 41, "metadata": {}, "outputs": [], "source": [ @@ -170,10 +298,9 @@ "from paddle.metric import Accuracy\n", "inputs = InputSpec([None, 784], 'float32', 'x')\n", "labels = InputSpec([None, 10], 'float32', 'x')\n", - "model = paddle.hapi.Model(LeNet(), inputs, labels)\n", + "model = paddle.Model(LeNet(), inputs, labels)\n", "optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())\n", "\n", - "\n", "model.prepare(\n", " optim,\n", " paddle.nn.loss.CrossEntropyLoss(),\n", @@ -190,7 +317,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 42, "metadata": {}, "outputs": [ { @@ -198,266 +325,209 @@ "output_type": "stream", "text": [ "Epoch 1/2\n", - "step 10/938 - loss: 2.2369 - acc_top1: 0.3281 - acc_top2: 0.4172 - 18ms/step\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Library/Python/3.7/site-packages/paddle/fluid/layers/utils.py:76: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working\n", - " return (isinstance(seq, collections.Sequence) and\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "step 20/938 - loss: 2.0185 - acc_top1: 0.3656 - acc_top2: 0.4328 - 17ms/step\n", - "step 30/938 - loss: 1.9579 - acc_top1: 0.4120 - acc_top2: 0.4969 - 16ms/step\n", - "step 40/938 - loss: 1.8549 - acc_top1: 0.4602 - acc_top2: 0.5500 - 16ms/step\n", - "step 50/938 - loss: 1.8628 - acc_top1: 0.5097 - acc_top2: 0.6028 - 16ms/step\n", - "step 60/938 - loss: 1.7139 - acc_top1: 0.5456 - acc_top2: 0.6409 - 16ms/step\n", - "step 70/938 - loss: 1.7296 - acc_top1: 0.5795 - acc_top2: 0.6719 - 15ms/step\n", - "step 80/938 - loss: 1.6302 - acc_top1: 0.6053 - acc_top2: 0.6949 - 15ms/step\n", - "step 90/938 - loss: 1.6688 - acc_top1: 0.6290 - acc_top2: 0.7158 - 15ms/step\n", - "step 100/938 - loss: 1.6401 - acc_top1: 0.6491 - acc_top2: 0.7327 - 15ms/step\n", - "step 110/938 - loss: 1.6357 - acc_top1: 0.6636 - acc_top2: 0.7440 - 15ms/step\n", - "step 120/938 - loss: 1.6309 - acc_top1: 0.6767 - acc_top2: 0.7539 - 15ms/step\n", - "step 130/938 - loss: 1.6445 - acc_top1: 0.6894 - acc_top2: 0.7638 - 15ms/step\n", - "step 140/938 - loss: 1.5961 - acc_top1: 0.7002 - acc_top2: 0.7728 - 15ms/step\n", - "step 150/938 - loss: 1.6822 - acc_top1: 0.7086 - acc_top2: 0.7794 - 15ms/step\n", - "step 160/938 - loss: 1.6243 - acc_top1: 0.7176 - acc_top2: 0.7858 - 15ms/step\n", - "step 170/938 - loss: 1.6159 - acc_top1: 0.7254 - acc_top2: 0.7915 - 15ms/step\n", - "step 180/938 - loss: 1.6820 - acc_top1: 0.7312 - acc_top2: 0.7962 - 15ms/step\n", - "step 190/938 - loss: 1.6733 - acc_top1: 0.7363 - acc_top2: 0.7999 - 15ms/step\n", - "step 200/938 - loss: 1.7717 - acc_top1: 0.7413 - acc_top2: 0.8039 - 15ms/step\n", - "step 210/938 - loss: 1.5468 - acc_top1: 0.7458 - acc_top2: 0.8072 - 15ms/step\n", - "step 220/938 - loss: 1.5654 - acc_top1: 0.7506 - acc_top2: 0.8111 - 15ms/step\n", - "step 230/938 - loss: 1.6129 - acc_top1: 0.7547 - acc_top2: 0.8143 - 15ms/step\n", - "step 240/938 - loss: 1.5937 - acc_top1: 0.7592 - acc_top2: 0.8180 - 15ms/step\n", - "step 250/938 - loss: 1.5457 - acc_top1: 0.7631 - acc_top2: 0.8214 - 15ms/step\n", - "step 260/938 - loss: 1.6041 - acc_top1: 0.7673 - acc_top2: 0.8249 - 15ms/step\n", - "step 270/938 - loss: 1.6049 - acc_top1: 0.7700 - acc_top2: 0.8271 - 15ms/step\n", - "step 280/938 - loss: 1.5989 - acc_top1: 0.7735 - acc_top2: 0.8299 - 15ms/step\n", - "step 290/938 - loss: 1.6950 - acc_top1: 0.7752 - acc_top2: 0.8310 - 15ms/step\n", - "step 300/938 - loss: 1.5888 - acc_top1: 0.7781 - acc_top2: 0.8330 - 15ms/step\n", - "step 310/938 - loss: 1.5983 - acc_top1: 0.7808 - acc_top2: 0.8350 - 15ms/step\n", - "step 320/938 - loss: 1.5133 - acc_top1: 0.7840 - acc_top2: 0.8370 - 15ms/step\n", - "step 330/938 - loss: 1.5587 - acc_top1: 0.7866 - acc_top2: 0.8385 - 15ms/step\n", - "step 340/938 - loss: 1.6093 - acc_top1: 0.7882 - acc_top2: 0.8393 - 15ms/step\n", - "step 350/938 - loss: 1.6259 - acc_top1: 0.7902 - acc_top2: 0.8410 - 15ms/step\n", - "step 360/938 - loss: 1.6194 - acc_top1: 0.7918 - acc_top2: 0.8422 - 15ms/step\n", - "step 370/938 - loss: 1.6531 - acc_top1: 0.7941 - acc_top2: 0.8438 - 15ms/step\n", - "step 380/938 - loss: 1.6986 - acc_top1: 0.7957 - acc_top2: 0.8447 - 15ms/step\n", - "step 390/938 - loss: 1.5932 - acc_top1: 0.7974 - acc_top2: 0.8459 - 15ms/step\n", - "step 400/938 - loss: 1.6512 - acc_top1: 0.7993 - acc_top2: 0.8474 - 15ms/step\n", - "step 410/938 - loss: 1.5698 - acc_top1: 0.8012 - acc_top2: 0.8487 - 15ms/step\n", - "step 420/938 - loss: 1.5889 - acc_top1: 0.8025 - acc_top2: 0.8494 - 15ms/step\n", - "step 430/938 - loss: 1.5518 - acc_top1: 0.8036 - acc_top2: 0.8503 - 15ms/step\n", - "step 440/938 - loss: 1.6057 - acc_top1: 0.8048 - acc_top2: 0.8508 - 15ms/step\n", - "step 450/938 - loss: 1.6081 - acc_top1: 0.8064 - acc_top2: 0.8519 - 15ms/step\n", - "step 460/938 - loss: 1.5742 - acc_top1: 0.8079 - acc_top2: 0.8531 - 15ms/step\n", - "step 470/938 - loss: 1.5704 - acc_top1: 0.8095 - acc_top2: 0.8543 - 15ms/step\n", - "step 480/938 - loss: 1.6083 - acc_top1: 0.8110 - acc_top2: 0.8550 - 15ms/step\n", - "step 490/938 - loss: 1.6081 - acc_top1: 0.8120 - acc_top2: 0.8555 - 15ms/step\n", - "step 500/938 - loss: 1.5156 - acc_top1: 0.8133 - acc_top2: 0.8564 - 15ms/step\n", - "step 510/938 - loss: 1.5856 - acc_top1: 0.8148 - acc_top2: 0.8573 - 15ms/step\n", - "step 520/938 - loss: 1.5275 - acc_top1: 0.8163 - acc_top2: 0.8582 - 15ms/step\n", - "step 530/938 - loss: 1.5345 - acc_top1: 0.8172 - acc_top2: 0.8591 - 15ms/step\n", - "step 540/938 - loss: 1.5387 - acc_top1: 0.8181 - acc_top2: 0.8596 - 15ms/step\n", - "step 550/938 - loss: 1.5753 - acc_top1: 0.8190 - acc_top2: 0.8601 - 15ms/step\n", - "step 560/938 - loss: 1.6103 - acc_top1: 0.8203 - acc_top2: 0.8610 - 15ms/step\n", - "step 570/938 - loss: 1.5571 - acc_top1: 0.8215 - acc_top2: 0.8618 - 15ms/step\n", - "step 580/938 - loss: 1.5575 - acc_top1: 0.8221 - acc_top2: 0.8622 - 15ms/step\n", - "step 590/938 - loss: 1.4821 - acc_top1: 0.8230 - acc_top2: 0.8627 - 15ms/step\n", - "step 600/938 - loss: 1.5644 - acc_top1: 0.8243 - acc_top2: 0.8636 - 15ms/step\n", - "step 610/938 - loss: 1.5317 - acc_top1: 0.8253 - acc_top2: 0.8644 - 15ms/step\n", - "step 620/938 - loss: 1.5849 - acc_top1: 0.8258 - acc_top2: 0.8647 - 15ms/step\n", - "step 630/938 - loss: 1.6087 - acc_top1: 0.8263 - acc_top2: 0.8649 - 15ms/step\n", - "step 640/938 - loss: 1.5617 - acc_top1: 0.8272 - acc_top2: 0.8655 - 15ms/step\n", - "step 650/938 - loss: 1.6376 - acc_top1: 0.8279 - acc_top2: 0.8660 - 15ms/step\n", - "step 660/938 - loss: 1.5428 - acc_top1: 0.8287 - acc_top2: 0.8665 - 15ms/step\n", - "step 670/938 - loss: 1.5797 - acc_top1: 0.8293 - acc_top2: 0.8668 - 15ms/step\n", - "step 680/938 - loss: 1.5210 - acc_top1: 0.8300 - acc_top2: 0.8674 - 15ms/step\n", - "step 690/938 - loss: 1.6159 - acc_top1: 0.8305 - acc_top2: 0.8677 - 15ms/step\n", - "step 700/938 - loss: 1.5592 - acc_top1: 0.8313 - acc_top2: 0.8682 - 15ms/step\n", - "step 710/938 - loss: 1.6400 - acc_top1: 0.8318 - acc_top2: 0.8685 - 15ms/step\n", - "step 720/938 - loss: 1.5638 - acc_top1: 0.8327 - acc_top2: 0.8691 - 15ms/step\n", - "step 730/938 - loss: 1.5691 - acc_top1: 0.8333 - acc_top2: 0.8693 - 15ms/step\n", - "step 740/938 - loss: 1.5848 - acc_top1: 0.8337 - acc_top2: 0.8695 - 15ms/step\n", - "step 750/938 - loss: 1.6317 - acc_top1: 0.8344 - acc_top2: 0.8698 - 15ms/step\n", - "step 760/938 - loss: 1.5127 - acc_top1: 0.8352 - acc_top2: 0.8703 - 15ms/step\n", - "step 770/938 - loss: 1.5822 - acc_top1: 0.8359 - acc_top2: 0.8707 - 15ms/step\n", - "step 780/938 - loss: 1.6010 - acc_top1: 0.8366 - acc_top2: 0.8712 - 15ms/step\n", - "step 790/938 - loss: 1.5238 - acc_top1: 0.8373 - acc_top2: 0.8717 - 15ms/step\n", - "step 800/938 - loss: 1.5858 - acc_top1: 0.8377 - acc_top2: 0.8719 - 15ms/step\n", - "step 810/938 - loss: 1.5800 - acc_top1: 0.8384 - acc_top2: 0.8724 - 15ms/step\n", - "step 820/938 - loss: 1.6312 - acc_top1: 0.8390 - acc_top2: 0.8727 - 15ms/step\n", - "step 830/938 - loss: 1.5812 - acc_top1: 0.8398 - acc_top2: 0.8732 - 15ms/step\n", - "step 840/938 - loss: 1.5661 - acc_top1: 0.8402 - acc_top2: 0.8734 - 15ms/step\n", - "step 850/938 - loss: 1.5379 - acc_top1: 0.8409 - acc_top2: 0.8739 - 15ms/step\n", - "step 860/938 - loss: 1.5266 - acc_top1: 0.8413 - acc_top2: 0.8740 - 15ms/step\n", - "step 870/938 - loss: 1.5264 - acc_top1: 0.8420 - acc_top2: 0.8745 - 15ms/step\n", - "step 880/938 - loss: 1.5688 - acc_top1: 0.8425 - acc_top2: 0.8748 - 15ms/step\n", - "step 890/938 - loss: 1.5707 - acc_top1: 0.8429 - acc_top2: 0.8751 - 15ms/step\n", - "step 900/938 - loss: 1.5564 - acc_top1: 0.8432 - acc_top2: 0.8752 - 15ms/step\n", - "step 910/938 - loss: 1.4924 - acc_top1: 0.8438 - acc_top2: 0.8757 - 15ms/step\n", - "step 920/938 - loss: 1.5514 - acc_top1: 0.8443 - acc_top2: 0.8760 - 15ms/step\n", - "step 930/938 - loss: 1.5850 - acc_top1: 0.8446 - acc_top2: 0.8762 - 15ms/step\n", - "step 938/938 - loss: 1.4915 - acc_top1: 0.8448 - acc_top2: 0.8764 - 15ms/step\n", + "step 10/938 - loss: 2.2434 - acc_top1: 0.1344 - acc_top2: 0.3719 - 14ms/step\n", + "step 20/938 - loss: 2.0292 - acc_top1: 0.2836 - acc_top2: 0.4633 - 14ms/step\n", + "step 30/938 - loss: 1.9341 - acc_top1: 0.3755 - acc_top2: 0.5214 - 14ms/step\n", + "step 40/938 - loss: 1.8009 - acc_top1: 0.4469 - acc_top2: 0.5727 - 14ms/step\n", + "step 50/938 - loss: 1.8000 - acc_top1: 0.4975 - acc_top2: 0.6125 - 13ms/step\n", + "step 60/938 - loss: 1.6335 - acc_top1: 0.5417 - acc_top2: 0.6438 - 14ms/step\n", + "step 70/938 - loss: 1.7931 - acc_top1: 0.5708 - acc_top2: 0.6643 - 13ms/step\n", + "step 80/938 - loss: 1.6699 - acc_top1: 0.5961 - acc_top2: 0.6846 - 13ms/step\n", + "step 90/938 - loss: 1.6832 - acc_top1: 0.6189 - acc_top2: 0.7069 - 13ms/step\n", + "step 100/938 - loss: 1.6336 - acc_top1: 0.6409 - acc_top2: 0.7245 - 14ms/step\n", + "step 110/938 - loss: 1.6598 - acc_top1: 0.6557 - acc_top2: 0.7376 - 13ms/step\n", + "step 120/938 - loss: 1.6348 - acc_top1: 0.6708 - acc_top2: 0.7488 - 13ms/step\n", + "step 130/938 - loss: 1.6223 - acc_top1: 0.6851 - acc_top2: 0.7601 - 13ms/step\n", + "step 140/938 - loss: 1.5622 - acc_top1: 0.6970 - acc_top2: 0.7694 - 13ms/step\n", + "step 150/938 - loss: 1.6455 - acc_top1: 0.7065 - acc_top2: 0.7767 - 14ms/step\n", + "step 160/938 - loss: 1.6127 - acc_top1: 0.7154 - acc_top2: 0.7837 - 14ms/step\n", + "step 170/938 - loss: 1.5963 - acc_top1: 0.7242 - acc_top2: 0.7898 - 14ms/step\n", + "step 180/938 - loss: 1.6485 - acc_top1: 0.7310 - acc_top2: 0.7948 - 14ms/step\n", + "step 190/938 - loss: 1.6666 - acc_top1: 0.7368 - acc_top2: 0.7992 - 14ms/step\n", + "step 200/938 - loss: 1.7862 - acc_top1: 0.7419 - acc_top2: 0.8030 - 14ms/step\n", + "step 210/938 - loss: 1.5479 - acc_top1: 0.7464 - acc_top2: 0.8064 - 14ms/step\n", + "step 220/938 - loss: 1.5650 - acc_top1: 0.7515 - acc_top2: 0.8106 - 14ms/step\n", + "step 230/938 - loss: 1.5822 - acc_top1: 0.7562 - acc_top2: 0.8141 - 14ms/step\n", + "step 240/938 - loss: 1.5966 - acc_top1: 0.7608 - acc_top2: 0.8179 - 14ms/step\n", + "step 250/938 - loss: 1.5551 - acc_top1: 0.7650 - acc_top2: 0.8213 - 14ms/step\n", + "step 260/938 - loss: 1.5584 - acc_top1: 0.7699 - acc_top2: 0.8249 - 14ms/step\n", + "step 270/938 - loss: 1.5933 - acc_top1: 0.7730 - acc_top2: 0.8273 - 14ms/step\n", + "step 280/938 - loss: 1.5589 - acc_top1: 0.7769 - acc_top2: 0.8301 - 14ms/step\n", + "step 290/938 - loss: 1.6513 - acc_top1: 0.7793 - acc_top2: 0.8315 - 14ms/step\n", + "step 300/938 - loss: 1.5929 - acc_top1: 0.7821 - acc_top2: 0.8337 - 14ms/step\n", + "step 310/938 - loss: 1.5672 - acc_top1: 0.7849 - acc_top2: 0.8360 - 14ms/step\n", + "step 320/938 - loss: 1.5147 - acc_top1: 0.7879 - acc_top2: 0.8381 - 14ms/step\n", + "step 330/938 - loss: 1.5697 - acc_top1: 0.7902 - acc_top2: 0.8397 - 14ms/step\n", + "step 340/938 - loss: 1.5697 - acc_top1: 0.7919 - acc_top2: 0.8406 - 14ms/step\n", + "step 350/938 - loss: 1.6122 - acc_top1: 0.7941 - acc_top2: 0.8423 - 14ms/step\n", + "step 360/938 - loss: 1.5934 - acc_top1: 0.7960 - acc_top2: 0.8435 - 14ms/step\n", + "step 370/938 - loss: 1.6258 - acc_top1: 0.7982 - acc_top2: 0.8451 - 14ms/step\n", + "step 380/938 - loss: 1.6805 - acc_top1: 0.7996 - acc_top2: 0.8463 - 14ms/step\n", + "step 390/938 - loss: 1.5997 - acc_top1: 0.8011 - acc_top2: 0.8475 - 14ms/step\n", + "step 400/938 - loss: 1.6151 - acc_top1: 0.8029 - acc_top2: 0.8488 - 14ms/step\n", + "step 410/938 - loss: 1.5800 - acc_top1: 0.8047 - acc_top2: 0.8499 - 14ms/step\n", + "step 420/938 - loss: 1.5950 - acc_top1: 0.8060 - acc_top2: 0.8508 - 14ms/step\n", + "step 430/938 - loss: 1.5533 - acc_top1: 0.8075 - acc_top2: 0.8517 - 14ms/step\n", + "step 440/938 - loss: 1.6171 - acc_top1: 0.8086 - acc_top2: 0.8521 - 14ms/step\n", + "step 450/938 - loss: 1.5756 - acc_top1: 0.8103 - acc_top2: 0.8533 - 14ms/step\n", + "step 460/938 - loss: 1.5655 - acc_top1: 0.8121 - acc_top2: 0.8544 - 14ms/step\n", + "step 470/938 - loss: 1.5816 - acc_top1: 0.8139 - acc_top2: 0.8555 - 14ms/step\n", + "step 480/938 - loss: 1.6202 - acc_top1: 0.8148 - acc_top2: 0.8562 - 14ms/step\n", + "step 490/938 - loss: 1.6223 - acc_top1: 0.8157 - acc_top2: 0.8567 - 14ms/step\n", + "step 500/938 - loss: 1.5198 - acc_top1: 0.8167 - acc_top2: 0.8574 - 14ms/step\n", + "step 510/938 - loss: 1.5853 - acc_top1: 0.8181 - acc_top2: 0.8583 - 14ms/step\n", + "step 520/938 - loss: 1.5252 - acc_top1: 0.8196 - acc_top2: 0.8593 - 14ms/step\n", + "step 530/938 - loss: 1.5265 - acc_top1: 0.8207 - acc_top2: 0.8601 - 14ms/step\n", + "step 540/938 - loss: 1.5297 - acc_top1: 0.8217 - acc_top2: 0.8608 - 14ms/step\n", + "step 550/938 - loss: 1.5743 - acc_top1: 0.8226 - acc_top2: 0.8613 - 13ms/step\n", + "step 560/938 - loss: 1.6419 - acc_top1: 0.8237 - acc_top2: 0.8622 - 13ms/step\n", + "step 570/938 - loss: 1.5556 - acc_top1: 0.8247 - acc_top2: 0.8630 - 13ms/step\n", + "step 580/938 - loss: 1.5349 - acc_top1: 0.8254 - acc_top2: 0.8635 - 13ms/step\n", + "step 590/938 - loss: 1.4915 - acc_top1: 0.8263 - acc_top2: 0.8640 - 13ms/step\n", + "step 600/938 - loss: 1.5672 - acc_top1: 0.8277 - acc_top2: 0.8651 - 13ms/step\n", + "step 610/938 - loss: 1.5464 - acc_top1: 0.8288 - acc_top2: 0.8659 - 13ms/step\n", + "step 620/938 - loss: 1.6329 - acc_top1: 0.8292 - acc_top2: 0.8661 - 13ms/step\n", + "step 630/938 - loss: 1.6121 - acc_top1: 0.8296 - acc_top2: 0.8662 - 13ms/step\n", + "step 640/938 - loss: 1.5636 - acc_top1: 0.8305 - acc_top2: 0.8668 - 13ms/step\n", + "step 650/938 - loss: 1.6227 - acc_top1: 0.8311 - acc_top2: 0.8672 - 13ms/step\n", + "step 660/938 - loss: 1.5646 - acc_top1: 0.8319 - acc_top2: 0.8678 - 13ms/step\n", + "step 670/938 - loss: 1.5620 - acc_top1: 0.8325 - acc_top2: 0.8681 - 13ms/step\n", + "step 680/938 - loss: 1.4908 - acc_top1: 0.8333 - acc_top2: 0.8688 - 13ms/step\n", + "step 690/938 - loss: 1.6010 - acc_top1: 0.8339 - acc_top2: 0.8691 - 13ms/step\n", + "step 700/938 - loss: 1.5592 - acc_top1: 0.8346 - acc_top2: 0.8695 - 13ms/step\n", + "step 710/938 - loss: 1.6226 - acc_top1: 0.8352 - acc_top2: 0.8699 - 13ms/step\n", + "step 720/938 - loss: 1.5642 - acc_top1: 0.8362 - acc_top2: 0.8705 - 13ms/step\n", + "step 730/938 - loss: 1.5807 - acc_top1: 0.8367 - acc_top2: 0.8707 - 13ms/step\n", + "step 740/938 - loss: 1.5721 - acc_top1: 0.8371 - acc_top2: 0.8708 - 13ms/step\n", + "step 750/938 - loss: 1.6542 - acc_top1: 0.8377 - acc_top2: 0.8711 - 13ms/step\n", + "step 760/938 - loss: 1.5128 - acc_top1: 0.8385 - acc_top2: 0.8716 - 13ms/step\n", + "step 770/938 - loss: 1.5711 - acc_top1: 0.8391 - acc_top2: 0.8721 - 14ms/step\n", + "step 780/938 - loss: 1.6095 - acc_top1: 0.8395 - acc_top2: 0.8725 - 14ms/step\n", + "step 790/938 - loss: 1.5348 - acc_top1: 0.8402 - acc_top2: 0.8730 - 14ms/step\n", + "step 800/938 - loss: 1.5715 - acc_top1: 0.8407 - acc_top2: 0.8732 - 14ms/step\n", + "step 810/938 - loss: 1.5880 - acc_top1: 0.8413 - acc_top2: 0.8737 - 14ms/step\n", + "step 820/938 - loss: 1.6160 - acc_top1: 0.8418 - acc_top2: 0.8740 - 14ms/step\n", + "step 830/938 - loss: 1.5585 - acc_top1: 0.8426 - acc_top2: 0.8746 - 14ms/step\n", + "step 840/938 - loss: 1.5829 - acc_top1: 0.8429 - acc_top2: 0.8748 - 14ms/step\n", + "step 850/938 - loss: 1.5348 - acc_top1: 0.8435 - acc_top2: 0.8753 - 14ms/step\n", + "step 860/938 - loss: 1.5448 - acc_top1: 0.8438 - acc_top2: 0.8754 - 14ms/step\n", + "step 870/938 - loss: 1.5463 - acc_top1: 0.8443 - acc_top2: 0.8759 - 14ms/step\n", + "step 880/938 - loss: 1.5763 - acc_top1: 0.8449 - acc_top2: 0.8762 - 14ms/step\n", + "step 890/938 - loss: 1.5699 - acc_top1: 0.8453 - acc_top2: 0.8764 - 14ms/step\n", + "step 900/938 - loss: 1.5616 - acc_top1: 0.8456 - acc_top2: 0.8766 - 14ms/step\n", + "step 910/938 - loss: 1.5026 - acc_top1: 0.8461 - acc_top2: 0.8771 - 14ms/step\n", + "step 920/938 - loss: 1.5380 - acc_top1: 0.8467 - acc_top2: 0.8774 - 14ms/step\n", + "step 930/938 - loss: 1.5993 - acc_top1: 0.8470 - acc_top2: 0.8777 - 14ms/step\n", + "step 938/938 - loss: 1.4942 - acc_top1: 0.8473 - acc_top2: 0.8778 - 14ms/step\n", "save checkpoint at /Users/chenlong/online_repo/book/paddle2.0_docs/image_classification/mnist_checkpoint/0\n", - "Eval begin...\n", - "step 10/157 - loss: 1.5984 - acc_top1: 0.8797 - acc_top2: 0.8953 - 5ms/step\n", - "step 20/157 - loss: 1.6266 - acc_top1: 0.8789 - acc_top2: 0.9000 - 5ms/step\n", - "step 30/157 - loss: 1.6475 - acc_top1: 0.8771 - acc_top2: 0.8984 - 5ms/step\n", - "step 40/157 - loss: 1.6329 - acc_top1: 0.8730 - acc_top2: 0.8957 - 5ms/step\n", - "step 50/157 - loss: 1.5399 - acc_top1: 0.8712 - acc_top2: 0.8934 - 5ms/step\n", - "step 60/157 - loss: 1.6322 - acc_top1: 0.8750 - acc_top2: 0.8961 - 5ms/step\n", - "step 70/157 - loss: 1.5818 - acc_top1: 0.8721 - acc_top2: 0.8931 - 5ms/step\n", - "step 80/157 - loss: 1.5522 - acc_top1: 0.8760 - acc_top2: 0.8979 - 5ms/step\n", - "step 90/157 - loss: 1.6085 - acc_top1: 0.8785 - acc_top2: 0.8984 - 5ms/step\n", - "step 100/157 - loss: 1.5661 - acc_top1: 0.8784 - acc_top2: 0.8980 - 5ms/step\n", - "step 110/157 - loss: 1.5694 - acc_top1: 0.8805 - acc_top2: 0.8996 - 5ms/step\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "step 120/157 - loss: 1.6012 - acc_top1: 0.8824 - acc_top2: 0.9003 - 5ms/step\n", - "step 130/157 - loss: 1.5378 - acc_top1: 0.8844 - acc_top2: 0.9017 - 5ms/step\n", - "step 140/157 - loss: 1.5068 - acc_top1: 0.8858 - acc_top2: 0.9022 - 5ms/step\n", - "step 150/157 - loss: 1.5424 - acc_top1: 0.8873 - acc_top2: 0.9029 - 5ms/step\n", - "step 157/157 - loss: 1.5862 - acc_top1: 0.8872 - acc_top2: 0.9035 - 5ms/step\n", - "Eval samples: 10000\n", "Epoch 2/2\n", - "step 10/938 - loss: 1.5988 - acc_top1: 0.8859 - acc_top2: 0.9016 - 15ms/step\n", - "step 20/938 - loss: 1.5702 - acc_top1: 0.8852 - acc_top2: 0.9047 - 15ms/step\n", - "step 30/938 - loss: 1.5999 - acc_top1: 0.8833 - acc_top2: 0.9021 - 15ms/step\n", - "step 40/938 - loss: 1.5652 - acc_top1: 0.8816 - acc_top2: 0.9000 - 15ms/step\n", - "step 50/938 - loss: 1.6163 - acc_top1: 0.8853 - acc_top2: 0.9047 - 15ms/step\n", - "step 60/938 - loss: 1.5307 - acc_top1: 0.8849 - acc_top2: 0.9049 - 15ms/step\n", - "step 70/938 - loss: 1.5542 - acc_top1: 0.8846 - acc_top2: 0.9029 - 15ms/step\n", - "step 80/938 - loss: 1.5694 - acc_top1: 0.8816 - acc_top2: 0.9008 - 15ms/step\n", - "step 90/938 - loss: 1.6030 - acc_top1: 0.8806 - acc_top2: 0.8991 - 15ms/step\n", - "step 100/938 - loss: 1.5631 - acc_top1: 0.8814 - acc_top2: 0.8989 - 15ms/step\n", - "step 110/938 - loss: 1.5598 - acc_top1: 0.8804 - acc_top2: 0.8984 - 15ms/step\n", - "step 120/938 - loss: 1.5773 - acc_top1: 0.8803 - acc_top2: 0.8986 - 15ms/step\n", - "step 130/938 - loss: 1.5076 - acc_top1: 0.8815 - acc_top2: 0.8995 - 15ms/step\n", - "step 140/938 - loss: 1.6064 - acc_top1: 0.8809 - acc_top2: 0.8988 - 15ms/step\n", - "step 150/938 - loss: 1.5279 - acc_top1: 0.8815 - acc_top2: 0.8993 - 15ms/step\n", - "step 160/938 - loss: 1.6039 - acc_top1: 0.8820 - acc_top2: 0.8998 - 15ms/step\n", - "step 170/938 - loss: 1.5709 - acc_top1: 0.8814 - acc_top2: 0.8993 - 15ms/step\n", - "step 180/938 - loss: 1.6164 - acc_top1: 0.8806 - acc_top2: 0.8985 - 15ms/step\n", - "step 190/938 - loss: 1.5920 - acc_top1: 0.8802 - acc_top2: 0.8985 - 15ms/step\n", - "step 200/938 - loss: 1.6457 - acc_top1: 0.8793 - acc_top2: 0.8973 - 15ms/step\n", - "step 210/938 - loss: 1.6045 - acc_top1: 0.8794 - acc_top2: 0.8977 - 15ms/step\n", - "step 220/938 - loss: 1.6614 - acc_top1: 0.8795 - acc_top2: 0.8975 - 15ms/step\n", - "step 230/938 - loss: 1.5384 - acc_top1: 0.8789 - acc_top2: 0.8966 - 15ms/step\n", - "step 240/938 - loss: 1.5556 - acc_top1: 0.8785 - acc_top2: 0.8960 - 15ms/step\n", - "step 250/938 - loss: 1.6006 - acc_top1: 0.8782 - acc_top2: 0.8961 - 15ms/step\n", - "step 260/938 - loss: 1.5552 - acc_top1: 0.8790 - acc_top2: 0.8968 - 15ms/step\n", - "step 270/938 - loss: 1.5805 - acc_top1: 0.8791 - acc_top2: 0.8970 - 15ms/step\n", - "step 280/938 - loss: 1.5404 - acc_top1: 0.8787 - acc_top2: 0.8966 - 15ms/step\n", - "step 290/938 - loss: 1.6023 - acc_top1: 0.8789 - acc_top2: 0.8969 - 15ms/step\n", - "step 300/938 - loss: 1.5706 - acc_top1: 0.8788 - acc_top2: 0.8969 - 15ms/step\n", - "step 310/938 - loss: 1.5424 - acc_top1: 0.8790 - acc_top2: 0.8968 - 15ms/step\n", - "step 320/938 - loss: 1.5823 - acc_top1: 0.8798 - acc_top2: 0.8975 - 15ms/step\n", - "step 330/938 - loss: 1.5600 - acc_top1: 0.8801 - acc_top2: 0.8977 - 15ms/step\n", - "step 340/938 - loss: 1.6258 - acc_top1: 0.8795 - acc_top2: 0.8970 - 15ms/step\n", - "step 350/938 - loss: 1.5093 - acc_top1: 0.8796 - acc_top2: 0.8972 - 15ms/step\n", - "step 360/938 - loss: 1.6030 - acc_top1: 0.8794 - acc_top2: 0.8967 - 15ms/step\n", - "step 370/938 - loss: 1.5732 - acc_top1: 0.8795 - acc_top2: 0.8969 - 15ms/step\n", - "step 380/938 - loss: 1.5980 - acc_top1: 0.8797 - acc_top2: 0.8972 - 15ms/step\n", - "step 390/938 - loss: 1.5902 - acc_top1: 0.8800 - acc_top2: 0.8974 - 15ms/step\n", - "step 400/938 - loss: 1.5395 - acc_top1: 0.8809 - acc_top2: 0.8983 - 15ms/step\n", - "step 410/938 - loss: 1.6623 - acc_top1: 0.8804 - acc_top2: 0.8978 - 15ms/step\n", - "step 420/938 - loss: 1.4987 - acc_top1: 0.8810 - acc_top2: 0.8983 - 15ms/step\n", - "step 430/938 - loss: 1.5989 - acc_top1: 0.8811 - acc_top2: 0.8983 - 15ms/step\n", - "step 440/938 - loss: 1.5722 - acc_top1: 0.8813 - acc_top2: 0.8984 - 15ms/step\n", - "step 450/938 - loss: 1.5549 - acc_top1: 0.8818 - acc_top2: 0.8986 - 15ms/step\n", - "step 460/938 - loss: 1.5536 - acc_top1: 0.8819 - acc_top2: 0.8986 - 15ms/step\n", - "step 470/938 - loss: 1.5247 - acc_top1: 0.8826 - acc_top2: 0.8992 - 15ms/step\n", - "step 480/938 - loss: 1.5520 - acc_top1: 0.8830 - acc_top2: 0.8995 - 15ms/step\n", - "step 490/938 - loss: 1.5518 - acc_top1: 0.8835 - acc_top2: 0.8998 - 15ms/step\n", - "step 500/938 - loss: 1.5227 - acc_top1: 0.8837 - acc_top2: 0.9000 - 15ms/step\n", - "step 510/938 - loss: 1.6014 - acc_top1: 0.8835 - acc_top2: 0.8998 - 15ms/step\n", - "step 520/938 - loss: 1.5526 - acc_top1: 0.8834 - acc_top2: 0.8998 - 15ms/step\n", - "step 530/938 - loss: 1.5849 - acc_top1: 0.8838 - acc_top2: 0.9001 - 15ms/step\n", - "step 540/938 - loss: 1.5607 - acc_top1: 0.8840 - acc_top2: 0.9006 - 15ms/step\n", - "step 550/938 - loss: 1.6438 - acc_top1: 0.8843 - acc_top2: 0.9010 - 15ms/step\n", - "step 560/938 - loss: 1.5229 - acc_top1: 0.8848 - acc_top2: 0.9014 - 15ms/step\n", - "step 570/938 - loss: 1.5395 - acc_top1: 0.8846 - acc_top2: 0.9012 - 15ms/step\n", - "step 580/938 - loss: 1.5409 - acc_top1: 0.8848 - acc_top2: 0.9013 - 15ms/step\n", - "step 590/938 - loss: 1.5851 - acc_top1: 0.8848 - acc_top2: 0.9013 - 15ms/step\n", - "step 600/938 - loss: 1.5383 - acc_top1: 0.8849 - acc_top2: 0.9013 - 15ms/step\n", - "step 610/938 - loss: 1.5969 - acc_top1: 0.8853 - acc_top2: 0.9016 - 15ms/step\n", - "step 620/938 - loss: 1.5634 - acc_top1: 0.8854 - acc_top2: 0.9017 - 15ms/step\n", - "step 630/938 - loss: 1.6308 - acc_top1: 0.8857 - acc_top2: 0.9019 - 15ms/step\n", - "step 640/938 - loss: 1.6413 - acc_top1: 0.8859 - acc_top2: 0.9021 - 15ms/step\n", - "step 650/938 - loss: 1.5954 - acc_top1: 0.8856 - acc_top2: 0.9020 - 15ms/step\n", - "step 660/938 - loss: 1.5278 - acc_top1: 0.8859 - acc_top2: 0.9023 - 15ms/step\n", - "step 670/938 - loss: 1.5144 - acc_top1: 0.8869 - acc_top2: 0.9035 - 15ms/step\n", - "step 680/938 - loss: 1.4612 - acc_top1: 0.8879 - acc_top2: 0.9048 - 15ms/step\n", - "step 690/938 - loss: 1.4820 - acc_top1: 0.8891 - acc_top2: 0.9060 - 15ms/step\n", - "step 700/938 - loss: 1.4766 - acc_top1: 0.8901 - acc_top2: 0.9073 - 15ms/step\n", - "step 710/938 - loss: 1.5245 - acc_top1: 0.8911 - acc_top2: 0.9083 - 15ms/step\n", - "step 720/938 - loss: 1.5183 - acc_top1: 0.8922 - acc_top2: 0.9095 - 15ms/step\n", - "step 730/938 - loss: 1.4971 - acc_top1: 0.8932 - acc_top2: 0.9106 - 15ms/step\n", - "step 740/938 - loss: 1.4744 - acc_top1: 0.8944 - acc_top2: 0.9117 - 15ms/step\n", - "step 750/938 - loss: 1.4789 - acc_top1: 0.8952 - acc_top2: 0.9127 - 15ms/step\n", - "step 760/938 - loss: 1.5114 - acc_top1: 0.8959 - acc_top2: 0.9137 - 15ms/step\n", - "step 770/938 - loss: 1.5035 - acc_top1: 0.8970 - acc_top2: 0.9147 - 15ms/step\n", - "step 780/938 - loss: 1.4668 - acc_top1: 0.8978 - acc_top2: 0.9157 - 15ms/step\n", - "step 790/938 - loss: 1.4850 - acc_top1: 0.8986 - acc_top2: 0.9166 - 15ms/step\n", - "step 800/938 - loss: 1.4777 - acc_top1: 0.8996 - acc_top2: 0.9176 - 15ms/step\n", - "step 810/938 - loss: 1.4783 - acc_top1: 0.9005 - acc_top2: 0.9186 - 15ms/step\n", - "step 820/938 - loss: 1.5256 - acc_top1: 0.9011 - acc_top2: 0.9194 - 15ms/step\n", - "step 830/938 - loss: 1.4801 - acc_top1: 0.9019 - acc_top2: 0.9202 - 15ms/step\n", - "step 840/938 - loss: 1.4873 - acc_top1: 0.9026 - acc_top2: 0.9211 - 15ms/step\n", - "step 850/938 - loss: 1.5093 - acc_top1: 0.9034 - acc_top2: 0.9219 - 15ms/step\n", - "step 860/938 - loss: 1.4727 - acc_top1: 0.9042 - acc_top2: 0.9227 - 15ms/step\n", - "step 870/938 - loss: 1.4917 - acc_top1: 0.9050 - acc_top2: 0.9235 - 15ms/step\n", - "step 880/938 - loss: 1.4792 - acc_top1: 0.9058 - acc_top2: 0.9243 - 15ms/step\n", - "step 890/938 - loss: 1.4854 - acc_top1: 0.9066 - acc_top2: 0.9251 - 15ms/step\n", - "step 900/938 - loss: 1.4616 - acc_top1: 0.9074 - acc_top2: 0.9258 - 15ms/step\n", - "step 910/938 - loss: 1.4954 - acc_top1: 0.9081 - acc_top2: 0.9265 - 15ms/step\n", - "step 920/938 - loss: 1.4875 - acc_top1: 0.9087 - acc_top2: 0.9272 - 15ms/step\n", - "step 930/938 - loss: 1.5037 - acc_top1: 0.9094 - acc_top2: 0.9279 - 15ms/step\n", - "step 938/938 - loss: 1.4964 - acc_top1: 0.9099 - acc_top2: 0.9284 - 15ms/step\n", - "save checkpoint at /Users/chenlong/online_repo/book/paddle2.0_docs/image_classification/mnist_checkpoint/1\n", - "Eval begin...\n", - "step 10/157 - loss: 1.5196 - acc_top1: 0.9719 - acc_top2: 0.9969 - 5ms/step\n", - "step 20/157 - loss: 1.5393 - acc_top1: 0.9672 - acc_top2: 0.9945 - 6ms/step\n", - "step 30/157 - loss: 1.4928 - acc_top1: 0.9630 - acc_top2: 0.9906 - 5ms/step\n", - "step 40/157 - loss: 1.4765 - acc_top1: 0.9617 - acc_top2: 0.9902 - 5ms/step\n", - "step 50/157 - loss: 1.4646 - acc_top1: 0.9631 - acc_top2: 0.9903 - 5ms/step\n" + "step 10/938 - loss: 1.5919 - acc_top1: 0.8875 - acc_top2: 0.9047 - 14ms/step\n", + "step 20/938 - loss: 1.5900 - acc_top1: 0.8875 - acc_top2: 0.9062 - 14ms/step\n", + "step 30/938 - loss: 1.5929 - acc_top1: 0.8891 - acc_top2: 0.9036 - 13ms/step\n", + "step 40/938 - loss: 1.5855 - acc_top1: 0.8883 - acc_top2: 0.9027 - 13ms/step\n", + "step 50/938 - loss: 1.6197 - acc_top1: 0.8916 - acc_top2: 0.9072 - 13ms/step\n", + "step 60/938 - loss: 1.5084 - acc_top1: 0.8914 - acc_top2: 0.9078 - 13ms/step\n", + "step 70/938 - loss: 1.5552 - acc_top1: 0.8904 - acc_top2: 0.9067 - 13ms/step\n", + "step 80/938 - loss: 1.5700 - acc_top1: 0.8887 - acc_top2: 0.9049 - 13ms/step\n", + "step 90/938 - loss: 1.6073 - acc_top1: 0.8866 - acc_top2: 0.9030 - 13ms/step\n", + "step 100/938 - loss: 1.5754 - acc_top1: 0.8859 - acc_top2: 0.9022 - 13ms/step\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "step 60/157 - loss: 1.5646 - acc_top1: 0.9641 - acc_top2: 0.9906 - 5ms/step\n", - "step 70/157 - loss: 1.5167 - acc_top1: 0.9618 - acc_top2: 0.9900 - 5ms/step\n", - "step 80/157 - loss: 1.4728 - acc_top1: 0.9635 - acc_top2: 0.9906 - 5ms/step\n", - "step 90/157 - loss: 1.5030 - acc_top1: 0.9668 - acc_top2: 0.9917 - 5ms/step\n", - "step 100/157 - loss: 1.4612 - acc_top1: 0.9677 - acc_top2: 0.9914 - 5ms/step\n", - "step 110/157 - loss: 1.4612 - acc_top1: 0.9689 - acc_top2: 0.9913 - 5ms/step\n", - "step 120/157 - loss: 1.4612 - acc_top1: 0.9707 - acc_top2: 0.9919 - 5ms/step\n", - "step 130/157 - loss: 1.4621 - acc_top1: 0.9719 - acc_top2: 0.9923 - 5ms/step\n", - "step 140/157 - loss: 1.4612 - acc_top1: 0.9734 - acc_top2: 0.9929 - 5ms/step\n", - "step 150/157 - loss: 1.4660 - acc_top1: 0.9748 - acc_top2: 0.9933 - 5ms/step\n", - "step 157/157 - loss: 1.5215 - acc_top1: 0.9731 - acc_top2: 0.9930 - 5ms/step\n", - "Eval samples: 10000\n", + "step 110/938 - loss: 1.5484 - acc_top1: 0.8848 - acc_top2: 0.9017 - 14ms/step\n", + "step 120/938 - loss: 1.5904 - acc_top1: 0.8840 - acc_top2: 0.9020 - 14ms/step\n", + "step 130/938 - loss: 1.5108 - acc_top1: 0.8852 - acc_top2: 0.9025 - 14ms/step\n", + "step 140/938 - loss: 1.6199 - acc_top1: 0.8840 - acc_top2: 0.9016 - 14ms/step\n", + "step 150/938 - loss: 1.5337 - acc_top1: 0.8842 - acc_top2: 0.9019 - 13ms/step\n", + "step 160/938 - loss: 1.6094 - acc_top1: 0.8846 - acc_top2: 0.9023 - 13ms/step\n", + "step 170/938 - loss: 1.5653 - acc_top1: 0.8843 - acc_top2: 0.9019 - 13ms/step\n", + "step 180/938 - loss: 1.5978 - acc_top1: 0.8835 - acc_top2: 0.9011 - 13ms/step\n", + "step 190/938 - loss: 1.5950 - acc_top1: 0.8833 - acc_top2: 0.9012 - 13ms/step\n", + "step 200/938 - loss: 1.6422 - acc_top1: 0.8828 - acc_top2: 0.9002 - 13ms/step\n", + "step 210/938 - loss: 1.5752 - acc_top1: 0.8831 - acc_top2: 0.9004 - 13ms/step\n", + "step 220/938 - loss: 1.6635 - acc_top1: 0.8832 - acc_top2: 0.9001 - 13ms/step\n", + "step 230/938 - loss: 1.5726 - acc_top1: 0.8823 - acc_top2: 0.8991 - 13ms/step\n", + "step 240/938 - loss: 1.5702 - acc_top1: 0.8814 - acc_top2: 0.8981 - 13ms/step\n", + "step 250/938 - loss: 1.5748 - acc_top1: 0.8814 - acc_top2: 0.8981 - 14ms/step\n", + "step 260/938 - loss: 1.5589 - acc_top1: 0.8822 - acc_top2: 0.8988 - 14ms/step\n", + "step 270/938 - loss: 1.5902 - acc_top1: 0.8823 - acc_top2: 0.8988 - 14ms/step\n", + "step 280/938 - loss: 1.5646 - acc_top1: 0.8817 - acc_top2: 0.8982 - 14ms/step\n", + "step 290/938 - loss: 1.6280 - acc_top1: 0.8819 - acc_top2: 0.8985 - 14ms/step\n", + "step 300/938 - loss: 1.5697 - acc_top1: 0.8815 - acc_top2: 0.8982 - 14ms/step\n", + "step 310/938 - loss: 1.5540 - acc_top1: 0.8814 - acc_top2: 0.8981 - 14ms/step\n", + "step 320/938 - loss: 1.5598 - acc_top1: 0.8821 - acc_top2: 0.8988 - 14ms/step\n", + "step 330/938 - loss: 1.5498 - acc_top1: 0.8824 - acc_top2: 0.8991 - 14ms/step\n", + "step 340/938 - loss: 1.6276 - acc_top1: 0.8818 - acc_top2: 0.8984 - 14ms/step\n", + "step 350/938 - loss: 1.5129 - acc_top1: 0.8821 - acc_top2: 0.8988 - 14ms/step\n", + "step 360/938 - loss: 1.6158 - acc_top1: 0.8818 - acc_top2: 0.8984 - 14ms/step\n", + "step 370/938 - loss: 1.5300 - acc_top1: 0.8820 - acc_top2: 0.8986 - 14ms/step\n", + "step 380/938 - loss: 1.5718 - acc_top1: 0.8822 - acc_top2: 0.8988 - 14ms/step\n", + "step 390/938 - loss: 1.5898 - acc_top1: 0.8825 - acc_top2: 0.8990 - 14ms/step\n", + "step 400/938 - loss: 1.5177 - acc_top1: 0.8834 - acc_top2: 0.9000 - 14ms/step\n", + "step 410/938 - loss: 1.6493 - acc_top1: 0.8831 - acc_top2: 0.8997 - 14ms/step\n", + "step 420/938 - loss: 1.5071 - acc_top1: 0.8838 - acc_top2: 0.9002 - 14ms/step\n", + "step 430/938 - loss: 1.5982 - acc_top1: 0.8840 - acc_top2: 0.9002 - 14ms/step\n", + "step 440/938 - loss: 1.5649 - acc_top1: 0.8841 - acc_top2: 0.9003 - 14ms/step\n", + "step 450/938 - loss: 1.5555 - acc_top1: 0.8844 - acc_top2: 0.9005 - 14ms/step\n", + "step 460/938 - loss: 1.5536 - acc_top1: 0.8845 - acc_top2: 0.9005 - 14ms/step\n", + "step 470/938 - loss: 1.5401 - acc_top1: 0.8851 - acc_top2: 0.9011 - 14ms/step\n", + "step 480/938 - loss: 1.5549 - acc_top1: 0.8854 - acc_top2: 0.9013 - 14ms/step\n", + "step 490/938 - loss: 1.5596 - acc_top1: 0.8858 - acc_top2: 0.9017 - 14ms/step\n", + "step 500/938 - loss: 1.5059 - acc_top1: 0.8860 - acc_top2: 0.9018 - 14ms/step\n", + "step 510/938 - loss: 1.6073 - acc_top1: 0.8858 - acc_top2: 0.9017 - 14ms/step\n", + "step 520/938 - loss: 1.5588 - acc_top1: 0.8857 - acc_top2: 0.9016 - 14ms/step\n", + "step 530/938 - loss: 1.6165 - acc_top1: 0.8859 - acc_top2: 0.9019 - 14ms/step\n", + "step 540/938 - loss: 1.5884 - acc_top1: 0.8862 - acc_top2: 0.9023 - 14ms/step\n", + "step 550/938 - loss: 1.6552 - acc_top1: 0.8863 - acc_top2: 0.9027 - 14ms/step\n", + "step 560/938 - loss: 1.5529 - acc_top1: 0.8867 - acc_top2: 0.9030 - 14ms/step\n", + "step 570/938 - loss: 1.5441 - acc_top1: 0.8866 - acc_top2: 0.9029 - 14ms/step\n", + "step 580/938 - loss: 1.5438 - acc_top1: 0.8867 - acc_top2: 0.9029 - 14ms/step\n", + "step 590/938 - loss: 1.5761 - acc_top1: 0.8868 - acc_top2: 0.9029 - 14ms/step\n", + "step 600/938 - loss: 1.5384 - acc_top1: 0.8867 - acc_top2: 0.9029 - 14ms/step\n", + "step 610/938 - loss: 1.5858 - acc_top1: 0.8871 - acc_top2: 0.9032 - 14ms/step\n", + "step 620/938 - loss: 1.5524 - acc_top1: 0.8872 - acc_top2: 0.9034 - 14ms/step\n", + "step 630/938 - loss: 1.6182 - acc_top1: 0.8875 - acc_top2: 0.9035 - 14ms/step\n", + "step 640/938 - loss: 1.6326 - acc_top1: 0.8877 - acc_top2: 0.9037 - 14ms/step\n", + "step 650/938 - loss: 1.5871 - acc_top1: 0.8877 - acc_top2: 0.9035 - 14ms/step\n", + "step 660/938 - loss: 1.5403 - acc_top1: 0.8877 - acc_top2: 0.9034 - 14ms/step\n", + "step 670/938 - loss: 1.5539 - acc_top1: 0.8879 - acc_top2: 0.9035 - 14ms/step\n", + "step 680/938 - loss: 1.4918 - acc_top1: 0.8881 - acc_top2: 0.9036 - 14ms/step\n", + "step 690/938 - loss: 1.6007 - acc_top1: 0.8882 - acc_top2: 0.9036 - 14ms/step\n", + "step 700/938 - loss: 1.5539 - acc_top1: 0.8883 - acc_top2: 0.9037 - 14ms/step\n", + "step 710/938 - loss: 1.6036 - acc_top1: 0.8882 - acc_top2: 0.9035 - 14ms/step\n", + "step 720/938 - loss: 1.5943 - acc_top1: 0.8881 - acc_top2: 0.9035 - 14ms/step\n", + "step 730/938 - loss: 1.5714 - acc_top1: 0.8881 - acc_top2: 0.9035 - 14ms/step\n", + "step 740/938 - loss: 1.5095 - acc_top1: 0.8881 - acc_top2: 0.9035 - 14ms/step\n", + "step 750/938 - loss: 1.5069 - acc_top1: 0.8882 - acc_top2: 0.9035 - 14ms/step\n", + "step 760/938 - loss: 1.5816 - acc_top1: 0.8882 - acc_top2: 0.9035 - 14ms/step\n", + "step 770/938 - loss: 1.5855 - acc_top1: 0.8880 - acc_top2: 0.9033 - 14ms/step\n", + "step 780/938 - loss: 1.5599 - acc_top1: 0.8881 - acc_top2: 0.9034 - 14ms/step\n", + "step 790/938 - loss: 1.6029 - acc_top1: 0.8879 - acc_top2: 0.9032 - 14ms/step\n", + "step 800/938 - loss: 1.5839 - acc_top1: 0.8880 - acc_top2: 0.9033 - 14ms/step\n", + "step 810/938 - loss: 1.5545 - acc_top1: 0.8882 - acc_top2: 0.9035 - 14ms/step\n", + "step 820/938 - loss: 1.5458 - acc_top1: 0.8881 - acc_top2: 0.9036 - 14ms/step\n", + "step 830/938 - loss: 1.5911 - acc_top1: 0.8879 - acc_top2: 0.9033 - 14ms/step\n", + "step 840/938 - loss: 1.5845 - acc_top1: 0.8881 - acc_top2: 0.9035 - 14ms/step\n", + "step 850/938 - loss: 1.5628 - acc_top1: 0.8880 - acc_top2: 0.9035 - 14ms/step\n", + "step 860/938 - loss: 1.5596 - acc_top1: 0.8880 - acc_top2: 0.9035 - 14ms/step\n", + "step 870/938 - loss: 1.5843 - acc_top1: 0.8882 - acc_top2: 0.9036 - 14ms/step\n", + "step 880/938 - loss: 1.5393 - acc_top1: 0.8883 - acc_top2: 0.9036 - 14ms/step\n", + "step 890/938 - loss: 1.5382 - acc_top1: 0.8882 - acc_top2: 0.9035 - 14ms/step\n", + "step 900/938 - loss: 1.5910 - acc_top1: 0.8884 - acc_top2: 0.9036 - 14ms/step\n", + "step 910/938 - loss: 1.5682 - acc_top1: 0.8886 - acc_top2: 0.9038 - 14ms/step\n", + "step 920/938 - loss: 1.5736 - acc_top1: 0.8889 - acc_top2: 0.9039 - 14ms/step\n", + "step 930/938 - loss: 1.5283 - acc_top1: 0.8888 - acc_top2: 0.9038 - 14ms/step\n", + "step 938/938 - loss: 1.5582 - acc_top1: 0.8888 - acc_top2: 0.9038 - 14ms/step\n", + "save checkpoint at /Users/chenlong/online_repo/book/paddle2.0_docs/image_classification/mnist_checkpoint/1\n", "save checkpoint at /Users/chenlong/online_repo/book/paddle2.0_docs/image_classification/mnist_checkpoint/final\n" ] } ], "source": [ "model.fit(train_dataset,\n", - " test_dataset,\n", " epochs=2,\n", " batch_size=64,\n", " save_dir='mnist_checkpoint')" @@ -467,131 +537,59 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### 训练方式1结束\n", - "以上就是训练方式1,可以非常快速的完成网络模型训练。此外,paddle还可以用下面的方式来完成模型的训练。" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# 3.训练方式2\n", - "方式1可以快速便捷的完成训练,隐藏了训练时的细节。而方式2则可以用最基本的方式,完成模型的训练。具体如下。" + "### 使用model.evaluate来预测模型" ] }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 43, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "epoch: 0, batch_id: 0, loss is: [2.300888], acc is: [0.28125]\n", - "epoch: 0, batch_id: 100, loss is: [1.6948285], acc is: [0.8125]\n", - "epoch: 0, batch_id: 200, loss is: [1.5282547], acc is: [0.96875]\n", - "epoch: 0, batch_id: 300, loss is: [1.509404], acc is: [0.96875]\n", - "epoch: 0, batch_id: 400, loss is: [1.4973292], acc is: [1.]\n", - "epoch: 0, batch_id: 500, loss is: [1.5063374], acc is: [0.984375]\n", - "epoch: 0, batch_id: 600, loss is: [1.490077], acc is: [0.984375]\n", - "epoch: 0, batch_id: 700, loss is: [1.5206413], acc is: [0.984375]\n", - "epoch: 0, batch_id: 800, loss is: [1.5104291], acc is: [1.]\n", - "epoch: 0, batch_id: 900, loss is: [1.5216607], acc is: [0.96875]\n", - "epoch: 1, batch_id: 0, loss is: [1.4949667], acc is: [0.984375]\n", - "epoch: 1, batch_id: 100, loss is: [1.4923338], acc is: [0.96875]\n", - "epoch: 1, batch_id: 200, loss is: [1.5026703], acc is: [1.]\n", - "epoch: 1, batch_id: 300, loss is: [1.4965419], acc is: [0.984375]\n", - "epoch: 1, batch_id: 400, loss is: [1.5270758], acc is: [1.]\n", - "epoch: 1, batch_id: 500, loss is: [1.4774603], acc is: [1.]\n", - "epoch: 1, batch_id: 600, loss is: [1.4762554], acc is: [0.984375]\n", - "epoch: 1, batch_id: 700, loss is: [1.4773959], acc is: [0.984375]\n", - "epoch: 1, batch_id: 800, loss is: [1.5044193], acc is: [1.]\n", - "epoch: 1, batch_id: 900, loss is: [1.4986757], acc is: [0.96875]\n" + "Eval begin...\n", + "step 10/157 - loss: 1.5447 - acc_top1: 0.8953 - acc_top2: 0.9078 - 5ms/step\n", + "step 20/157 - loss: 1.6185 - acc_top1: 0.8930 - acc_top2: 0.9078 - 5ms/step\n", + "step 30/157 - loss: 1.6497 - acc_top1: 0.8917 - acc_top2: 0.9057 - 5ms/step\n", + "step 40/157 - loss: 1.6318 - acc_top1: 0.8902 - acc_top2: 0.9055 - 5ms/step\n", + "step 50/157 - loss: 1.5533 - acc_top1: 0.8856 - acc_top2: 0.9012 - 5ms/step\n", + "step 60/157 - loss: 1.6212 - acc_top1: 0.8878 - acc_top2: 0.9036 - 5ms/step\n", + "step 70/157 - loss: 1.5674 - acc_top1: 0.8839 - acc_top2: 0.9002 - 5ms/step\n", + "step 80/157 - loss: 1.5409 - acc_top1: 0.8891 - acc_top2: 0.9043 - 5ms/step\n", + "step 90/157 - loss: 1.6133 - acc_top1: 0.8903 - acc_top2: 0.9045 - 5ms/step\n", + "step 100/157 - loss: 1.5535 - acc_top1: 0.8909 - acc_top2: 0.9044 - 5ms/step\n", + "step 110/157 - loss: 1.5690 - acc_top1: 0.8916 - acc_top2: 0.9054 - 5ms/step\n", + "step 120/157 - loss: 1.6147 - acc_top1: 0.8926 - acc_top2: 0.9055 - 5ms/step\n", + "step 130/157 - loss: 1.5203 - acc_top1: 0.8944 - acc_top2: 0.9066 - 5ms/step\n", + "step 140/157 - loss: 1.5066 - acc_top1: 0.8952 - acc_top2: 0.9068 - 5ms/step\n", + "step 150/157 - loss: 1.5536 - acc_top1: 0.8958 - acc_top2: 0.9072 - 5ms/step\n", + "step 157/157 - loss: 1.5855 - acc_top1: 0.8956 - acc_top2: 0.9076 - 5ms/step\n", + "Eval samples: 10000\n" ] - } - ], - "source": [ - "import paddle\n", - "train_loader = paddle.io.DataLoader(train_dataset, places=paddle.CPUPlace(), batch_size=64)\n", - "def train(model):\n", - " model.train()\n", - " epochs = 2\n", - " batch_size = 64\n", - " optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())\n", - " for epoch in range(epochs):\n", - " for batch_id, data in enumerate(train_loader()):\n", - " x_data = data[0]\n", - " y_data = data[1]\n", - " predicts = model(x_data)\n", - " loss = paddle.nn.functional.cross_entropy(predicts, y_data)\n", - " acc = paddle.metric.accuracy(predicts, y_data, k=2)\n", - " avg_loss = paddle.mean(loss)\n", - " avg_acc = paddle.mean(acc)\n", - " avg_loss.backward()\n", - " if batch_id % 100 == 0:\n", - " print(\"epoch: {}, batch_id: {}, loss is: {}, acc is: {}\".format(epoch, batch_id, avg_loss.numpy(), avg_acc.numpy()))\n", - " optim.minimize(avg_loss)\n", - " model.clear_gradients()\n", - "model = LeNet()\n", - "train(model)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 对模型进行验证" - ] - }, - { - "cell_type": "code", - "execution_count": 46, - "metadata": {}, - "outputs": [ + }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "batch_id: 0, loss is: [1.5017498], acc is: [1.]\n", - "batch_id: 100, loss is: [1.4783669], acc is: [0.984375]\n", - "batch_id: 200, loss is: [1.4958509], acc is: [1.]\n", - "batch_id: 300, loss is: [1.4924574], acc is: [1.]\n", - "batch_id: 400, loss is: [1.4762049], acc is: [1.]\n", - "batch_id: 500, loss is: [1.4817208], acc is: [0.984375]\n", - "batch_id: 600, loss is: [1.4763825], acc is: [0.984375]\n", - "batch_id: 700, loss is: [1.4954926], acc is: [1.]\n", - "batch_id: 800, loss is: [1.5220823], acc is: [0.984375]\n", - "batch_id: 900, loss is: [1.4945463], acc is: [0.984375]\n" - ] + "data": { + "text/plain": [ + "{'loss': [1.585474], 'acc_top1': 0.8956, 'acc_top2': 0.9076}" + ] + }, + "execution_count": 43, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "import paddle\n", - "test_loader = paddle.io.DataLoader(test_dataset, places=paddle.CPUPlace(), batch_size=64)\n", - "def test(model):\n", - " model.eval()\n", - " batch_size = 64\n", - " for batch_id, data in enumerate(train_loader()):\n", - " x_data = data[0]\n", - " y_data = data[1]\n", - " predicts = model(x_data)\n", - " loss = paddle.nn.functional.cross_entropy(predicts, y_data)\n", - " acc = paddle.metric.accuracy(predicts, y_data, k=2)\n", - " avg_loss = paddle.mean(loss)\n", - " avg_acc = paddle.mean(acc)\n", - " avg_loss.backward()\n", - " if batch_id % 100 == 0:\n", - " print(\"batch_id: {}, loss is: {}, acc is: {}\".format(batch_id, avg_loss.numpy(), avg_acc.numpy()))\n", - "test(model)" + "model.evaluate(test_dataset, batch_size=64)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### 训练方式2结束\n", - "以上就是训练方式2,通过这种方式,可以清楚的看到训练和测试中的每一步过程。" + "### 训练方式二结束\n", + "以上就是训练方式二,可以快速、高效的完成网络模型训练与预测。" ] }, {