Skip to content

Commit 2a7840a

Browse files
committed
Add memory function
1 parent 7371a23 commit 2a7840a

File tree

3 files changed

+88
-38
lines changed

3 files changed

+88
-38
lines changed

src/neural-network-gpu.ts

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,37 @@ import {
2222
NeuralNetwork,
2323
} from './neural-network';
2424
import { release } from './utilities/kernel';
25-
import { LossFunction, LossFunctionInputs, LossFunctionState } from './utilities/loss';
25+
import { LossFunction, LossFunctionInputs, MemoryFunction, NeuralNetworkMemory } from './utilities/loss';
2626

2727
function loss(
28+
this: IKernelFunctionThis,
2829
actual: number,
2930
expected: number,
3031
inputs: LossFunctionInputs,
31-
state: LossFunctionState
32+
memory: NeuralNetworkMemory
3233
) {
3334
return expected - actual;
3435
}
3536

37+
function updateMemory(
38+
this: IKernelFunctionThis,
39+
actual: number,
40+
expected: number,
41+
inputs: LossFunctionInputs,
42+
memory: NeuralNetworkMemory,
43+
memorySize: number,
44+
loss: number
45+
) {
46+
const layer = this.thread.z;
47+
const neuron = this.thread.y;
48+
const signal = this.thread.x;
49+
50+
// Maintain the same signal magnitude.
51+
return memory[layer][neuron][signal];
52+
}
53+
3654
const DEFAULT_LOSS_FUNCTION = loss;
55+
const DEFAULT_MEMORY_FUNCTION = updateMemory;
3756

3857
export interface INeuralNetworkGPUDatumFormatted {
3958
input: KernelOutput;
@@ -402,7 +421,7 @@ export class NeuralNetworkGPU<
402421
};
403422

