Skip to content

Commit

Permalink
Adds state update for training with batch normalization.
Browse files Browse the repository at this point in the history
  • Loading branch information
jkeeley-MW committed Nov 11, 2024
1 parent fc585b1 commit 0c7b986
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions conslearn/trainConstrainedNetwork.m
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,8 @@

% Evaluate the model gradients, and loss using dlfeval and the
% modelLoss function and update the network state.
[lossTrain,gradients] = dlfeval(@iModelLoss,net,X,T,metric);
[lossTrain,gradients,state] = dlfeval(@iModelLoss,net,X,T,metric);
net.State = state;

% Gradient Update
[net,avgG,avgSqG] = adamupdate(net,gradients,avgG,avgSqG,epoch,learnRate);
Expand Down Expand Up @@ -180,8 +181,12 @@
end

%% Helpers
function [loss,gradients] = iModelLoss(net,X,T,metric)
Y = forward(net,X);
function [loss,gradients,state] = iModelLoss(net,X,T,metric)

% Make a forward pass
[Y,state] = forward(net,X);

% Compute the loss
switch metric
case "mse"
loss = mse(Y,T);
Expand All @@ -190,6 +195,8 @@
case "crossentropy"
loss = crossentropy(softmax(Y),T);
end

% Compute the gradient of the loss with respect to the learnabless
gradients = dlgradient(loss,net.Learnables);
end

Expand Down

0 comments on commit 0c7b986

Please sign in to comment.