diff --git a/index.d.ts b/index.d.ts index b6e343d..e46f564 100644 --- a/index.d.ts +++ b/index.d.ts @@ -1,3 +1,5 @@ +import type { AxiosRequestConfig } from "axios"; + type Identifier = `${string}/${string}:${string}`; declare module "replicate" { @@ -5,6 +7,7 @@ declare module "replicate" { auth: string; userAgent: string; baseUrl?: string; + axiosConfig?: Partial>; } export interface Collection { diff --git a/index.js b/index.js index 9872294..cc58b81 100644 --- a/index.js +++ b/index.js @@ -38,6 +38,7 @@ class Replicate { options.userAgent || `replicate-javascript/${packageJSON.version}`; this.baseUrl = options.baseUrl || 'https://api.replicate.com/v1'; this.instance = axios.create({ + ...options.axiosConfig, baseURL: this.baseUrl, headers: { Authorization: `Token ${this.auth}`, diff --git a/index.test.js b/index.test.js index 2770dd5..0da4602 100644 --- a/index.test.js +++ b/index.test.js @@ -12,6 +12,16 @@ describe('Replicate client', () => { expect(clientWithoutBaseUrl.baseUrl).toBe('https://api.replicate.com/v1'); }); + test('Constructor passes through axios options', async () => { + const adapter = jest.fn(() => Promise.resolve({})); + const clientWithAxiosOptions = new Replicate({ + axiosConfig: { adapter }, + }); + + await clientWithAxiosOptions.collections.get('text-to-image'); + expect(adapter).toHaveBeenCalled(); + }); + describe('collections.get', () => { test('Calls the correct API route', async () => { client.request = jest.fn();