diff --git a/.changeset/smooth-parrots-speak.md b/.changeset/smooth-parrots-speak.md new file mode 100644 index 0000000000..49a7b10c73 --- /dev/null +++ b/.changeset/smooth-parrots-speak.md @@ -0,0 +1,6 @@ +--- +'@firebase/ai': minor +'firebase': minor +--- + +Add `inferenceSource` to the response from `generateContent` and `generateContentStream`. This property indicates whether on-device or in-cloud inference was used to generate the result. diff --git a/common/api-review/ai.api.md b/common/api-review/ai.api.md index debea0a854..f3113e0ac2 100644 --- a/common/api-review/ai.api.md +++ b/common/api-review/ai.api.md @@ -256,6 +256,8 @@ export { Date_2 as Date } // @public export interface EnhancedGenerateContentResponse extends GenerateContentResponse { functionCalls: () => FunctionCall[] | undefined; + // @beta + inferenceSource?: InferenceSource; inlineDataParts: () => InlineDataPart[] | undefined; text: () => string; thoughtSummary: () => string | undefined; @@ -816,6 +818,15 @@ export const InferenceMode: { // @beta export type InferenceMode = (typeof InferenceMode)[keyof typeof InferenceMode]; +// @beta +export const InferenceSource: { + readonly ON_DEVICE: "on_device"; + readonly IN_CLOUD: "in_cloud"; +}; + +// @beta +export type InferenceSource = (typeof InferenceSource)[keyof typeof InferenceSource]; + // @public export interface InlineDataPart { // (undocumented) diff --git a/docs-devsite/ai.enhancedgeneratecontentresponse.md b/docs-devsite/ai.enhancedgeneratecontentresponse.md index 9e947add0c..609196d603 100644 --- a/docs-devsite/ai.enhancedgeneratecontentresponse.md +++ b/docs-devsite/ai.enhancedgeneratecontentresponse.md @@ -24,6 +24,7 @@ export interface EnhancedGenerateContentResponse extends GenerateContentResponse | Property | Type | Description | | --- | --- | --- | | [functionCalls](./ai.enhancedgeneratecontentresponse.md#enhancedgeneratecontentresponsefunctioncalls) | () => [FunctionCall](./ai.functioncall.md#functioncall_interface)\[\] \| undefined | Aggregates and returns every [FunctionCall](./ai.functioncall.md#functioncall_interface) from the first candidate of [GenerateContentResponse](./ai.generatecontentresponse.md#generatecontentresponse_interface). | +| [inferenceSource](./ai.enhancedgeneratecontentresponse.md#enhancedgeneratecontentresponseinferencesource) | [InferenceSource](./ai.md#inferencesource) | (Public Preview) Indicates whether inference happened on-device or in-cloud. | | [inlineDataParts](./ai.enhancedgeneratecontentresponse.md#enhancedgeneratecontentresponseinlinedataparts) | () => [InlineDataPart](./ai.inlinedatapart.md#inlinedatapart_interface)\[\] \| undefined | Aggregates and returns every [InlineDataPart](./ai.inlinedatapart.md#inlinedatapart_interface) from the first candidate of [GenerateContentResponse](./ai.generatecontentresponse.md#generatecontentresponse_interface). | | [text](./ai.enhancedgeneratecontentresponse.md#enhancedgeneratecontentresponsetext) | () => string | Returns the text string from the response, if available. Throws if the prompt or candidate was blocked. | | [thoughtSummary](./ai.enhancedgeneratecontentresponse.md#enhancedgeneratecontentresponsethoughtsummary) | () => string \| undefined | Aggregates and returns every [TextPart](./ai.textpart.md#textpart_interface) with their thought property set to true from the first candidate of [GenerateContentResponse](./ai.generatecontentresponse.md#generatecontentresponse_interface). | @@ -38,6 +39,19 @@ Aggregates and returns every [FunctionCall](./ai.functioncall.md#functioncall_in functionCalls: () => FunctionCall[] | undefined; ``` +## EnhancedGenerateContentResponse.inferenceSource + +> This API is provided as a preview for developers and may change based on feedback that we receive. Do not use this API in a production environment. +> + +Indicates whether inference happened on-device or in-cloud. + +Signature: + +```typescript +inferenceSource?: InferenceSource; +``` + ## EnhancedGenerateContentResponse.inlineDataParts Aggregates and returns every [InlineDataPart](./ai.inlinedatapart.md#inlinedatapart_interface) from the first candidate of [GenerateContentResponse](./ai.generatecontentresponse.md#generatecontentresponse_interface). diff --git a/docs-devsite/ai.md b/docs-devsite/ai.md index db6148ee88..fabdbc5cc5 100644 --- a/docs-devsite/ai.md +++ b/docs-devsite/ai.md @@ -162,6 +162,7 @@ The Firebase AI Web SDK. | [ImagenPersonFilterLevel](./ai.md#imagenpersonfilterlevel) | A filter level controlling whether generation of images containing people or faces is allowed.See the personGeneration documentation for more details. | | [ImagenSafetyFilterLevel](./ai.md#imagensafetyfilterlevel) | A filter level controlling how aggressively to filter sensitive content.Text prompts provided as inputs and images (generated or uploaded) through Imagen on Vertex AI are assessed against a list of safety filters, which include 'harmful categories' (for example, violence, sexual, derogatory, and toxic). This filter level controls how aggressively to filter out potentially harmful content from responses. See the [documentation](http://firebase.google.com/docs/vertex-ai/generate-images) and the [Responsible AI and usage guidelines](https://cloud.google.com/vertex-ai/generative-ai/docs/image/responsible-ai-imagen#safety-filters) for more details. | | [InferenceMode](./ai.md#inferencemode) | (Public Preview) Determines whether inference happens on-device or in-cloud. | +| [InferenceSource](./ai.md#inferencesource) | (Public Preview) Indicates whether inference happened on-device or in-cloud. | | [Language](./ai.md#language) | (Public Preview) The programming language of the code. | | [LiveResponseType](./ai.md#liveresponsetype) | (Public Preview) The types of responses that can be returned by [LiveSession.receive()](./ai.livesession.md#livesessionreceive). | | [Modality](./ai.md#modality) | Content part modality. | @@ -189,6 +190,7 @@ The Firebase AI Web SDK. | [ImagenPersonFilterLevel](./ai.md#imagenpersonfilterlevel) | A filter level controlling whether generation of images containing people or faces is allowed.See the personGeneration documentation for more details. | | [ImagenSafetyFilterLevel](./ai.md#imagensafetyfilterlevel) | A filter level controlling how aggressively to filter sensitive content.Text prompts provided as inputs and images (generated or uploaded) through Imagen on Vertex AI are assessed against a list of safety filters, which include 'harmful categories' (for example, violence, sexual, derogatory, and toxic). This filter level controls how aggressively to filter out potentially harmful content from responses. See the [documentation](http://firebase.google.com/docs/vertex-ai/generate-images) and the [Responsible AI and usage guidelines](https://cloud.google.com/vertex-ai/generative-ai/docs/image/responsible-ai-imagen#safety-filters) for more details. | | [InferenceMode](./ai.md#inferencemode) | (Public Preview) Determines whether inference happens on-device or in-cloud. | +| [InferenceSource](./ai.md#inferencesource) | (Public Preview) Indicates whether inference happened on-device or in-cloud. | | [Language](./ai.md#language) | (Public Preview) The programming language of the code. | | [LanguageModelMessageContentValue](./ai.md#languagemodelmessagecontentvalue) | (Public Preview) Content formats that can be provided as on-device message content. | | [LanguageModelMessageRole](./ai.md#languagemodelmessagerole) | (Public Preview) Allowable roles for on-device language model usage. | @@ -643,6 +645,22 @@ InferenceMode: { } ``` +## InferenceSource + +> This API is provided as a preview for developers and may change based on feedback that we receive. Do not use this API in a production environment. +> + +Indicates whether inference happened on-device or in-cloud. + +Signature: + +```typescript +InferenceSource: { + readonly ON_DEVICE: "on_device"; + readonly IN_CLOUD: "in_cloud"; +} +``` + ## Language > This API is provided as a preview for developers and may change based on feedback that we receive. Do not use this API in a production environment. @@ -926,6 +944,19 @@ Determines whether inference happens on-device or in-cloud. export type InferenceMode = (typeof InferenceMode)[keyof typeof InferenceMode]; ``` +## InferenceSource + +> This API is provided as a preview for developers and may change based on feedback that we receive. Do not use this API in a production environment. +> + +Indicates whether inference happened on-device or in-cloud. + +Signature: + +```typescript +export type InferenceSource = (typeof InferenceSource)[keyof typeof InferenceSource]; +``` + ## Language > This API is provided as a preview for developers and may change based on feedback that we receive. Do not use this API in a production environment. diff --git a/packages/ai/src/methods/generate-content.ts b/packages/ai/src/methods/generate-content.ts index 0e65b47934..a2fb29e20d 100644 --- a/packages/ai/src/methods/generate-content.ts +++ b/packages/ai/src/methods/generate-content.ts @@ -57,14 +57,14 @@ export async function generateContentStream( chromeAdapter?: ChromeAdapter, requestOptions?: RequestOptions ): Promise { - const response = await callCloudOrDevice( + const callResult = await callCloudOrDevice( params, chromeAdapter, () => chromeAdapter!.generateContentStream(params), () => generateContentStreamOnCloud(apiSettings, model, params, requestOptions) ); - return processStream(response, apiSettings); // TODO: Map streaming responses + return processStream(callResult.response, apiSettings); // TODO: Map streaming responses } async function generateContentOnCloud( @@ -93,18 +93,19 @@ export async function generateContent( chromeAdapter?: ChromeAdapter, requestOptions?: RequestOptions ): Promise { - const response = await callCloudOrDevice( + const callResult = await callCloudOrDevice( params, chromeAdapter, () => chromeAdapter!.generateContent(params), () => generateContentOnCloud(apiSettings, model, params, requestOptions) ); const generateContentResponse = await processGenerateContentResponse( - response, + callResult.response, apiSettings ); const enhancedResponse = createEnhancedContentResponse( - generateContentResponse + generateContentResponse, + callResult.inferenceSource ); return { response: enhancedResponse diff --git a/packages/ai/src/requests/hybrid-helpers.test.ts b/packages/ai/src/requests/hybrid-helpers.test.ts index a758f34ad2..33e83c0469 100644 --- a/packages/ai/src/requests/hybrid-helpers.test.ts +++ b/packages/ai/src/requests/hybrid-helpers.test.ts @@ -18,7 +18,12 @@ import { use, expect } from 'chai'; import { SinonStub, SinonStubbedInstance, restore, stub } from 'sinon'; import { callCloudOrDevice } from './hybrid-helpers'; -import { GenerateContentRequest, InferenceMode, AIErrorCode } from '../types'; +import { + GenerateContentRequest, + InferenceMode, + AIErrorCode, + InferenceSource +} from '../types'; import { AIError } from '../errors'; import sinonChai from 'sinon-chai'; import chaiAsPromised from 'chai-as-promised'; @@ -58,7 +63,8 @@ describe('callCloudOrDevice', () => { onDeviceCall, inCloudCall ); - expect(result).to.equal('in-cloud-response'); + expect(result.response).to.equal('in-cloud-response'); + expect(result.inferenceSource).to.equal(InferenceSource.IN_CLOUD); expect(inCloudCall).to.have.been.calledOnce; expect(onDeviceCall).to.not.have.been.called; }); @@ -76,7 +82,8 @@ describe('callCloudOrDevice', () => { onDeviceCall, inCloudCall ); - expect(result).to.equal('on-device-response'); + expect(result.response).to.equal('on-device-response'); + expect(result.inferenceSource).to.equal(InferenceSource.ON_DEVICE); expect(onDeviceCall).to.have.been.calledOnce; expect(inCloudCall).to.not.have.been.called; }); @@ -89,7 +96,8 @@ describe('callCloudOrDevice', () => { onDeviceCall, inCloudCall ); - expect(result).to.equal('in-cloud-response'); + expect(result.response).to.equal('in-cloud-response'); + expect(result.inferenceSource).to.equal(InferenceSource.IN_CLOUD); expect(inCloudCall).to.have.been.calledOnce; expect(onDeviceCall).to.not.have.been.called; }); @@ -108,7 +116,8 @@ describe('callCloudOrDevice', () => { onDeviceCall, inCloudCall ); - expect(result).to.equal('on-device-response'); + expect(result.response).to.equal('on-device-response'); + expect(result.inferenceSource).to.equal(InferenceSource.ON_DEVICE); expect(onDeviceCall).to.have.been.calledOnce; expect(inCloudCall).to.not.have.been.called; }); @@ -136,7 +145,8 @@ describe('callCloudOrDevice', () => { onDeviceCall, inCloudCall ); - expect(result).to.equal('in-cloud-response'); + expect(result.response).to.equal('in-cloud-response'); + expect(result.inferenceSource).to.equal(InferenceSource.IN_CLOUD); expect(inCloudCall).to.have.been.calledOnce; expect(onDeviceCall).to.not.have.been.called; }); @@ -154,7 +164,8 @@ describe('callCloudOrDevice', () => { onDeviceCall, inCloudCall ); - expect(result).to.equal('in-cloud-response'); + expect(result.response).to.equal('in-cloud-response'); + expect(result.inferenceSource).to.equal(InferenceSource.IN_CLOUD); expect(inCloudCall).to.have.been.calledOnce; expect(onDeviceCall).to.not.have.been.called; }); @@ -169,7 +180,8 @@ describe('callCloudOrDevice', () => { onDeviceCall, inCloudCall ); - expect(result).to.equal('on-device-response'); + expect(result.response).to.equal('on-device-response'); + expect(result.inferenceSource).to.equal(InferenceSource.ON_DEVICE); expect(inCloudCall).to.have.been.calledOnce; expect(onDeviceCall).to.have.been.calledOnce; }); diff --git a/packages/ai/src/requests/hybrid-helpers.ts b/packages/ai/src/requests/hybrid-helpers.ts index 3140594c00..b37505bf93 100644 --- a/packages/ai/src/requests/hybrid-helpers.ts +++ b/packages/ai/src/requests/hybrid-helpers.ts @@ -20,7 +20,8 @@ import { GenerateContentRequest, InferenceMode, AIErrorCode, - ChromeAdapter + ChromeAdapter, + InferenceSource } from '../types'; import { ChromeAdapterImpl } from '../methods/chrome-adapter'; @@ -33,6 +34,11 @@ const errorsCausingFallback: AIErrorCode[] = [ AIErrorCode.API_NOT_ENABLED ]; +interface CallResult { + response: Response; + inferenceSource: InferenceSource; +} + /** * Dispatches a request to the appropriate backend (on-device or in-cloud) * based on the inference mode. @@ -48,35 +54,56 @@ export async function callCloudOrDevice( chromeAdapter: ChromeAdapter | undefined, onDeviceCall: () => Promise, inCloudCall: () => Promise -): Promise { +): Promise> { if (!chromeAdapter) { - return inCloudCall(); + return { + response: await inCloudCall(), + inferenceSource: InferenceSource.IN_CLOUD + }; } switch ((chromeAdapter as ChromeAdapterImpl).mode) { case InferenceMode.ONLY_ON_DEVICE: if (await chromeAdapter.isAvailable(request)) { - return onDeviceCall(); + return { + response: await onDeviceCall(), + inferenceSource: InferenceSource.ON_DEVICE + }; } throw new AIError( AIErrorCode.UNSUPPORTED, 'Inference mode is ONLY_ON_DEVICE, but an on-device model is not available.' ); case InferenceMode.ONLY_IN_CLOUD: - return inCloudCall(); + return { + response: await inCloudCall(), + inferenceSource: InferenceSource.IN_CLOUD + }; case InferenceMode.PREFER_IN_CLOUD: try { - return await inCloudCall(); + return { + response: await inCloudCall(), + inferenceSource: InferenceSource.IN_CLOUD + }; } catch (e) { if (e instanceof AIError && errorsCausingFallback.includes(e.code)) { - return onDeviceCall(); + return { + response: await onDeviceCall(), + inferenceSource: InferenceSource.ON_DEVICE + }; } throw e; } case InferenceMode.PREFER_ON_DEVICE: if (await chromeAdapter.isAvailable(request)) { - return onDeviceCall(); + return { + response: await onDeviceCall(), + inferenceSource: InferenceSource.ON_DEVICE + }; } - return inCloudCall(); + return { + response: await inCloudCall(), + inferenceSource: InferenceSource.IN_CLOUD + }; default: throw new AIError( AIErrorCode.ERROR, diff --git a/packages/ai/src/requests/response-helpers.ts b/packages/ai/src/requests/response-helpers.ts index 930bfabb2a..bb3748f6bc 100644 --- a/packages/ai/src/requests/response-helpers.ts +++ b/packages/ai/src/requests/response-helpers.ts @@ -25,7 +25,8 @@ import { ImagenInlineImage, AIErrorCode, InlineDataPart, - Part + Part, + InferenceSource } from '../types'; import { AIError } from '../errors'; import { logger } from '../logger'; @@ -66,7 +67,8 @@ function hasValidCandidates(response: GenerateContentResponse): boolean { * other modifications that improve usability. */ export function createEnhancedContentResponse( - response: GenerateContentResponse + response: GenerateContentResponse, + inferenceSource: InferenceSource = InferenceSource.IN_CLOUD ): EnhancedGenerateContentResponse { /** * The Vertex AI backend omits default values. @@ -79,6 +81,7 @@ export function createEnhancedContentResponse( } const responseWithHelpers = addHelpers(response); + responseWithHelpers.inferenceSource = inferenceSource; return responseWithHelpers; } diff --git a/packages/ai/src/requests/stream-reader.test.ts b/packages/ai/src/requests/stream-reader.test.ts index 2e50bbb3d3..ca3c2cdcfe 100644 --- a/packages/ai/src/requests/stream-reader.test.ts +++ b/packages/ai/src/requests/stream-reader.test.ts @@ -34,7 +34,8 @@ import { HarmCategory, HarmProbability, SafetyRating, - AIErrorCode + AIErrorCode, + InferenceSource } from '../types'; import { AIError } from '../errors'; import { ApiSettings } from '../types/internal'; @@ -61,6 +62,7 @@ describe('getResponseStream', () => { .map(v => JSON.stringify(v)) .map(v => 'data: ' + v + '\r\n\r\n') .join('') + // @ts-ignore ).pipeThrough(new TextDecoderStream('utf8', { fatal: true })); const responseStream = getResponseStream<{ text: string }>(inputStream); const reader = responseStream.getReader(); @@ -88,9 +90,33 @@ describe('processStream', () => { const result = processStream(fakeResponse as Response, fakeApiSettings); for await (const response of result.stream) { expect(response.text()).to.not.be.empty; + expect(response.inferenceSource).to.equal(InferenceSource.IN_CLOUD); } const aggregatedResponse = await result.response; expect(aggregatedResponse.text()).to.include('Cheyenne'); + expect(aggregatedResponse.inferenceSource).to.equal( + InferenceSource.IN_CLOUD + ); + }); + it('streaming response - short - on-device', async () => { + const fakeResponse = getMockResponseStreaming( + 'vertexAI', + 'streaming-success-basic-reply-short.txt' + ); + const result = processStream( + fakeResponse as Response, + fakeApiSettings, + InferenceSource.ON_DEVICE + ); + for await (const response of result.stream) { + expect(response.text()).to.not.be.empty; + expect(response.inferenceSource).to.equal(InferenceSource.ON_DEVICE); + } + const aggregatedResponse = await result.response; + expect(aggregatedResponse.text()).to.include('Cheyenne'); + expect(aggregatedResponse.inferenceSource).to.equal( + InferenceSource.ON_DEVICE + ); }); it('streaming response - long', async () => { const fakeResponse = getMockResponseStreaming( diff --git a/packages/ai/src/requests/stream-reader.ts b/packages/ai/src/requests/stream-reader.ts index 042c052fa8..b4968969be 100644 --- a/packages/ai/src/requests/stream-reader.ts +++ b/packages/ai/src/requests/stream-reader.ts @@ -28,7 +28,11 @@ import { createEnhancedContentResponse } from './response-helpers'; import * as GoogleAIMapper from '../googleai-mappers'; import { GoogleAIGenerateContentResponse } from '../types/googleai'; import { ApiSettings } from '../types/internal'; -import { BackendType, URLContextMetadata } from '../public-types'; +import { + BackendType, + InferenceSource, + URLContextMetadata +} from '../public-types'; const responseLineRE = /^data\: (.*)(?:\n\n|\r\r|\r\n\r\n)/; @@ -42,7 +46,8 @@ const responseLineRE = /^data\: (.*)(?:\n\n|\r\r|\r\n\r\n)/; */ export function processStream( response: Response, - apiSettings: ApiSettings + apiSettings: ApiSettings, + inferenceSource?: InferenceSource ): GenerateContentStreamResult { const inputStream = response.body!.pipeThrough( new TextDecoderStream('utf8', { fatal: true }) @@ -51,14 +56,15 @@ export function processStream( getResponseStream(inputStream); const [stream1, stream2] = responseStream.tee(); return { - stream: generateResponseSequence(stream1, apiSettings), - response: getResponsePromise(stream2, apiSettings) + stream: generateResponseSequence(stream1, apiSettings, inferenceSource), + response: getResponsePromise(stream2, apiSettings, inferenceSource) }; } async function getResponsePromise( stream: ReadableStream, - apiSettings: ApiSettings + apiSettings: ApiSettings, + inferenceSource?: InferenceSource ): Promise { const allResponses: GenerateContentResponse[] = []; const reader = stream.getReader(); @@ -71,7 +77,10 @@ async function getResponsePromise( generateContentResponse as GoogleAIGenerateContentResponse ); } - return createEnhancedContentResponse(generateContentResponse); + return createEnhancedContentResponse( + generateContentResponse, + inferenceSource + ); } allResponses.push(value); @@ -80,7 +89,8 @@ async function getResponsePromise( async function* generateResponseSequence( stream: ReadableStream, - apiSettings: ApiSettings + apiSettings: ApiSettings, + inferenceSource?: InferenceSource ): AsyncGenerator { const reader = stream.getReader(); while (true) { @@ -94,10 +104,11 @@ async function* generateResponseSequence( enhancedResponse = createEnhancedContentResponse( GoogleAIMapper.mapGenerateContentResponse( value as GoogleAIGenerateContentResponse - ) + ), + inferenceSource ); } else { - enhancedResponse = createEnhancedContentResponse(value); + enhancedResponse = createEnhancedContentResponse(value, inferenceSource); } const firstCandidate = enhancedResponse.candidates?.[0]; diff --git a/packages/ai/src/types/enums.ts b/packages/ai/src/types/enums.ts index cd7029df3b..f7c55d5e4c 100644 --- a/packages/ai/src/types/enums.ts +++ b/packages/ai/src/types/enums.ts @@ -379,6 +379,24 @@ export const InferenceMode = { */ export type InferenceMode = (typeof InferenceMode)[keyof typeof InferenceMode]; +/** + * Indicates whether inference happened on-device or in-cloud. + * + * @beta + */ +export const InferenceSource = { + 'ON_DEVICE': 'on_device', + 'IN_CLOUD': 'in_cloud' +} as const; + +/** + * Indicates whether inference happened on-device or in-cloud. + * + * @beta + */ +export type InferenceSource = + (typeof InferenceSource)[keyof typeof InferenceSource]; + /** * Represents the result of the code execution. * diff --git a/packages/ai/src/types/responses.ts b/packages/ai/src/types/responses.ts index 8b8e135167..ec06592f90 100644 --- a/packages/ai/src/types/responses.ts +++ b/packages/ai/src/types/responses.ts @@ -22,6 +22,7 @@ import { HarmCategory, HarmProbability, HarmSeverity, + InferenceSource, Modality } from './enums'; @@ -88,6 +89,12 @@ export interface EnhancedGenerateContentResponse * set to `true`. */ thoughtSummary: () => string | undefined; + /** + * Indicates whether inference happened on-device or in-cloud. + * + * @beta + */ + inferenceSource?: InferenceSource; } /**