Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use MLDatasets.jl to load data; Flux's datasets are deprecated. #117

Merged
merged 1 commit into from
Dec 4, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions tutorials/_posts/2021-02-07-convnet.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ This example writes out the saved model to the file `mnist_conv.bson`. Also, it
To run this example, we need the following packages:

```julia
using Flux, Flux.Data.MNIST, Statistics
using Flux, MLDatasets, Statistics
using Flux: onehotbatch, onecold, logitcrossentropy
using MLDatasets: MNIST
using Base.Iterators: partition
using Printf, BSON
using Parameters: @with_kw
Expand All @@ -38,9 +39,9 @@ To train our model, we need to bundle images together with their labels and grou

```julia
function make_minibatch(X, Y, idxs)
X_batch = Array{Float32}(undef, size(X[1])..., 1, length(idxs))
X_batch = Array{Float32}(undef, size(X)[1:end-1]..., 1, length(idxs))
for i in 1:length(idxs)
X_batch[:, :, :, i] = Float32.(X[idxs[i]])
X_batch[:, :, :, i] = Float32.(X[:,:,idxs[i]])
end
Y_batch = onehotbatch(Y[idxs], 0:9)
return (X_batch, Y_batch)
Expand All @@ -62,15 +63,13 @@ end
```julia
function get_processed_data(args)
# Load labels and images
train_labels = MNIST.labels()
train_imgs = MNIST.images()
mb_idxs = partition(1:length(train_imgs), args.batch_size)
train_imgs, train_labels = MNIST.traindata()
mb_idxs = partition(1:length(train_labels), args.batch_size)
train_set = [make_minibatch(train_imgs, train_labels, i) for i in mb_idxs]

# Prepare test set as one giant minibatch:
test_imgs = MNIST.images(:test)
test_labels = MNIST.labels(:test)
test_set = make_minibatch(test_imgs, test_labels, 1:length(test_imgs))
test_imgs, test_labels = MNIST.testdata()
test_set = make_minibatch(test_imgs, test_labels, 1:length(test_labels))

return train_set, test_set

Expand Down