Skip to content

Commit 747629f

Browse files
authored
Merge branch 'main' into base64-encode-inputs
2 parents 25bb97e + d09067c commit 747629f

File tree

11 files changed

+527
-15
lines changed

11 files changed

+527
-15
lines changed

.github/workflows/ci.yml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,5 +47,8 @@ jobs:
4747
with:
4848
node-version: ${{ matrix.node-version }}
4949
cache: "npm"
50-
- run: npm --prefix integration/${{ matrix.suite }} ci --omit=dev
51-
- run: npm --prefix integration/${{ matrix.suite }} test
50+
# Build a production tarball and run the integration tests against it.
51+
- run: |
52+
PKG_TARBALL=$(npm --loglevel error pack)
53+
npm --prefix integration/${{ matrix.suite }} install "file:/./$PKG_TARBALL"
54+
npm --prefix integration/${{ matrix.suite }} test

README.md

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,14 @@ npm install replicate
2323
Create the client:
2424

2525
```js
26+
// CommonJS (default or using .cjs extension)
27+
const Replicate = require("replicate");
28+
29+
// ESM (where `"module": true` in package.json or using .mjs extension)
2630
import Replicate from "replicate";
31+
```
2732

33+
```
2834
const replicate = new Replicate({
2935
// get your token from https://replicate.com/account
3036
auth: "my api token", // defaults to process.env.REPLICATE_API_TOKEN
@@ -69,9 +75,11 @@ console.log(prediction.output);
6975

7076
To run a model that takes a file input, pass a URL to a publicly accessible file. Or, for smaller files (<10MB), you can pass the data directly.
7177

72-
7378
```js
74-
import { promises as fs } from "fs";
79+
const fs = require("node:fs/promises");
80+
81+
// Or when using ESM.
82+
// import fs from "node:fs/promises";
7583

