-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Closed
Description
I'm trying to write some better examples for RNN and LSTM, but ran into some snags and decided to start very simply. In short, my LSTM networks were taking forever to train in my projects, so I went back to basics and created this benchmark:
const brain = require('brain.js');
const trainingData = [
{input: [0,0], output: [0]},
{input: [0,1], output: [1]},
{input: [1,0], output: [1]},
{input: [1,1], output: [0]}
];
// NN xor
const net = new brain.NeuralNetwork();
let now = Date.now();
let output = net.train(trainingData);
console.log('NN: trained output:', output);
console.log(`in ${Date.now() - now} ms`);
console.log('NN test [0,0]:[0]', Math.round(net.run([0,0])));
console.log('NN test [0,1]:[1]', Math.round(net.run([0,1])));
console.log('NN test [1,0]:[1]', Math.round(net.run([1,0])));
console.log('NN test [1,1]:[0]', Math.round(net.run([1,1])));
// RNN xor
now = Date.now();
const rnn = new brain.recurrent.RNN();
output = rnn.train(trainingData);
console.log('RNN: trained output:', output);
console.log(`in ${Date.now() - now} ms`);
console.log('RNN test [0,0]:[0]', rnn.run([0,0]));
console.log('RNN test [0,1]:[1]', rnn.run([0,1]));
console.log('RNN test [1,0]:[1]', rnn.run([1,0]));
console.log('RNN test [1,1]:[0]', rnn.run([1,1]));
// LSTM xor
const lstm = new brain.recurrent.LSTM();
now = Date.now();
output = lstm.train(trainingData);
console.log('LSTM: trained output:', output);
console.log(`in ${Date.now() - now} ms`);
console.log('LSTM test [0,0]:[0]', lstm.run([0,0]));
console.log('LSTM test [0,1]:[1]', lstm.run([0,1]));
console.log('LSTM test [1,0]:[1]', lstm.run([1,0]));
console.log('LSTM test [1,1]:[0]', lstm.run([1,1]));
The output is interesting (2015 MacBook pro):
NN: trained output: { error: 0.004995326394090512, iterations: 4116 }
in 23 ms
NN test [0,0]:[0] 0
NN test [0,1]:[1] 1
NN test [1,0]:[1] 1
NN test [1,1]:[0] 0
RNN: trained output: { error: 1.6177478953504192, iterations: 20000 }
in 5994 ms
RNN test [0,0]:[0] 0
RNN test [0,1]:[1] 1
RNN test [1,0]:[1] 1
RNN test [1,1]:[0] 1
LSTM: trained output: { error: 1.4159965869501594, iterations: 20000 }
in 26377 ms
LSTM test [0,0]:[0] 0
LSTM test [0,1]:[1] 1
LSTM test [1,0]:[1] 1
LSTM test [1,1]:[0] 0
- RNN sometimes gets them all right, but often times does not.
- Both RNN and LSTM have what look to me like high error values
Any idea what I might be doing wrong, or could do to improve their accuracy? I realize not much can be done for performance (in node) atm and that is fine. I don't mind waiting for them to train if I can get reasonable error values out of them, and of course, be able to trust that I'm doing it right.
Metadata
Metadata
Assignees
Labels
No labels