diff --git a/conslearn/trainConstrainedNetwork.m b/conslearn/trainConstrainedNetwork.m index 15c7e92..61ae9ab 100644 --- a/conslearn/trainConstrainedNetwork.m +++ b/conslearn/trainConstrainedNetwork.m @@ -147,7 +147,8 @@ % Evaluate the model gradients, and loss using dlfeval and the % modelLoss function and update the network state. - [lossTrain,gradients] = dlfeval(@iModelLoss,net,X,T,metric); + [lossTrain,gradients,state] = dlfeval(@iModelLoss,net,X,T,metric); + net.State = state; % Gradient Update [net,avgG,avgSqG] = adamupdate(net,gradients,avgG,avgSqG,epoch,learnRate); @@ -180,8 +181,12 @@ end %% Helpers -function [loss,gradients] = iModelLoss(net,X,T,metric) -Y = forward(net,X); +function [loss,gradients,state] = iModelLoss(net,X,T,metric) + +% Make a forward pass +[Y,state] = forward(net,X); + +% Compute the loss switch metric case "mse" loss = mse(Y,T); @@ -190,6 +195,8 @@ case "crossentropy" loss = crossentropy(softmax(Y),T); end + +% Compute the gradient of the loss with respect to the learnabless gradients = dlgradient(loss,net.Learnables); end