diff --git a/lib/nn.js b/lib/nn.js index f678930..c6e3ad2 100644 --- a/lib/nn.js +++ b/lib/nn.js @@ -26,6 +26,10 @@ class NeuralNetwork { this.bias_h.randomize(); this.bias_o.randomize(); this.setLearningRate(); + + this.setActivationFunction(); + this.setDActivationFunction(); + } predict(input_array) { @@ -35,12 +39,12 @@ class NeuralNetwork { let hidden = Matrix.multiply(this.weights_ih, inputs); hidden.add(this.bias_h); // activation function! - hidden.map(sigmoid); + hidden.map(this.activation_function); // Generating the output's output! let output = Matrix.multiply(this.weights_ho, hidden); output.add(this.bias_o); - output.map(sigmoid); + output.map(this.activation_function); // Sending back to the caller! return output.toArray(); @@ -49,6 +53,14 @@ class NeuralNetwork { setLearningRate(learning_rate = 0.1) { this.learning_rate = learning_rate; } + + setActivationFunction(Fun = sigmoid) { + this.activation_function = Fun; + } + + setDActivationFunction(dFun = dsigmoid) { + this.d_activation_function = dFun; + } train(input_array, target_array) { // Generating the Hidden Outputs @@ -56,12 +68,12 @@ class NeuralNetwork { let hidden = Matrix.multiply(this.weights_ih, inputs); hidden.add(this.bias_h); // activation function! - hidden.map(sigmoid); + hidden.map(this.activation_function); // Generating the output's output! let outputs = Matrix.multiply(this.weights_ho, hidden); outputs.add(this.bias_o); - outputs.map(sigmoid); + outputs.map(this.activation_function); // Convert array to matrix object let targets = Matrix.fromArray(target_array); @@ -72,7 +84,7 @@ class NeuralNetwork { // let gradient = outputs * (1 - outputs); // Calculate gradient - let gradients = Matrix.map(outputs, dsigmoid); + let gradients = Matrix.map(outputs, this.d_activation_function); gradients.multiply(output_errors); gradients.multiply(this.learning_rate); @@ -91,7 +103,7 @@ class NeuralNetwork { let hidden_errors = Matrix.multiply(who_t, output_errors); // Calculate hidden gradient - let hidden_gradient = Matrix.map(hidden, dsigmoid); + let hidden_gradient = Matrix.map(hidden, this.d_activation_function); hidden_gradient.multiply(hidden_errors); hidden_gradient.multiply(this.learning_rate);