Skip to content

Commit

Permalink
Updates CIFAR-10 example to use CNN instead of MLP.
Browse files Browse the repository at this point in the history
  • Loading branch information
jkeeley-MW committed Jan 13, 2025
1 parent 0c7b986 commit 00f1e3c
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 58 deletions.
14 changes: 7 additions & 7 deletions conslearn/buildConstrainedNetwork.m
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@
% The network includes either a featureInputLayer or an imageInputLayer,
% depending on INPUTSIZE:
%
% - If INPUTSIZE is a scalar, then the network has a featureInputLayer.
%
% - If INPUTSIZE is a vector with three elements, then the network has an
% - If INPUTSIZE is a scalar, then the network has a featureInputLayer. -
% If INPUTSIZE is a vector with three elements, then the network has an
% imageInputLayer.
%
% NUMHIDDENUNITS is a vector of integers that corresponds to the sizes
Expand All @@ -30,7 +29,8 @@
% ConvexNonDecreasingActivation - Convex, non-decreasing
% ("fully-convex") activation functions.
% ("partially-convex") The options are "softplus" or
% "relu". The default is "softplus".
% "relu".
% The default is "softplus".
% Activation - Network activation function.
% ("partially-convex") The options are "tanh", "relu" or
% "fullsort". The default is "tanh".
Expand Down Expand Up @@ -80,9 +80,9 @@
% "fullsort". The default is
% "fullsort".
% UpperBoundLipschitzConstant - Upper bound on the Lipschitz
% constant for the network, as a
% positive real number. The default
% value is 1.
% constant
% for the network, as a positive real
% number. The default value is 1.
% pNorm - p-norm value for measuring
% distance with respect to the
% Lipschitz continuity definition.
Expand Down
67 changes: 59 additions & 8 deletions conslearn/trainConstrainedNetwork.m
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@
% iteration, specified as: "mse", "mae", or
% "crossentropy".
% The default is "mse".
% L2Regularization - Factor for L2 regularization (weight decay).
% The default is 0.
% ValidationData - Data to use for validation during training,
% specified as a minibatchqueue object.
% ValidationFrequency - Frequency of validation in number of
% iterations. The default is 50.
% TrainingMonitor - Flag to display the training progress monitor
% showing the training data loss.
% The default is true.
Expand Down Expand Up @@ -73,6 +79,9 @@
trainingOptions.LossMetric {...
mustBeTextScalar, ...
mustBeMember(trainingOptions.LossMetric,["mse","mae","crossentropy"])} = "mse";
trainingOptions.L2Regularization (1,1) {mustBeNumeric, mustBeNonnegative} = 0
trainingOptions.ValidationData minibatchqueue {mustBeScalarOrEmpty} = minibatchqueue.empty
trainingOptions.ValidationFrequency (1,1) {mustBeNumeric, mustBePositive, mustBeInteger} = 50
trainingOptions.TrainingMonitor (1,1) logical = true;
trainingOptions.TrainingMonitorLogScale (1,1) logical = true;
trainingOptions.ShuffleMinibatches (1,1) logical = false;
Expand All @@ -84,26 +93,39 @@
% Set up the training progress monitor
if trainingOptions.TrainingMonitor
monitor = trainingProgressMonitor;

% Track progress information
monitor.Info = ["LearningRate","Epoch","Iteration"];
monitor.Metrics = "TrainingLoss";

% Plot the training and validation metrics on the same plot
monitor.Metrics = ["TrainingLoss", "ValidationLoss"];
groupSubPlot(monitor, "Loss", ["TrainingLoss", "ValidationLoss"]);

% Apply loss log scale
if trainingOptions.TrainingMonitorLogScale
yscale(monitor,"TrainingLoss","log");
yscale(monitor,"Loss","log");
end

% Specify the horizontal axis label for the training plot.
monitor.XLabel = "Iteration";

% Start the monitor
monitor.Status = "Running";
stopButton = @() ~monitor.Stop;
else
% Let training run without a monitor by setting stop to false
stopButton = @() 1;
end

