|
| 1 | +import AutoEncoder from "./auto-encoder"; |
| 2 | + |
| 3 | +const trainingData = [ |
| 4 | + [0, 0, 0], |
| 5 | + [0, 1, 1], |
| 6 | + [1, 0, 1], |
| 7 | + [1, 1, 0] |
| 8 | +]; |
| 9 | + |
| 10 | +const xornet = new AutoEncoder<number[], number[]>( |
| 11 | + { |
| 12 | + decodedSize: 3, |
| 13 | + hiddenLayers: [ 5, 2, 5 ] |
| 14 | + } |
| 15 | +); |
| 16 | + |
| 17 | +const errorThresh = 0.011; |
| 18 | + |
| 19 | +const result = xornet.train( |
| 20 | + trainingData, { |
| 21 | + iterations: 100000, |
| 22 | + errorThresh |
| 23 | + } |
| 24 | +); |
| 25 | + |
| 26 | +test( |
| 27 | + "denoise a data sample", |
| 28 | + async () => { |
| 29 | + expect(result.error).toBeLessThanOrEqual(errorThresh); |
| 30 | + |
| 31 | + function xor(...args: number[]) { |
| 32 | + return Math.round(xornet.denoise(args)[2]); |
| 33 | + } |
| 34 | + |
| 35 | + const run1 = xor(0, 0, 0); |
| 36 | + const run2 = xor(0, 1, 1); |
| 37 | + const run3 = xor(1, 0, 1); |
| 38 | + const run4 = xor(1, 1, 0); |
| 39 | + |
| 40 | + expect(run1).toBe(0); |
| 41 | + expect(run2).toBe(1); |
| 42 | + expect(run3).toBe(1); |
| 43 | + expect(run4).toBe(0); |
| 44 | + } |
| 45 | +); |
| 46 | + |
| 47 | +test( |
| 48 | + "encode and decode a data sample", |
| 49 | + async () => { |
| 50 | + expect(result.error).toBeLessThanOrEqual(errorThresh); |
| 51 | + |
| 52 | + const run1$input = [0, 0, 0]; |
| 53 | + const run1$encoded = xornet.encode(run1$input); |
| 54 | + const run1$decoded = xornet.decode(run1$encoded); |
| 55 | + |
| 56 | + const run2$input = [0, 1, 1]; |
| 57 | + const run2$encoded = xornet.encode(run2$input); |
| 58 | + const run2$decoded = xornet.decode(run2$encoded); |
| 59 | + |
| 60 | + for (let i = 0; i < 3; i++) expect(Math.round(run1$decoded[i])).toBe(run1$input[i]); |
| 61 | + for (let i = 0; i < 3; i++) expect(Math.round(run2$decoded[i])).toBe(run2$input[i]); |
| 62 | + } |
| 63 | +); |
| 64 | + |
| 65 | +test( |
| 66 | + "test a data sample for anomalies", |
| 67 | + async () => { |
| 68 | + expect(result.error).toBeLessThanOrEqual(errorThresh); |
| 69 | + |
| 70 | + function includesAnomalies(...args: number[]) { |
| 71 | + expect(xornet.includesAnomalies(args)).toBe(false); |
| 72 | + } |
| 73 | + |
| 74 | + includesAnomalies(0, 0, 0); |
| 75 | + includesAnomalies(0, 1, 1); |
| 76 | + includesAnomalies(1, 0, 1); |
| 77 | + includesAnomalies(1, 1, 0); |
| 78 | + } |
| 79 | +); |
0 commit comments