From 911f2f811375349adc0db46b8a248c4f894c2d59 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Thu, 21 Dec 2023 04:51:43 -0800 Subject: [PATCH] Add support for accounts.current endpoint --- index.d.ts | 11 +++++++++++ index.js | 5 +++++ index.test.ts | 16 ++++++++++++++++ lib/accounts.js | 16 ++++++++++++++++ 4 files changed, 48 insertions(+) create mode 100644 lib/accounts.js diff --git a/index.d.ts b/index.d.ts index 5620f3b..76df9d4 100644 --- a/index.d.ts +++ b/index.d.ts @@ -8,6 +8,13 @@ declare module "replicate" { response: Response; } + export interface Account { + type: "user" | "organization"; + username: string; + name: string; + github_url?: string; + } + export interface Collection { name: string; slug: string; @@ -140,6 +147,10 @@ declare module "replicate" { stop?: (prediction: Prediction) => Promise ): Promise; + accounts: { + current(): Promise; + }; + collections: { list(): Promise>; get(collection_slug: string): Promise; diff --git a/index.js b/index.js index ce407f9..a85ea4e 100644 --- a/index.js +++ b/index.js @@ -3,6 +3,7 @@ const ModelVersionIdentifier = require("./lib/identifier"); const { Stream } = require("./lib/stream"); const { withAutomaticRetries } = require("./lib/util"); +const accounts = require("./lib/accounts"); const collections = require("./lib/collections"); const deployments = require("./lib/deployments"); const hardware = require("./lib/hardware"); @@ -47,6 +48,10 @@ class Replicate { this.baseUrl = options.baseUrl || "https://api.replicate.com/v1"; this.fetch = options.fetch || globalThis.fetch; + this.accounts = { + current: accounts.current.bind(this), + }; + this.collections = { list: collections.list.bind(this), get: collections.get.bind(this), diff --git a/index.test.ts b/index.test.ts index 5b5a1dd..d50ccb4 100644 --- a/index.test.ts +++ b/index.test.ts @@ -67,6 +67,22 @@ describe("Replicate client", () => { }); }); + describe("accounts.current", () => { + test("Calls the correct API route", async () => { + nock(BASE_URL).get("/account").reply(200, { + type: "organization", + username: "replicate", + name: "Replicate", + github_url: "https://github.com/replicate", + }); + + const account = await client.accounts.current(); + expect(account.type).toBe("organization"); + expect(account.username).toBe("replicate"); + }); + // Add more tests for error handling, edge cases, etc. + }); + describe("collections.list", () => { test("Calls the correct API route", async () => { nock(BASE_URL) diff --git a/lib/accounts.js b/lib/accounts.js new file mode 100644 index 0000000..b3bbd9f --- /dev/null +++ b/lib/accounts.js @@ -0,0 +1,16 @@ +/** + * Get the current account + * + * @returns {Promise} Resolves with the current account + */ +async function getCurrentAccount() { + const response = await this.request("/account", { + method: "GET", + }); + + return response.json(); +} + +module.exports = { + current: getCurrentAccount, +};