% Prepare the generic hyperparameters
maxEpochs = trainingOptions.MaxEpochs;
initialLearnRate = trainingOptions.InitialLearnRate;
decay = trainingOptions.Decay;
metric = trainingOptions.LossMetric;
shuffleMinibatches = trainingOptions.ShuffleMinibatches;
l2Regularization = trainingOptions.L2Regularization;
validationData = trainingOptions.ValidationData;
validationFrequency = trainingOptions.ValidationFrequency;

% Specify ADAM options
avgG = [];
Expand Down Expand Up @@ -147,7 +169,7 @@

% Evaluate the model gradients, and loss using dlfeval and the
% modelLoss function and update the network state.
[lossTrain,gradients,state] = dlfeval(@iModelLoss,net,X,T,metric);
[lossTrain,gradients,state] = dlfeval(dlaccelerate(@iModelLoss),net,X,T,metric,l2Regularization);
net.State = state;

% Gradient Update
Expand All @@ -162,10 +184,33 @@
LearningRate=learnRate, ...
Epoch=string(epoch) + " of " + string(maxEpochs), ...
Iteration=string(iteration));

recordMetrics(monitor,iteration, ...
TrainingLoss=lossTrain);

monitor.Progress = 100*epoch/maxEpochs;
end

% Record validation loss, if requested
if ~isempty(validationData)
if (iteration == 1) || (mod(iteration, validationFrequency) == 0)

% Reset the validation data
if ~hasdata(validationData)
reset(validationData);
end

% Compute the validation loss
[X, T] = next(validationData);
lossValidation = iModelLoss(net, X, T, metric, l2Regularization);

% Update the training monitor
if trainingOptions.TrainingMonitor
recordMetrics(monitor,iteration, ...
ValidationLoss=lossValidation);
end
end
end
end
end

Expand All @@ -181,23 +226,29 @@
end

%% Helpers
function [loss,gradients,state] = iModelLoss(net,X,T,metric)
function [loss,gradients,state] = iModelLoss(net,X,T,metric,l2Regularization)

% Make a forward pass
[Y,state] = forward(net,X);
[Y, state] = forward(net,X);

% Compute the loss
switch metric
case "mse"
loss = mse(Y,T);
case "mae"
loss = mean(abs(Y-T));
loss = mean(abs(Y-T), 'all');
case "crossentropy"
loss = crossentropy(softmax(Y),T);
end

% Compute the gradient of the loss with respect to the learnabless
gradients = dlgradient(loss,net.Learnables);
if nargout > 1
% Compute the gradient of the loss with respect to the learnables
gradients = dlgradient(loss,net.Learnables);

% Apply L2 regularization
idxWeights = net.Learnables.Parameter == "Weights";
gradients(idxWeights,:) = dlupdate(@(g,w) g + l2Regularization*w, gradients(idxWeights, :), net.Learnables(idxWeights, :));
end
end

function proximalOp = iSetupProximalOperator(constraint,trainingOptions)
Expand Down
82 changes: 39 additions & 43 deletions examples/convex/classificationCIFAR10/TrainICNNOnCIFAR10Example.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@

# <span style="color:rgb(213,80,0)">Train Fully Convex Neural Network for Image Classification</span>

This example shows how to create a fully input convex neural network and train it on CIFAR\-10 data. This example uses fully connected based convex networks, rather than the more typical convolutional networks, proven to give higher accuracy on the training and test data set. The aim of this example is to demonstrate the expressive capabilities convex constrained networks have by classifying natural images and demonstrating high accuracies on the training set. Further discussion on the expressive capabilities of convex networks for tasks including image classification can be found in \[1\].
This example shows how to create a fully input convex convolutional neural network and train it on CIFAR\-10 data \[1\].

# Prepare Data

Download the CIFAR\-10 data set \[1\]. The data set contains 60,000 images. Each image is 32\-by\-32 pixels in size and has three color channels (RGB). The size of the data set is 175 MB. Depending on your internet connection, the download process can take time.
Download the CIFAR\-10 data set \[2\]. The data set contains 60,000 images. Each image is 32\-by\-32 pixels in size and has three color channels (RGB). The size of the data set is 175 MB. Depending on your internet connection, the download process can take time.

```matlab
datadir = ".";
Expand All @@ -18,16 +18,6 @@ Load the CIFAR\-10 training and test images as 4\-D arrays. The training set con
[XTrain,TTrain,XTest,TTest] = loadCIFARData(datadir);
```

