From 5276df8d42b849a5e3ae7e7915f57e4f7f8a5453 Mon Sep 17 00:00:00 2001 From: HaiYing Wang Date: Fri, 3 Dec 2021 17:20:55 -0500 Subject: [PATCH] Use MLDatasets.jl to load data; Flux's datasets are deprecated. --- tutorials/_posts/2021-02-07-convnet.md | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/tutorials/_posts/2021-02-07-convnet.md b/tutorials/_posts/2021-02-07-convnet.md index ee51e6c7..1c956daf 100644 --- a/tutorials/_posts/2021-02-07-convnet.md +++ b/tutorials/_posts/2021-02-07-convnet.md @@ -12,8 +12,9 @@ This example writes out the saved model to the file `mnist_conv.bson`. Also, it To run this example, we need the following packages: ```julia -using Flux, Flux.Data.MNIST, Statistics +using Flux, MLDatasets, Statistics using Flux: onehotbatch, onecold, logitcrossentropy +using MLDatasets: MNIST using Base.Iterators: partition using Printf, BSON using Parameters: @with_kw @@ -38,9 +39,9 @@ To train our model, we need to bundle images together with their labels and grou ```julia function make_minibatch(X, Y, idxs) - X_batch = Array{Float32}(undef, size(X[1])..., 1, length(idxs)) + X_batch = Array{Float32}(undef, size(X)[1:end-1]..., 1, length(idxs)) for i in 1:length(idxs) - X_batch[:, :, :, i] = Float32.(X[idxs[i]]) + X_batch[:, :, :, i] = Float32.(X[:,:,idxs[i]]) end Y_batch = onehotbatch(Y[idxs], 0:9) return (X_batch, Y_batch) @@ -62,15 +63,13 @@ end ```julia function get_processed_data(args) # Load labels and images - train_labels = MNIST.labels() - train_imgs = MNIST.images() - mb_idxs = partition(1:length(train_imgs), args.batch_size) + train_imgs, train_labels = MNIST.traindata() + mb_idxs = partition(1:length(train_labels), args.batch_size) train_set = [make_minibatch(train_imgs, train_labels, i) for i in mb_idxs] # Prepare test set as one giant minibatch: - test_imgs = MNIST.images(:test) - test_labels = MNIST.labels(:test) - test_set = make_minibatch(test_imgs, test_labels, 1:length(test_imgs)) + test_imgs, test_labels = MNIST.testdata() + test_set = make_minibatch(test_imgs, test_labels, 1:length(test_labels)) return train_set, test_set