diff --git a/.changeset/fast-mangos-chew.md b/.changeset/fast-mangos-chew.md new file mode 100644 index 00000000000..c5d2e4b4d1f --- /dev/null +++ b/.changeset/fast-mangos-chew.md @@ -0,0 +1,5 @@ +--- +'@firebase/vertexai': patch +--- + +Pass `GenerativeModel`'s `BaseParams` to created chat sessions. This fixes an issue where `GenerationConfig` would not be inherited from `ChatSession`. diff --git a/packages/vertexai/src/models/generative-model.test.ts b/packages/vertexai/src/models/generative-model.test.ts index 987f9b115e2..51ea8aafead 100644 --- a/packages/vertexai/src/models/generative-model.test.ts +++ b/packages/vertexai/src/models/generative-model.test.ts @@ -172,7 +172,10 @@ describe('GenerativeModel', () => { { functionDeclarations: [{ name: 'myfunc', description: 'mydesc' }] } ], toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.NONE } }, - systemInstruction: { role: 'system', parts: [{ text: 'be friendly' }] } + systemInstruction: { role: 'system', parts: [{ text: 'be friendly' }] }, + generationConfig: { + topK: 1 + } }); expect(genModel.tools?.length).to.equal(1); expect(genModel.toolConfig?.functionCallingConfig?.mode).to.equal( @@ -196,7 +199,8 @@ describe('GenerativeModel', () => { return ( value.includes('myfunc') && value.includes(FunctionCallingMode.NONE) && - value.includes('be friendly') + value.includes('be friendly') && + value.includes('topK') ); }), {} @@ -236,7 +240,10 @@ describe('GenerativeModel', () => { { functionDeclarations: [{ name: 'myfunc', description: 'mydesc' }] } ], toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.NONE } }, - systemInstruction: { role: 'system', parts: [{ text: 'be friendly' }] } + systemInstruction: { role: 'system', parts: [{ text: 'be friendly' }] }, + generationConfig: { + responseMimeType: 'image/jpeg' + } }); expect(genModel.tools?.length).to.equal(1); expect(genModel.toolConfig?.functionCallingConfig?.mode).to.equal( @@ -262,7 +269,10 @@ describe('GenerativeModel', () => { toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.AUTO } }, - systemInstruction: { role: 'system', parts: [{ text: 'be formal' }] } + systemInstruction: { role: 'system', parts: [{ text: 'be formal' }] }, + generationConfig: { + responseMimeType: 'image/png' + } }) .sendMessage('hello'); expect(makeRequestStub).to.be.calledWith( @@ -274,7 +284,9 @@ describe('GenerativeModel', () => { return ( value.includes('otherfunc') && value.includes(FunctionCallingMode.AUTO) && - value.includes('be formal') + value.includes('be formal') && + value.includes('image/png') && + !value.includes('image/jpeg') ); }), {} diff --git a/packages/vertexai/src/models/generative-model.ts b/packages/vertexai/src/models/generative-model.ts index 983118bf6ff..1af1ee700d5 100644 --- a/packages/vertexai/src/models/generative-model.ts +++ b/packages/vertexai/src/models/generative-model.ts @@ -132,6 +132,13 @@ export class GenerativeModel extends VertexAIModel { tools: this.tools, toolConfig: this.toolConfig, systemInstruction: this.systemInstruction, + generationConfig: this.generationConfig, + safetySettings: this.safetySettings, + /** + * Overrides params inherited from GenerativeModel with those explicitly set in the + * StartChatParams. For example, if startChatParams.generationConfig is set, it'll override + * this.generationConfig. + */ ...startChatParams }, this.requestOptions