For illustration in this example, subsample this data set evenly in each class. You can increase the number of samples by moving the slider to smaller values.

```matlab
subSampleFrequency = 10;
XTrain = XTrain(:,:,:,1:subSampleFrequency:end);
XTest = XTest(:,:,:,1:subSampleFrequency:end);
TTrain = TTrain(1:subSampleFrequency:end);
TTest = TTest(1:subSampleFrequency:end);
```

You can display a random sample of the training images using the following code.

<pre>
Expand All @@ -37,42 +27,45 @@ im = imtile(XTrain(:,:,:,idx),ThumbnailSize=[96,96]);
imshow(im)
</pre>

# Define FICNN Network Architecture
# Define FICCNN Network Architecture

Use the <samp>buildConstrainedNetwork</samp> function to create a fully input convex neural network suitable for this data set.
Use the <samp>buildConvexCNN</samp> function to create a fully input convex convolutional neural network suitable for this data set.

- The CIFAR\-10 images are 32\-by\-32 pixels. Therefore, create a fully convex network specifying the <samp>inputSize=[32 32 3]</samp>.
- Specify a vector a hidden unit sizes of decreasing value in <samp>numHiddenUnits</samp>. The final number of outputs of the network must be equal to the number of classes, which in this example is 10.
- The CIFAR\-10 images are 32\-by\-32 pixels, and belong to one of ten classes. Therefore, create a fully convex network specifying the <samp>inputSize=[32 32 3]</samp> and the <samp>numClasses=10</samp>.
- For each convolutional layer, specify the filter size in <samp>filterSize</samp>, the number of filters in <samp>numFilters</samp>, and the stride size in <samp>stride</samp>.
```matlab
inputSize = [32 32 3];
numHiddenUnits = [512 128 32 10];
numClasses = 10;
filterSize = [3; 3; 3; 3; 3; 1; 1];
numFilters = [96; 96; 192; 192; 192; 192; 10];
stride = [1; 2; 1; 2; 1; 1; 1];
```

Seed the network initialization for reproducibility.
Seed the network initialization for reproducibility.

```matlab
rng(0);
ficnnet = buildConstrainedNetwork("fully-convex",inputSize,numHiddenUnits)
ficnnet = buildConvexCNN(inputSize, numClasses, filterSize, numFilters, Stride=stride)
```

```matlabTextOutput
ficnnet =
dlnetwork with properties:
Layers: [15x1 nnet.cnn.layer.Layer]
Connections: [17x2 table]
Learnables: [14x3 table]
State: [0x3 table]
InputNames: {'image_input'}
OutputNames: {'add_4'}
Layers: [24x1 nnet.cnn.layer.Layer]
Connections: [23x2 table]
Learnables: [30x3 table]
State: [14x3 table]
InputNames: {'input'}
OutputNames: {'fc_+_end'}
Initialized: 1
View summary with summary.
```

```matlab
plot(ficnnet)
plot(ficnnet);
```

<figure>
Expand All @@ -83,14 +76,15 @@ plot(ficnnet)

# Specify Training Options

Train for a specified number of epochs with a mini\-batch size of 256. To attain high training accuracy, you may need to train for a larger number of epochs, for example <samp>numEpochs=8000</samp>, which could take several hours.
Train for a specified number of epochs with a mini\-batch size of 256. To attain high training accuracy, you may need to train for a larger number of epochs, for example <samp>numEpochs=400</samp>, which could take several hours.

```matlab
numEpochs = 8000;
numEpochs = 400;
miniBatchSize = 256;
initialLearnRate = 0.1;
decay = 0.005;
initialLearnRate = 0.0025;
decay = eps;
lossMetric = "crossentropy";
l2Regularization = 1e-4;
```

