Skip to content

Commit 550ac02

Browse files
authored
Added UpdateModel, publishModel,and unpublishModel functionality + tests (#791)
* Added UpdateModel, publishModel,and unpublishModel functionality plus tests
1 parent 27f4fb2 commit 550ac02

File tree

5 files changed

+741
-51
lines changed

5 files changed

+741
-51
lines changed

src/machine-learning/machine-learning-api-client.ts

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,20 @@ export interface StatusErrorResponse {
3131
readonly message: string;
3232
}
3333

34+
/**
35+
* A Firebase ML Model input object
36+
*/
37+
export interface ModelOptions {
38+
displayName?: string;
39+
tags?: string[];
40+
41+
tfliteModel?: { gcsTfliteUri: string; };
42+
}
43+
44+
export interface ModelUpdateOptions extends ModelOptions {
45+
state?: { published?: boolean; };
46+
}
47+
3448
export interface ModelContent {
3549
readonly displayName?: string;
3650
readonly tags?: string[];
@@ -80,7 +94,7 @@ export class MachineLearningApiClient {
8094
this.httpClient = new AuthorizedHttpClient(app);
8195
}
8296

83-
public createModel(model: ModelContent): Promise<OperationResponse> {
97+
public createModel(model: ModelOptions): Promise<OperationResponse> {
8498
if (!validator.isNonNullObject(model) ||
8599
!validator.isNonEmptyString(model.displayName)) {
86100
const err = new FirebaseMachineLearningError('invalid-argument', 'Invalid model content.');
@@ -97,6 +111,24 @@ export class MachineLearningApiClient {
97111
});
98112
}
99113

114+
public updateModel(modelId: string, model: ModelUpdateOptions, updateMask: string[]): Promise<OperationResponse> {
115+
if (!validator.isNonEmptyString(modelId) ||
116+
!validator.isNonNullObject(model) ||
117+
!validator.isNonEmptyArray(updateMask)) {
118+
const err = new FirebaseMachineLearningError('invalid-argument', 'Invalid model or mask content.');
119+
return Promise.reject(err);
120+
}
121+
return this.getUrl()
122+
.then((url) => {
123+
const request: HttpRequestConfig = {
124+
method: 'PATCH',
125+
url: `${url}/models/${modelId}?updateMask=${updateMask.join()}`,
126+
data: model,
127+
};
128+
return this.sendRequest<OperationResponse>(request);
129+
});
130+
}
131+
100132

101133
public getModel(modelId: string): Promise<ModelResponse> {
102134
return Promise.resolve()

src/machine-learning/machine-learning.ts

Lines changed: 27 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,14 @@
1616

1717
import {FirebaseApp} from '../firebase-app';
1818
import {FirebaseServiceInterface, FirebaseServiceInternalsInterface} from '../firebase-service';
19-
import {MachineLearningApiClient, ModelResponse, OperationResponse, ModelContent} from './machine-learning-api-client';
19+
import {MachineLearningApiClient, ModelResponse, OperationResponse,
20+
ModelOptions, ModelUpdateOptions} from './machine-learning-api-client';
2021
import {FirebaseError} from '../utils/error';
2122

2223
import * as validator from '../utils/validator';
2324
import {FirebaseMachineLearningError} from './machine-learning-utils';
2425
import { deepCopy } from '../utils/deep-copy';
26+
import * as utils from '../utils';
2527

2628
/**
2729
* Internals of an ML instance.
@@ -95,7 +97,7 @@ export class MachineLearning implements FirebaseServiceInterface {
9597
* @return {Promise<Model>} A Promise fulfilled with the created model.
9698
*/
9799
public createModel(model: ModelOptions): Promise<Model> {
98-
return this.convertOptionstoContent(model, true)
100+
return this.signUrlIfPresent(model)
99101
.then((modelContent) => this.client.createModel(modelContent))
100102
.then((operation) => handleOperation(operation));
101103
}
@@ -109,8 +111,11 @@ export class MachineLearning implements FirebaseServiceInterface {
109111
* @return {Promise<Model>} A Promise fulfilled with the updated model.
110112
*/
111113
public updateModel(modelId: string, model: ModelOptions): Promise<Model> {
112-
throw new Error('NotImplemented');
113-
}
114+
const updateMask = utils.generateUpdateMask(model);
115+
return this.signUrlIfPresent(model)
116+
.then((modelContent) => this.client.updateModel(modelId, modelContent, updateMask))
117+
.then((operation) => handleOperation(operation));
118+
}
114119

115120
/**
116121
* Publishes a model in Firebase ML.
@@ -120,7 +125,7 @@ export class MachineLearning implements FirebaseServiceInterface {
120125
* @return {Promise<Model>} A Promise fulfilled with the published model.
121126
*/
122127
public publishModel(modelId: string): Promise<Model> {
123-
throw new Error('NotImplemented');
128+
return this.setPublishStatus(modelId, true);
124129
}
125130

126131
/**
@@ -131,7 +136,7 @@ export class MachineLearning implements FirebaseServiceInterface {
131136
* @return {Promise<Model>} A Promise fulfilled with the unpublished model.
132137
*/
133138
public unpublishModel(modelId: string): Promise<Model> {
134-
throw new Error('NotImplemented');
139+
return this.setPublishStatus(modelId, false);
135140
}
136141

137142
/**
@@ -143,9 +148,7 @@ export class MachineLearning implements FirebaseServiceInterface {
143148
*/
144149
public getModel(modelId: string): Promise<Model> {
145150
return this.client.getModel(modelId)
146-
.then((modelResponse) => {
147-
return new Model(modelResponse);
148-
});
151+
.then((modelResponse) => new Model(modelResponse));
149152
}
150153

151154
/**
@@ -171,23 +174,28 @@ export class MachineLearning implements FirebaseServiceInterface {
171174
return this.client.deleteModel(modelId);
172175
}
173176

174-
private convertOptionstoContent(options: ModelOptions, forUpload?: boolean): Promise<ModelContent> {
175-
const modelContent = deepCopy(options);
177+
private setPublishStatus(modelId: string, publish: boolean): Promise<Model> {
178+
const updateMask = ['state.published'];
179+
const options: ModelUpdateOptions = {state: {published: publish}};
180+
return this.client.updateModel(modelId, options, updateMask)
181+
.then((operation) => handleOperation(operation));
182+
}
176183

177-
if (forUpload && modelContent.tfliteModel?.gcsTfliteUri) {
178-
return this.signUrl(modelContent.tfliteModel.gcsTfliteUri)
184+
private signUrlIfPresent(options: ModelOptions): Promise<ModelOptions> {
185+
const modelOptions = deepCopy(options);
186+
if (modelOptions.tfliteModel?.gcsTfliteUri) {
187+
return this.signUrl(modelOptions.tfliteModel.gcsTfliteUri)
179188
.then ((uri: string) => {
180-
modelContent.tfliteModel!.gcsTfliteUri = uri;
181-
return modelContent;
189+
modelOptions.tfliteModel!.gcsTfliteUri = uri;
190+
return modelOptions;
182191
})
183192
.catch((err: Error) => {
184193
throw new FirebaseMachineLearningError(
185194
'internal-error',
186195
`Error during signing upload url: ${err.message}`);
187-
}) as Promise<ModelContent>;
196+
});
188197
}
189-
190-
return Promise.resolve(modelContent) as Promise<ModelContent>;
198+
return Promise.resolve(modelOptions);
191199
}
192200

193201
private signUrl(unsignedUrl: string): Promise<string> {
@@ -208,9 +216,7 @@ export class MachineLearning implements FirebaseServiceInterface {
208216
return blob.getSignedUrl({
209217
action: 'read',
210218
expires: Date.now() + URL_VALID_DURATION,
211-
}).then((x) => {
212-
return x[0];
213-
});
219+
}).then((signUrl) => signUrl[0]);
214220
}
215221
}
216222

@@ -287,23 +293,10 @@ export interface TFLiteModel {
287293
readonly gcsTfliteUri: string;
288294
}
289295

290-
291-
/**
292-
* A Firebase ML Model input object
293-
*/
294-
export class ModelOptions {
295-
public displayName?: string;
296-
public tags?: string[];
297-
298-
public tfliteModel?: { gcsTfliteUri: string; };
299-
}
300-
301-
302296
function extractModelId(resourceName: string): string {
303297
return resourceName.split('/').pop()!;
304298
}
305299

306-
307300
function handleOperation(op: OperationResponse): Model {
308301
// Backend currently does not return operations that are not done.
309302
if (op.done) {

test/integration/machine-learning.spec.ts

Lines changed: 177 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,178 @@ describe('admin.machineLearning', () => {
127127
});
128128
});
129129

130+
describe('updateModel()', () => {
131+
132+
const UPDATE_NAME: admin.machineLearning.ModelOptions = {
133+
displayName: 'update-model-new-name',
134+
};
135+
136+
it('rejects with not-found when the Model does not exist', () => {
137+
const nonExistingId = '00000000';
138+
return admin.machineLearning().updateModel(nonExistingId, UPDATE_NAME)
139+
.should.eventually.be.rejected.and.have.property(
140+
'code', 'machine-learning/not-found');
141+
});
142+
143+
it('rejects with invalid-argument when the ModelId is invalid', () => {
144+
return admin.machineLearning().updateModel('invalid-model-id', UPDATE_NAME)
145+
.should.eventually.be.rejected.and.have.property(
146+
'code', 'machine-learning/invalid-argument');
147+
});
148+
149+
it ('rejects with invalid-argument when modelOptions are invalid', () => {
150+
const modelOptions: admin.machineLearning.ModelOptions = {
151+
displayName: 'Invalid Name#*^!',
152+
};
153+
return createTemporaryModel({displayName: 'node-integration-invalid-argument'})
154+
.then((model) => admin.machineLearning().updateModel(model.modelId, modelOptions)
155+
.should.eventually.be.rejected.and.have.property(
156+
'code', 'machine-learning/invalid-argument'));
157+
});
158+
159+
it('updates the displayName', () => {
160+
const DISPLAY_NAME = 'node-integration-test-update-1b';
161+
return createTemporaryModel({displayName: 'node-integration-test-update-1a'})
162+
.then((model) => {
163+
const modelOptions: admin.machineLearning.ModelOptions = {
164+
displayName: DISPLAY_NAME,
165+
};
166+
return admin.machineLearning().updateModel(model.modelId, modelOptions)
167+
.then((updatedModel) => {
168+
verifyModel(updatedModel, modelOptions);
169+
});
170+
});
171+
});
172+
173+
it('sets tags for a model', () => {
174+
// TODO(ifielker): Uncomment & replace when BE change lands.
175+
// const ORIGINAL_TAGS = ['tag-node-update-1'];
176+
const ORIGINAL_TAGS: string[] = [];
177+
const NEW_TAGS = ['tag-node-update-2', 'tag-node-update-3'];
178+
179+
return createTemporaryModel({
180+
displayName: 'node-integration-test-update-2',
181+
tags: ORIGINAL_TAGS,
182+
}).then((expectedModel) => {
183+
const modelOptions: admin.machineLearning.ModelOptions = {
184+
tags: NEW_TAGS,
185+
};
186+
return admin.machineLearning().updateModel(expectedModel.modelId, modelOptions)
187+
.then((actualModel) => {
188+
expect(actualModel.tags!.length).to.equal(2);
189+
expect(actualModel.tags).to.have.same.members(NEW_TAGS);
190+
});
191+
});
192+
});
193+
194+
it('updates the tflite file', () => {
195+
Promise.all([
196+
createTemporaryModel(),
197+
uploadModelToGcs('model1.tflite', 'valid_model.tflite')])
198+
.then(([model, fileName]) => {
199+
const modelOptions: admin.machineLearning.ModelOptions = {
200+
tfliteModel: {gcsTfliteUri: fileName},
201+
};
202+
return admin.machineLearning().updateModel(model.modelId, modelOptions)
203+
.then((updatedModel) => {
204+
verifyModel(updatedModel, modelOptions);
205+
});
206+
});
207+
});
208+
209+
it('can update more than 1 field', () => {
210+
const DISPLAY_NAME = 'node-integration-test-update-3b';
211+
const TAGS = ['node-integration-tag-1', 'node-integration-tag-2'];
212+
return createTemporaryModel({displayName: 'node-integration-test-update-3a'})
213+
.then((model) => {
214+
const modelOptions: admin.machineLearning.ModelOptions = {
215+
displayName: DISPLAY_NAME,
216+
tags: TAGS,
217+
};
218+
return admin.machineLearning().updateModel(model.modelId, modelOptions)
219+
.then((updatedModel) => {
220+
expect(updatedModel.displayName).to.equal(DISPLAY_NAME);
221+
expect(updatedModel.tags).to.have.same.members(TAGS);
222+
});
223+
});
224+
});
225+
});
226+
227+
describe('publishModel()', () => {
228+
it('should reject when model does not exist', () => {
229+
const nonExistingName = '00000000';
230+
return admin.machineLearning().publishModel(nonExistingName)
231+
.should.eventually.be.rejected.and.have.property(
232+
'code', 'machine-learning/not-found');
233+
});
234+
235+
it('rejects with invalid-argument when the ModelId is invalid', () => {
236+
return admin.machineLearning().publishModel('invalid-model-id')
237+
.should.eventually.be.rejected.and.have.property(
238+
'code', 'machine-learning/invalid-argument');
239+
});
240+
241+
it('publishes the model successfully', () => {
242+
const modelOptions: admin.machineLearning.ModelOptions = {
243+
displayName: 'node-integration-test-publish-1',
244+
tfliteModel: {gcsTfliteUri: 'this will be replaced below'},
245+
};
246+
return uploadModelToGcs('model1.tflite', 'valid_model.tflite')
247+
.then((fileName: string) => {
248+
modelOptions.tfliteModel!.gcsTfliteUri = fileName;
249+
createTemporaryModel(modelOptions)
250+
.then((createdModel) => {
251+
expect(createdModel.validationError).to.be.empty;
252+
expect(createdModel.published).to.be.false;
253+
admin.machineLearning().publishModel(createdModel.modelId)
254+
.then((publishedModel) => {
255+
expect(publishedModel.published).to.be.true;
256+
});
257+
});
258+
});
259+
});
260+
});
261+
262+
describe('unpublishModel()', () => {
263+
it('should reject when model does not exist', () => {
264+
const nonExistingName = '00000000';
265+
return admin.machineLearning().unpublishModel(nonExistingName)
266+
.should.eventually.be.rejected.and.have.property(
267+
'code', 'machine-learning/not-found');
268+
});
269+
270+
it('rejects with invalid-argument when the ModelId is invalid', () => {
271+
return admin.machineLearning().unpublishModel('invalid-model-id')
272+
.should.eventually.be.rejected.and.have.property(
273+
'code', 'machine-learning/invalid-argument');
274+
});
275+
276+
it('unpublishes the model successfully', () => {
277+
const modelOptions: admin.machineLearning.ModelOptions = {
278+
displayName: 'node-integration-test-unpublish-1',
279+
tfliteModel: {gcsTfliteUri: 'this will be replaced below'},
280+
};
281+
return uploadModelToGcs('model1.tflite', 'valid_model.tflite')
282+
.then((fileName: string) => {
283+
modelOptions.tfliteModel!.gcsTfliteUri = fileName;
284+
createTemporaryModel(modelOptions)
285+
.then((createdModel) => {
286+
expect(createdModel.validationError).to.be.empty;
287+
expect(createdModel.published).to.be.false;
288+
admin.machineLearning().publishModel(createdModel.modelId)
289+
.then((publishedModel) => {
290+
expect(publishedModel.published).to.be.true;
291+
admin.machineLearning().unpublishModel(publishedModel.modelId)
292+
.then((unpublishedModel) => {
293+
expect(unpublishedModel.published).to.be.false;
294+
});
295+
});
296+
});
297+
});
298+
});
299+
});
300+
301+
130302
describe('getModel()', () => {
131303
it('rejects with not-found when the Model does not exist', () => {
132304
const nonExistingName = '00000000';
@@ -181,7 +353,11 @@ describe('admin.machineLearning', () => {
181353
});
182354

183355
function verifyModel(model: admin.machineLearning.Model, expectedOptions: admin.machineLearning.ModelOptions) {
184-
expect(model.displayName).to.equal(expectedOptions.displayName);
356+
if (expectedOptions.displayName) {
357+
expect(model.displayName).to.equal(expectedOptions.displayName);
358+
} else {
359+
expect(model.displayName).not.to.be.empty;
360+
}
185361
expect(model.createTime).to.not.be.empty;
186362
expect(model.updateTime).to.not.be.empty;
187363
expect(model.etag).to.not.be.empty;

0 commit comments

Comments
 (0)