Skip to content

Commit f749d4e

Browse files
authored
[nn-refactor] Refactor and reimplementation of NeuralNetwork (#749)
* adds temp DiyNeuralNetwork - refactoring NeuralNetwork class * adds updates to refactor * refactoring nn-refactor * adds features for compile and add layers * rm console log * adds train interface * adds basic predict * adds blank functions in data class * update nn class * adds nn compile handling * updates function name * adds data loading functions // todo - clean up * add recursive findEntries function and data loading functions * adds formatRawData function * adds .addData() function * adds saveData function * adds handling for onehot and counting input and output units " " * adds code comments * adds concat to this.meta * changed name to createMetaDataFromData" * adds convertRawToTensors * adds functions for calculating stats * adds normalization and conversion to tensor handling * adds .summarizeData * adds data handling to index * updates summarizeData function to explicitly set meta * updates and adds functions * updates predict function * adds classify() with meta * adds metadata handling and data functions * adds loadData with options in init * adds major updates to initiation and defaults * adds boolean flags to check status to configure nn * adds addData function to index * adds support for auto labeling inputs and outputs for blank nn * code cleanup and function name change * flattens array in cnvertRawToTensors * flattens inputs * flatten array always * adds isOneHotEncodedOrNormalized * updates predict and classify functions and output format * updates param handling in predict and classify * code cleanup * adds save function * code cleanup * adds first pass at loading data * fixes missing isNormalized flag in meta * moves loading functions to respective class * moves files to NeuralNetwork * moves files to NeuralNetwork and rm diyNN * rms console.log * check if metadata and warmedup are true before normalization * adds unnormalize function to nn predict * return unNormalized value * adds loadData() and changes to loadDataFromUrl * adds saveData to index * adds modelUrl to constructor options in index * cleans up predict and classify * fix reference to unNormalizeValue * code cleanup * adds looping to format data for prediction and predictMultiple and classifyMultiple * adds layer handling for options * adds tfvis to index and ml5 root * adds debug flag in options * adds vis and fixes input formatting" " * adds model summary * adds comments and reorders code * refactoring functions with 3 datatypes in mind: number, string, array * adds data handling updates * adds handling tensors * adds process up to training * fixes breaking training * adds full working poc * fix addData check * adds updates to api and notes to fix with functions * adds createMetadata in index * adds image handling in classify functino * adds method to not exceed call stack of min and max * fixes loadData issue * adds first header name for min and max * code cleanup * removes unused functions * fixes setDataRaw * code clean up, organization, and adds method binding to constructor * adds methods to constructor, adds comments, and cleans up * adds methods to constructor for nndata * adds methods to constructor, code cleanup, and organization
1 parent 538ec86 commit f749d4e

File tree

9 files changed

+2053
-1323
lines changed

9 files changed

+2053
-1323
lines changed

src/NeuralNetwork/NeuralNetwork.js

Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
import * as tf from '@tensorflow/tfjs';
2+
import callCallback from '../utils/callcallback';
3+
import { saveBlob } from '../utils/io';
4+
5+
class NeuralNetwork {
6+
constructor() {
7+
// flags
8+
this.isTrained = false;
9+
this.isCompiled = false;
10+
this.isLayered = false;
11+
// the model
12+
this.model = null;
13+
14+
// methods
15+
this.init = this.init.bind(this);
16+
this.createModel = this.createModel.bind(this);
17+
this.addLayer = this.addLayer.bind(this);
18+
this.compile = this.compile.bind(this);
19+
this.setOptimizerFunction = this.setOptimizerFunction.bind(this);
20+
this.train = this.train.bind(this);
21+
this.trainInternal = this.trainInternal.bind(this);
22+
this.predict = this.predict.bind(this);
23+
this.classify = this.classify.bind(this);
24+
this.save = this.save.bind(this);
25+
this.load = this.load.bind(this);
26+
27+
// initialize
28+
this.init();
29+
}
30+
31+
/**
32+
* initialize with create model
33+
*/
34+
init() {
35+
this.createModel();
36+
}
37+
38+
/**
39+
* creates a sequential model
40+
* uses switch/case for potential future where different formats are supported
41+
* @param {*} _type
42+
*/
43+
createModel(_type = 'sequential') {
44+
switch (_type.toLowerCase()) {
45+
case 'sequential':
46+
this.model = tf.sequential();
47+
return this.model;
48+
default:
49+
this.model = tf.sequential();
50+
return this.model;
51+
}
52+
}
53+
54+
/**
55+
* add layer to the model
56+
* if the model has 2 or more layers switch the isLayered flag
57+
* @param {*} _layerOptions
58+
*/
59+
addLayer(_layerOptions) {
60+
const LAYER_OPTIONS = _layerOptions || {};
61+
this.model.add(LAYER_OPTIONS);
62+
63+
// check if it has at least an input and output layer
64+
if (this.model.layers.length >= 2) {
65+
this.isLayered = true;
66+
}
67+
}
68+
69+
/**
70+
* Compile the model
71+
* if the model is compiled, set the isCompiled flag to true
72+
* @param {*} _modelOptions
73+
*/
74+
compile(_modelOptions) {
75+
this.model.compile(_modelOptions);
76+
this.isCompiled = true;
77+
}
78+
79+
/**
80+
* Set the optimizer function given the learning rate
81+
* as a paramter
82+
* @param {*} learningRate
83+
* @param {*} optimizer
84+
*/
85+
setOptimizerFunction(learningRate, optimizer) {
86+
return optimizer.call(this, learningRate);
87+
}
88+
89+
/**
90+
* Calls the trainInternal() and calls the callback when finished
91+
* @param {*} _options
92+
* @param {*} _cb
93+
*/
94+
train(_options, _cb) {
95+
return callCallback(this.trainInternal(_options), _cb);
96+
}
97+
98+
/**
99+
* Train the model
100+
* @param {*} _options
101+
*/
102+
async trainInternal(_options) {
103+
const TRAINING_OPTIONS = _options;
104+
105+
const xs = TRAINING_OPTIONS.inputs;
106+
const ys = TRAINING_OPTIONS.outputs;
107+
108+
const { batchSize, epochs, shuffle, validationSplit, whileTraining } = TRAINING_OPTIONS;
109+
110+
await this.model.fit(xs, ys, {
111+
batchSize,
112+
epochs,
113+
shuffle,
114+
validationSplit,
115+
callbacks: whileTraining,
116+
});
117+
118+
xs.dispose();
119+
ys.dispose();
120+
121+
this.isTrained = true;
122+
}
123+
124+
/**
125+
* returns the prediction as an array
126+
* @param {*} _inputs
127+
*/
128+
async predict(_inputs) {
129+
const output = tf.tidy(() => {
130+
return this.model.predict(_inputs);
131+
});
132+
const result = await output.array();
133+
134+
output.dispose();
135+
_inputs.dispose();
136+
137+
return result;
138+
}
139+
140+
/**
141+
* classify is the same as .predict()
142+
* @param {*} _inputs
143+
*/
144+
async classify(_inputs) {
145+
return this.predict(_inputs);
146+
}
147+
148+
// predictMultiple
149+
// classifyMultiple
150+
// are the same as .predict()
151+
152+
/**
153+
* save the model
154+
* @param {*} nameOrCb
155+
* @param {*} cb
156+
*/
157+
async save(nameOrCb, cb) {
158+
let modelName;
159+
let callback;
160+
161+
if (typeof nameOrCb === 'function') {
162+
modelName = 'model';
163+
callback = nameOrCb;
164+
} else if (typeof nameOrCb === 'string') {
165+
modelName = nameOrCb;
166+
167+
if (typeof cb === 'function') {
168+
callback = cb;
169+
}
170+
} else {
171+
modelName = 'model';
172+
}
173+
174+
this.model.save(
175+
tf.io.withSaveHandler(async data => {
176+
this.weightsManifest = {
177+
modelTopology: data.modelTopology,
178+
weightsManifest: [
179+
{
180+
paths: [`./${modelName}.weights.bin`],
181+
weights: data.weightSpecs,
182+
},
183+
],
184+
};
185+
186+
await saveBlob(data.weightData, `${modelName}.weights.bin`, 'application/octet-stream');
187+
await saveBlob(JSON.stringify(this.weightsManifest), `${modelName}.json`, 'text/plain');
188+
if (callback) {
189+
callback();
190+
}
191+
}),
192+
);
193+
}
194+
195+
/**
196+
* loads the model and weights
197+
* @param {*} filesOrPath
198+
* @param {*} callback
199+
*/
200+
async load(filesOrPath = null, callback) {
201+
if (filesOrPath instanceof FileList) {
202+
const files = await Promise.all(
203+
Array.from(filesOrPath).map(async file => {
204+
if (file.name.includes('.json') && !file.name.includes('_meta')) {
205+
return { name: 'model', file };
206+
} else if (file.name.includes('.json') && file.name.includes('_meta.json')) {
207+
const modelMetadata = await file.text();
208+
return { name: 'metadata', file: modelMetadata };
209+
} else if (file.name.includes('.bin')) {
210+
return { name: 'weights', file };
211+
}
212+
return { name: null, file: null };
213+
}),
214+
);
215+
216+
const model = files.find(item => item.name === 'model').file;
217+
const weights = files.find(item => item.name === 'weights').file;
218+
219+
// load the model
220+
this.model = await tf.loadLayersModel(tf.io.browserFiles([model, weights]));
221+
} else if (filesOrPath instanceof Object) {
222+
// filesOrPath = {model: URL, metadata: URL, weights: URL}
223+
224+
let modelJson = await fetch(filesOrPath.model);
225+
modelJson = await modelJson.text();
226+
const modelJsonFile = new File([modelJson], 'model.json', { type: 'application/json' });
227+
228+
let weightsBlob = await fetch(filesOrPath.weights);
229+
weightsBlob = await weightsBlob.blob();
230+
const weightsBlobFile = new File([weightsBlob], 'model.weights.bin', {
231+
type: 'application/macbinary',
232+
});
233+
234+
this.model = await tf.loadLayersModel(tf.io.browserFiles([modelJsonFile, weightsBlobFile]));
235+
} else {
236+
this.model = await tf.loadLayersModel(filesOrPath);
237+
}
238+
239+
this.isCompiled = true;
240+
this.isLayered = true;
241+
this.isTrained = true;
242+
243+
if (callback) {
244+
callback();
245+
}
246+
return this.model;
247+
}
248+
}
249+
export default NeuralNetwork;

0 commit comments

Comments
 (0)