7684
const model = "nightmareai/real-esrgan:42fed1c4974146d4d2414e2be2c5277c7fcf05fcc3a73abf41610695738c1d7b";
7785
const input = {
@@ -81,6 +89,10 @@ const output = await replicate.run(model, { input });
8189
// ['https://replicate.delivery/mgxm/e7b0e122-9daa-410e-8cde-006c7308ff4d/output.png']
8290
```
8391

92+
## TypeScript
93+
94+
Currently in order to support the module format used by `replicate` you'll need to set `esModuleInterop` to `true` in your tsconfig.json.
95+
8496
## API
8597

8698
### Constructor
@@ -112,8 +124,12 @@ you can install a fetch function from an external package like
112124
and pass it to the `fetch` option in the constructor.
113125

114126
```js
115-
import Replicate from "replicate";
116-
import fetch from "cross-fetch";
127+
const Replicate = require("replicate");
128+
const fetch = require("fetch");
129+
130+
// Using ESM:
131+
// import Replicate from "replicate";
132+
// import fetch from "cross-fetch";
117133

118134
const replicate = new Replicate({ fetch });
119135
```
@@ -188,9 +204,22 @@ Returns `AsyncGenerator<ServerSentEvent>` which yields the events of running the
188204
Example:
189205

190206
```js
191-
for await (const event of replicate.stream("meta/llama-2-70b-chat")) {
192-
process.stdout.write(`${event}`);
207+
const model = "meta/llama-2-70b-chat";
208+
const options = {
209+
input: {
210+
prompt: "Write a poem about machine learning in the style of Mary Oliver.",
211+
},
212+
// webhook: "https://smee.io/dMUlmOMkzeyRGjW" // optional
213+
};
214+
const output = [];
215+
216+
for await (const { event, data } of replicate.stream(model, options)) {
217+
if (event === "output") {
218+
output.push(data);
219+
}
193220
}
221+
222+
console.log(output.join("").trim());
194223
```
195224

196225
### Server-sent events

index.d.ts

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,37 @@ declare module "replicate" {
88
response: Response;
99
}
1010

11+
export interface Account {
12+
type: "user" | "organization";
13+
username: string;
14+
name: string;
15+
github_url?: string;
16+
}
17+
1118
export interface Collection {
1219
name: string;
1320
slug: string;
1421
description: string;
1522
models?: Model[];
1623
}
1724

25+
export interface Deployment {
26+
owner: string;
27+
name: string;
28+
current_release: {
29+
number: number;
30+
model: string;
31+
version: string;
32+
created_at: string;
33+
created_by: Account;
34+
configuration: {
35+
hardware: string;
36+
min_instances: number;
37+
max_instances: number;
38+
};
39+
};
40+
}
41+
1842
export interface Hardware {
1943
sku: string;
2044
name: string;
@@ -82,6 +106,10 @@ declare module "replicate" {
82106
retry?: number;
83107
}
84108

109+
export interface WebhookSecret {
110+
key: string;
111+
}
112+
85113
export default class Replicate {
86114
constructor(options?: {
87115
auth?: string;
@@ -140,6 +168,10 @@ declare module "replicate" {
140168
stop?: (prediction: Prediction) => Promise<boolean>
141169
): Promise<Prediction>;
142170

171+
accounts: {
172+
current(): Promise<Account>;
173+
};
174+
143175
collections: {
144176
list(): Promise<Page<Collection>>;
145177
get(collection_slug: string): Promise<Collection>;
@@ -158,6 +190,10 @@ declare module "replicate" {
158190
}
159191
): Promise<Prediction>;
160192
};
193+
get(
194+
deployment_owner: string,
195+
deployment_name: string
196+
): Promise<Deployment>;
161197
};
162198

163199
hardware: {
@@ -222,5 +258,26 @@ declare module "replicate" {
222258
cancel(training_id: string): Promise<Training>;
223259
list(): Promise<Page<Training>>;
224260
};
261+
262+
webhooks: {
263+
default: {
264+
secret: {
265+
get(): Promise<WebhookSecret>;
266+
};
267+
};
268+
};
225269
}
270+
271+
export function validateWebhook(
272+
requestData:
273+
| Request
274+
| {
275+
id?: string;
276+
timestamp?: string;
277+
body: string;
278+
secret?: string;
279+
signature?: string;
280+
},
281+
secret: string
282+
): boolean;
226283
}

index.js

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
const ApiError = require("./lib/error");
22
const ModelVersionIdentifier = require("./lib/identifier");
33
const { Stream } = require("./lib/stream");
4-
const { withAutomaticRetries } = require("./lib/util");
4+
const { withAutomaticRetries, validateWebhook } = require("./lib/util");
55

6+
const accounts = require("./lib/accounts");
67
const collections = require("./lib/collections");
78
const deployments = require("./lib/deployments");
89
const hardware = require("./lib/hardware");
910
const models = require("./lib/models");
1011
const predictions = require("./lib/predictions");
1112
const trainings = require("./lib/trainings");
13+
const webhooks = require("./lib/webhooks");
1214

1315
const packageJSON = require("./package.json");
1416

@@ -47,12 +49,17 @@ class Replicate {
4749
this.baseUrl = options.baseUrl || "https://api.replicate.com/v1";
4850
this.fetch = options.fetch || globalThis.fetch;
4951

52+
this.accounts = {
53+
current: accounts.current.bind(this),
54+
};
55+
5056
this.collections = {
5157
list: collections.list.bind(this),
5258
get: collections.get.bind(this),
5359
};
5460

5561
this.deployments = {
62+
get: deployments.get.bind(this),
5663
predictions: {
5764
create: deployments.predictions.create.bind(this),
5865
},
@@ -85,6 +92,14 @@ class Replicate {
8592
cancel: trainings.cancel.bind(this),
8693
list: trainings.list.bind(this),
8794
};
95+
96+
this.webhooks = {
97+
default: {
98+
secret: {
99+
get: webhooks.default.secret.get.bind(this),
100+
},
101+
},
102+
};
88103
}
89104

90105
/**
@@ -359,3 +374,4 @@ class Replicate {
359374
}
360375

361376
module.exports = Replicate;
377+
module.exports.validateWebhook = validateWebhook;

index.test.ts

Lines changed: 97 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
import { expect, jest, test } from "@jest/globals";
2-
import Replicate, { ApiError, Model, Prediction } from "replicate";
2+
import Replicate, {
3+
ApiError,
4+
Model,
5+
Prediction,
6+
validateWebhook,
7+
} from "replicate";
38
import nock from "nock";
49
import fetch from "cross-fetch";
510

@@ -67,6 +72,22 @@ describe("Replicate client", () => {
6772
});
6873
});
6974

75+
describe("account.get", () => {
76+
test("Calls the correct API route", async () => {
77+
nock(BASE_URL).get("/account").reply(200, {
78+
type: "organization",
79+
username: "replicate",
80+
name: "Replicate",
81+
github_url: "https://github.com/replicate",
82+
});
83+
84+
const account = await client.accounts.current();
85+
expect(account.type).toBe("organization");
86+
expect(account.username).toBe("replicate");
87+
});
88+
// Add more tests for error handling, edge cases, etc.
89+
});
90+
7091
describe("collections.list", () => {
7192
test("Calls the correct API route", async () => {
7293
nock(BASE_URL)
@@ -741,6 +762,47 @@ describe("Replicate client", () => {
741762
// Add more tests for error handling, edge cases, etc.
742763
});
743764

765+
describe("deployments.get", () => {
766+
test("Calls the correct API route with the correct payload", async () => {
767+
nock(BASE_URL)
768+
.get("/deployments/acme/my-app-image-generator")
769+
.reply(200, {
770+
owner: "acme",
771+
name: "my-app-image-generator",
772+
current_release: {
773+
number: 1,
774+
model: "stability-ai/sdxl",
775+
version:
776+
"da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf",
777+
created_at: "2024-02-15T16:32:57.018467Z",
778+
created_by: {
779+
type: "organization",
780+
username: "acme",
781+
name: "Acme Corp, Inc.",
782+
github_url: "https://github.com/acme",
783+
},
784+
configuration: {
785+
hardware: "gpu-t4",
786+
scaling: {
787+
min_instances: 1,
788+
max_instances: 5,
789+
},
790+
},
791+
},
792+
});
793+
794+
const deployment = await client.deployments.get(
795+
"acme",
796+
"my-app-image-generator"
797+
);
798+
799+
expect(deployment.owner).toBe("acme");
800+
expect(deployment.name).toBe("my-app-image-generator");
801+
expect(deployment.current_release.model).toBe("stability-ai/sdxl");
802+
});
803+
// Add more tests for error handling, edge cases, etc.
804+
});
805+
744806
describe("predictions.create with model", () => {
745807
test("Calls the correct API route with the correct payload", async () => {
746808
nock(BASE_URL)
@@ -1021,5 +1083,39 @@ describe("Replicate client", () => {
10211083
});
10221084
});
10231085

1086+
describe("webhooks.default.secret.get", () => {
1087+
test("Calls the correct API route", async () => {
1088+
nock(BASE_URL).get("/webhooks/default/secret").reply(200, {
1089+
key: "whsec_5WbX5kEWLlfzsGNjH64I8lOOqUB6e8FH",
1090+
});
1091+
1092+
const secret = await client.webhooks.default.secret.get();
1093+
expect(secret.key).toBe("whsec_5WbX5kEWLlfzsGNjH64I8lOOqUB6e8FH");
1094+
});
1095+
1096+
test("Can be used to validate webhook", async () => {
1097+
// Test case from https://github.com/svix/svix-webhooks/blob/b41728cd98a7e7004a6407a623f43977b82fcba4/javascript/src/webhook.test.ts#L190-L200
1098+
const request = new Request("http://test.host/webhook", {
1099+
method: "POST",
1100+
headers: {
1101+
"Content-Type": "application/json",
1102+
"Webhook-ID": "msg_p5jXN8AQM9LWM0D4loKWxJek",
1103+
"Webhook-Timestamp": "1614265330",
1104+
"Webhook-Signature":
1105+
"v1,g0hM9SsE+OTPJTGt/tmIKtSyZlE3uFJELVlNIOLJ1OE=",
1106+
},
1107+
body: `{"test": 2432232314}`,
1108+
});
1109+
1110+
// This is a test secret and should not be used in production
1111+
const secret = "whsec_MfKQ9r8GKYqrTwjUPD8ILPZIo2LaLaSw";
1112+
1113+
const isValid = await validateWebhook(request, secret);
1114+
expect(isValid).toBe(true);
1115+
});
1116+
1117+
// Add more tests for error handling, edge cases, etc.
1118+
});
1119+
10241120
// Continue with tests for other methods
10251121
});

lib/accounts.js

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
/**
2+
* Get the current account
3+
*
4+
* @returns {Promise<object>} Resolves with the current account
5+
*/
6+
async function getCurrentAccount() {
7+
const response = await this.request("/account", {
8+
method: "GET",
9+
});
10+
11+
return response.json();
12+
}
13+
14+
module.exports = {
15+
current: getCurrentAccount,
16+
};

0 commit comments

Comments
 (0)