-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathIntegerAdditionTrainer.cs
93 lines (79 loc) · 3.96 KB
/
IntegerAdditionTrainer.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
using System;
using System.Threading.Tasks;
using BrightData;
using BrightWire;
using BrightWire.Models;
using BrightWire.TrainingData.Artificial;
namespace ExampleCode.DataTableTrainers
{
internal class IntegerAdditionTrainer(IDataTable data, IDataTable training, IDataTable test) : DataTableTrainer(data, training, test)
{
public async Task TrainRecurrentNeuralNetwork(bool writeResults = true)
{
var graph = _context.CreateGraphFactory();
// binary classification rounds each output to either 0 or 1
var errorMetric = graph.ErrorMetric.BinaryClassification;
// configure the network properties
graph.CurrentPropertySet
.Use(graph.GradientDescent.Adam)
.Use(graph.GaussianWeightInitialisation(false, 0.3f, GaussianVarianceCalibration.SquareRoot2N))
;
// create the engine
var trainingData = await graph.CreateDataSource(Training);
var testData = trainingData.CloneWith(Test);
var engine = graph.CreateTrainingEngine(trainingData, errorMetric, learningRate: 0.001f, batchSize: 16);
// build the network
const int hiddenLayerSize = 20, trainingIterations = 50;
graph.Connect(engine)
.AddSimpleRecurrent(graph.ReluActivation(), hiddenLayerSize)
.AddFeedForward(engine.DataSource.GetOutputSizeOrThrow())
.Add(graph.ReluActivation())
.AddBackpropagationThroughTime()
;
// train the network for twenty iterations, saving the model on each improvement
ExecutionGraphModel? bestGraph = null;
await engine.Train(trainingIterations, testData, bn => bestGraph = bn.Graph);
if (writeResults) {
// export the graph and verify it against some unseen integers on the best model
var executionEngine = graph.CreateExecutionEngine(bestGraph ?? engine.Graph);
var testData2 = await graph.CreateDataSource(await BinaryIntegers.Addition(_context, 8));
var results = await executionEngine.Execute(testData2, 128, null, true).ToListAsync();
// group the output
var groupedResults = new (float[][] Input, float[][] Target, float[][] Output)[8];
for (var i = 0; i < 8; i++) {
var input = new float[32][];
var target = new float[32][];
var output = new float[32][];
for (var j = 0; j < 32; j++) {
input[j] = results[j].Input![i].ToArray();
target[j] = results[j].Target![i].ToArray();
output[j] = results[j].Output[i].ToArray();
}
groupedResults[i] = (input, target, output);
}
// write the results
foreach (var (input, target, output) in groupedResults) {
Console.Write("First: ");
foreach (var item in input)
WriteAsBinary(item[0]);
Console.WriteLine();
Console.Write("Second: ");
foreach (var item in input)
WriteAsBinary(item[1]);
Console.WriteLine();
Console.WriteLine(" --------------------------------");
Console.Write("Expected: ");
foreach (var item in target)
WriteAsBinary(item[0]);
Console.WriteLine();
Console.Write("Predicted: ");
foreach (var item in output)
WriteAsBinary(item[0]);
Console.WriteLine();
Console.WriteLine();
}
}
}
static void WriteAsBinary(float value) => Console.Write(value >= 0.5 ? "1" : "0");
}
}