Create a <samp>minibatchqueue</samp> object that processes and manages mini\-batches of images during training. For each mini\-batch:
Expand All @@ -103,6 +97,7 @@ Create a <samp>minibatchqueue</samp> object that processes and manages mini\-bat
xds = arrayDatastore(XTrain,IterationDimension=4);
tds = arrayDatastore(TTrain,IterationDimension=1);
cds = combine(xds,tds);
mbqTrain = minibatchqueue(cds,...
MiniBatchSize=miniBatchSize,...
MiniBatchFcn=@preprocessMiniBatch,...
Expand Down Expand Up @@ -160,7 +155,7 @@ disp("Training accuracy: " + (1-trainError)*100 + "%")
```

```matlabTextOutput
Training accuracy: 90.4848%
Training accuracy: 70.2123%
```

Compute the accuracy on the test set.
Expand All @@ -173,7 +168,7 @@ disp("Test accuracy: " + (1-testError)*100 + "%")
```

```matlabTextOutput
Test accuracy: 27.4554%
Test accuracy: 66.266%
```

The networks output has been constrained to be convex in every pixel in every colour. Even with this level of restriction, the network is able to fit reasonably well to the training data. You can see poor accuracy on the test data set but, as discussed at the start of the example, it is not anticipated that such a fully input convex network comprising of fully connected operations should generalize well to natural image classification.
Expand All @@ -197,14 +192,14 @@ cm.RowSummary = "row-normalized";

To summarise, the fully input convex network is able to fit to the training data set, which is labelled natural images. The training can take a considerable amount of time owing to the weight projection to the constrained set after each gradient update, which slows down training convergence. Nevertheless, this example illustrates the flexibility and expressivity convex neural networks have to correctly classifying natural images.

# Supporting Functions
## Mini Batch Preprocessing Function
# Supporting Functions
## Mini\-Batch Preprocessing Function

The <samp>preprocessMiniBatch</samp> function preprocesses a mini\-batch of predictors and labels using the following steps:
The <samp>preprocessMiniBatch</samp> function preprocesses a mini\-batch of predictions and labels using the following steps:

1. Preprocess the images using the <samp>preprocessMiniBatchPredictors</samp> function.
2. Extract the label data from the incoming cell array and concatenate into a categorical array along the second dimension.
3. One\-hot encode the categorical labels into numeric arrays. Encoding into the first dimension produces an encoded array that matches the shape of the network output.
3. One\-hot encode the categorical labels into numeric arrays. Encoding in the first dimension produces an encoded array that matches the shape of the network output.
```matlab
function [X,T] = preprocessMiniBatch(dataX,dataT)
Expand All @@ -219,19 +214,20 @@ T = onehotencode(T,1);
end
```
## Mini\-Batch Predictors Preprocessing Function
## Mini\-Batch Predictors Preprocessing Function

The <samp>preprocessMiniBatchPredictors</samp> function preprocesses a mini\-batch of predictors by extracting the image data from the input cell array and concatenate into a numeric array. For grayscale input, concatenating over the fourth dimension adds a third dimension to each image, to use as a singleton channel dimension. You divide by 255 to normalize the pixels to <samp>[0,1]</samp> range.
The <samp>preprocessMiniBatchPredictors</samp> function preprocesses a mini\-batch of predictors by extracting the image data from the input cell array and concatenating it into a numeric array. For grayscale input, concatenating over the fourth dimension adds a third dimension to each image, to use as a singleton channel dimension. You divide by 255 to normalize the pixels to <samp>[0,1]</samp> range.

```matlab
function X = preprocessMiniBatchPredictors(dataX)
X = single(cat(4,dataX{1:end}))/255;
X = (single(cat(4,dataX{1:end}))/255); % Normalizes to [0, 1]
X = 2*X - 1; % Normalizes to [-1, 1].
end
```
# References

\[1\] Amos, Brandon, et al. Input Convex Neural Networks. arXiv:1609.07152, arXiv, 14 June 2017. arXiv.org, https://doi.org/10.48550/arXiv.1609.07152.

# References
\[1\] Amos, Brandon, et al. "Input Convex Neural Networks." (2017). https://doi.org/10.48550/arXiv.1609.07152.

*Copyright 2024 The MathWorks, Inc.*
\[2\] Krizhevsky, Alex. "Learning multiple layers of features from tiny images." (2009). https://www.cs.toronto.edu/~kriz/learning-features-2009-TR.pdf

*Copyright 2024-2025 The MathWorks, Inc.*
Binary file not shown.
Binary file modified examples/convex/classificationCIFAR10/figures/TrainICNN_Fig1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 00f1e3c

Please sign in to comment.