404423
buildCalculateDeltas(): void {
405-
let calcDeltas: GPUFunction<[number, number, LossFunctionInputs, LossFunctionState]>;
424+
let calcDeltas: GPUFunction<[number, number, LossFunctionInputs, NeuralNetworkMemory]>;
406425
switch (this.trainOpts.activation) {
407426
case 'sigmoid':
408427
calcDeltas = calcDeltasSigmoid;
@@ -423,6 +442,7 @@ export class NeuralNetworkGPU<
423442
}
424443

425444
const loss: LossFunction = this._lossFunction ?? DEFAULT_LOSS_FUNCTION;
445+
const updateMemory: MemoryFunction = this._memoryFunction ?? DEFAULT_MEMORY_FUNCTION;
426446

427447
calcDeltas = alias(
428448
utils.getMinifySafeName(() => calcDeltas),
@@ -436,14 +456,14 @@ export class NeuralNetworkGPU<
436456
// @ts-expect-error
437457
this.backwardPropagate[this.outputLayer] = this.gpu.createKernelMap(
438458
{
439-
error: calcErrorOutput,
459+
error: calcErrorOutput
440460
},
441461
function (
442462
this: IKernelFunctionThis,
443463
outputs: number[],
444464
targets: number[],
445465
inputs: LossFunctionInputs,
446-
state: LossFunctionState
466+
state: NeuralNetworkMemory
447467
): number {
448468
const output = outputs[this.thread.x];
449469
const target = targets[this.thread.x];
@@ -503,9 +523,8 @@ export class NeuralNetworkGPU<
503523

504524
let output;
505525
if (layer === this.outputLayer) {
506-
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
507-
// @ts-expect-error
508-
output = this.backwardPropagate[layer](this.outputs[layer], target, this.outputs[0], this.lossState);
526+
// @ts-ignore
527+
output = this.backwardPropagate[layer](this.outputs[layer], target, this.outputs[0], this.memory);
509528
} else {
510529
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
511530
// @ts-expect-error
@@ -731,11 +750,17 @@ export class NeuralNetworkGPU<
731750
: (layerBiases as Float32Array)
732751
)
733752
);
753+
const jsonLayerMemory = this.memory.map((layerMemory, layerIndex) =>
754+
layerMemory.map(nodeMemory =>
755+
Array.from(nodeMemory)
756+
)
757+
);
734758
const jsonLayers: IJSONLayer[] = [];
735759
for (let i = 0; i <= this.outputLayer; i++) {
736760
jsonLayers.push({
737761
weights: jsonLayerWeights[i] ?? [],
738762
biases: jsonLayerBiases[i] ?? [],
763+
memory: jsonLayerMemory[i] ?? []
739764
});
740765
}
741766
return {

src/neural-network.ts

Lines changed: 43 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import { max } from './utilities/max';
1212
import { mse } from './utilities/mse';
1313
import { randos } from './utilities/randos';
1414
import { zeros } from './utilities/zeros';
15-
import { LossFunction, LossFunctionInputs, LossFunctionState } from './utilities/loss';
15+
import { LossFunction, LossFunctionInputs, MemoryFunction, NeuralNetworkMemory } from './utilities/loss';
1616
type NeuralNetworkFormatter =
1717
| ((v: INumberHash) => Float32Array)
1818
| ((v: number[]) => Float32Array);
@@ -44,7 +44,7 @@ function loss(
4444
actual: number,
4545
expected: number,
4646
inputs: LossFunctionInputs,
47-
state: LossFunctionState
47+
state: NeuralNetworkMemory
4848
) {
4949
return expected - actual;
5050
}
@@ -58,6 +58,7 @@ export type NeuralNetworkActivation =
5858
export interface IJSONLayer {
5959
biases: number[];
6060
weights: number[][];
61+
memory: number[][];
6162
}
6263

6364
export interface INeuralNetworkJSON {
@@ -77,15 +78,15 @@ export interface INeuralNetworkOptions {
7778
outputSize: number;
7879
binaryThresh: number;
7980
hiddenLayers?: number[];
80-
lossStateSize: number;
81+
memorySize: number;
8182
}
8283

8384
export function defaults(): INeuralNetworkOptions {
8485
return {
8586
inputSize: 0,
8687
outputSize: 0,
8788
binaryThresh: 0.5,
88-
lossStateSize: 1
89+
memorySize: 1
8990
};
9091
}
9192

@@ -119,8 +120,8 @@ export interface INeuralNetworkTrainOptions {
119120
log: boolean | ((status: INeuralNetworkState) => void);
120121
logPeriod: number;
121122
loss?: LossFunction;
122-
lossState?: LossFunctionState;
123-
lossStateSize: number;
123+
memory?: MemoryFunction;
124+
memorySize: number;
124125
leakyReluAlpha: number;
125126
learningRate: number;
126127
momentum: number;
@@ -141,7 +142,7 @@ export function trainDefaults(): INeuralNetworkTrainOptions {
141142
log: false, // true to use console.log, when a function is supplied it is used
142143
logPeriod: 10, // iterations between logging out
143144
loss,
144-
lossStateSize: 1,
145+
memorySize: 1,
145146
leakyReluAlpha: 0.01,
146147
learningRate: 0.3, // multiply's against the input and the delta then adds to momentum
147148
momentum: 0.1, // multiply's against the specified "change" then adds to learning rate for change
@@ -192,7 +193,7 @@ export class NeuralNetwork<
192193
_formatInput: NeuralNetworkFormatter | null = null;
193194
_formatOutput: NeuralNetworkFormatter | null = null;
194195

195-
_lossState: LossFunctionState;
196+
_memory: NeuralNetworkMemory;
196197

197198
runInput: (input: Float32Array) => Float32Array = (input: Float32Array) => {
198199
this.setActivation();
@@ -208,6 +209,7 @@ export class NeuralNetwork<
208209
};
209210

210211
_lossFunction?: LossFunction;
212+
_memoryFunction?: MemoryFunction;
211213

212214
// adam
213215
biasChangesLow: Float32Array[] = [];
@@ -227,8 +229,9 @@ export class NeuralNetwork<
227229
this.sizes = [inputSize].concat(hiddenLayers ?? []).concat([outputSize]);
228230
}
229231

230-
const { lossStateSize } = this.options ?? 0;
231-
this._lossState = this.trainOpts.lossState ?? this.replaceLossState(lossStateSize);
232+
// Initialize memory matrix
233+
const { memorySize } = this.options ?? 0;
234+
this._memory = this.replaceMemory(memorySize);
232235
}
233236

234237
/**
@@ -305,8 +308,8 @@ export class NeuralNetwork<
305308
return this.sizes.length > 0;
306309
}
307310

308-
public get lossState(): LossFunctionState {
309-
return this._lossState;
311+
public get memory(): NeuralNetworkMemory {
312+
return this._memory;
310313
}
311314

312315
run(input: Partial<InputType>): OutputType {
@@ -772,7 +775,7 @@ export class NeuralNetwork<
772775
const kernelFunctionThis = { thread: { x: node, y: layer, z: 0 } };
773776
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
774777
// @ts-ignore
775-
error = this._lossFunction.call(kernelFunctionThis, output, target[node], input, this.lossState);
778+
error = this._lossFunction.call(kernelFunctionThis, output, target[node], input, this.memory);
776779
}
777780
else error = target[node] - output;
778781
} else {
@@ -805,7 +808,7 @@ export class NeuralNetwork<
805808
const kernelFunctionThis = { thread: { x: node, y: layer, z: 0 } };
806809
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
807810
// @ts-ignore
808-
error = this._lossFunction.call(kernelFunctionThis, output, target[node], input, this.lossState);
811+
error = this._lossFunction.call(kernelFunctionThis, output, target[node], input, this.memory);
809812
}
810813
else error = target[node] - output;
811814
} else {
@@ -838,7 +841,7 @@ export class NeuralNetwork<
838841
const kernelFunctionThis = { thread: { x: node, y: layer, z: 0 } };
839842
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
840843
// @ts-ignore
841-
error = this._lossFunction.call(kernelFunctionThis, output, target[node], input, this.lossState);
844+
error = this._lossFunction.call(kernelFunctionThis, output, target[node], input, this.memory);
842845
}
843846
else error = target[node] - output;
844847
} else {
@@ -870,7 +873,7 @@ export class NeuralNetwork<
870873
const kernelFunctionThis = { thread: { x: node, y: layer, z: 0 } };
871874
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
872875
// @ts-ignore
873-
error = this._lossFunction.call(kernelFunctionThis, output, target[node], input, this.lossState);
876+
error = this._lossFunction.call(kernelFunctionThis, output, target[node], input, this.memory);
874877
}
875878
else error = target[node] - output;
876879
} else {
@@ -1231,12 +1234,18 @@ export class NeuralNetwork<
12311234
const jsonLayerBiases = this.biases.map((layerBiases) =>
12321235
Array.from(layerBiases)
12331236
);
1237+
const jsonLayerMemory = this.memory.map(layerMemory =>
1238+
layerMemory.map(
1239+
nodeMemory => Array.from(nodeMemory)
1240+
)
1241+
);
12341242
const jsonLayers: IJSONLayer[] = [];
12351243
const outputLength = this.sizes.length - 1;
12361244
for (let i = 0; i <= outputLength; i++) {
12371245
jsonLayers.push({
12381246
weights: jsonLayerWeights[i] ?? [],
12391247
biases: jsonLayerBiases[i] ?? [],
1248+
memory: jsonLayerMemory[i] ?? []
12401249
});
12411250
}
12421251
return {
@@ -1281,9 +1290,15 @@ export class NeuralNetwork<
12811290
const layerBiases = this.biases.map((layerBiases, layerIndex) =>
12821291
Float32Array.from(jsonLayers[layerIndex].biases)
12831292
);
1293+
const layerMemory = this.memory.map((memory, layerIndex) =>
1294+
Array.from(jsonLayers[layerIndex].memory).map(nodeMemory =>
1295+
Float32Array.from(nodeMemory)
1296+
)
1297+
);
12841298
for (let i = 0; i <= this.outputLayer; i++) {
12851299
this.weights[i] = layerWeights[i] || [];
12861300
this.biases[i] = layerBiases[i] || [];
1301+
this.memory[i] = layerMemory[i] || [];
12871302
}
12881303
return this;
12891304
}
@@ -1387,23 +1402,23 @@ export class NeuralNetwork<
13871402
) => OutputType;
13881403
}
13891404

1390-
private createLossState(
1391-
lossStateSize: number
1392-
): LossFunctionState {
1393-
const lossState: LossFunctionState = [];
1405+
private createMemory(
1406+
memorySize: number
1407+
): NeuralNetworkMemory {
1408+
const memory: NeuralNetworkMemory = [];
13941409
for (let layer = 0; layer < this.sizes.length; layer++) {
1395-
lossState[layer] = [];
1410+
memory[layer] = [];
13961411
for (let neuron = 0; neuron < this.sizes.length; neuron++) {
1397-
lossState[layer][neuron] = new Float32Array(lossStateSize);
1412+
memory[layer][neuron] = new Float32Array(memorySize);
13981413
}
13991414
}
1400-
return lossState;
1415+
return memory;
14011416
}
14021417

1403-
private replaceLossState(
1404-
lossState: number | LossFunctionState
1405-
): LossFunctionState {
1406-
if (typeof lossState === "number") return this._lossState = this.createLossState(lossState);
1407-
return this._lossState = lossState;
1418+
private replaceMemory(
1419+
memory: number | NeuralNetworkMemory
1420+
): NeuralNetworkMemory {
1421+
if (typeof memory === "number") return this._memory = this.createMemory(memory);
1422+
return this._memory = memory;
14081423
}
14091424
}

src/utilities/loss.ts

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,22 @@ import { IKernelFunctionThis } from "gpu.js";
22

33
export type LossFunctionInputs = number[] | number[][] | number[][][] | Float32Array | Float32Array[] | Float32Array[][];
44

5-
export type LossFunctionState = number[][][] | Float32Array[][];
5+
export type NeuralNetworkMemory = Float32Array[][];
66

77
export type LossFunction = (
88
this: IKernelFunctionThis,
99
actual: number,
1010
expected: number,
1111
inputs: LossFunctionInputs,
12-
state: LossFunctionState
12+
memory: NeuralNetworkMemory
13+
) => number;
14+
15+
export type MemoryFunction = (
16+
this: IKernelFunctionThis,
17+
actual: number,
18+
expected: number,
19+
inputs: LossFunctionInputs,
20+
memory: NeuralNetworkMemory,
21+
memorySize: number,
22+
loss: number
1323
) => number;

0 commit comments

Comments
 (0)