Skip to content

Commit 2823d83

Browse files
committed
Automatically transform binary inputs into data uris
1 parent 8210d23 commit 2823d83

File tree

4 files changed

+57
-16
lines changed

4 files changed

+57
-16
lines changed

README.md

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -67,24 +67,15 @@ console.log(prediction.output);
6767
// ['https://replicate.delivery/pbxt/RoaxeXqhL0xaYyLm6w3bpGwF5RaNBjADukfFnMbhOyeoWBdhA/out-0.png']
6868
```
6969

70-
To run a model that takes a file input, pass a URL to a publicly accessible file. Or, for smaller files (<10MB), you can convert file data into a base64-encoded data URI and pass that directly:
70+
To run a model that takes a file input, pass a URL to a publicly accessible file. Or, for smaller files (<10MB), you can pass the data directly.
7171

7272

7373
```js
7474
import { promises as fs } from "fs";
7575

76-
// Read the file into a buffer
77-
const data = await fs.readFile("path/to/image.png");
78-
// Convert the buffer into a base64-encoded string
79-
const base64 = data.toString("base64");
80-
// Set MIME type for PNG image
81-
const mimeType = "image/png";
82-
// Create the data URI
83-
const dataURI = `data:${mimeType};base64,${base64}`;
84-
8576
const model = "nightmareai/real-esrgan:42fed1c4974146d4d2414e2be2c5277c7fcf05fcc3a73abf41610695738c1d7b";
8677
const input = {
87-
image: dataURI,
78+
image: await fs.readFile("path/to/image.png"),
8879
};
8980
const output = await replicate.run(model, { input });
9081
// ['https://replicate.delivery/mgxm/e7b0e122-9daa-410e-8cde-006c7308ff4d/output.png']

index.test.ts

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,46 @@ describe("Replicate client", () => {
200200
expect(prediction.id).toBe("ufawqhfynnddngldkgtslldrkq");
201201
});
202202

203+
test.each([
204+
{
205+
type: "file",
206+
value: new File(["hello world"], "hello.txt", { type: "text/plain" }),
207+
expected: "data:text/plain;base64,aGVsbG8gd29ybGQ=",
208+
},
209+
{
210+
type: "blob",
211+
value: new Blob(["hello world"], { type: "text/plain" }),
212+
expected: "data:text/plain;base64,aGVsbG8gd29ybGQ=",
213+
},
214+
{
215+
type: "buffer",
216+
value: Buffer.from("hello world"),
217+
expected: "data:application/octet-stream;base64,aGVsbG8gd29ybGQ=",
218+
},
219+
])(
220+
"converts a $type input into a base64 encoded string",
221+
async ({ value: data, expected }) => {
222+
let body: Record<string, any>;
223+
nock(BASE_URL)
224+
.post("/predictions")
225+
.reply(201, (_uri, _body: Record<string, any>) => {
226+
return (body = _body);
227+
});
228+
229+
await client.predictions.create({
230+
version:
231+
"5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
232+
input: {
233+
prompt: "Tell me a story",
234+
data,
235+
},
236+
stream: true,
237+
});
238+
239+
expect(body!.input.data).toEqual(expected);
240+
}
241+
);
242+
203243
test("Passes stream parameter to API endpoint", async () => {
204244
nock(BASE_URL)
205245
.post("/predictions")

lib/deployments.js

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
const { transformFileInputs } = require("./util");
2+
13
/**
24
* Create a new prediction with a deployment
35
*
@@ -11,7 +13,7 @@
1113
* @returns {Promise<object>} Resolves with the created prediction data
1214
*/
1315
async function createPrediction(deployment_owner, deployment_name, options) {
14-
const { stream, ...data } = options;
16+
const { stream, input: _input, ...data } = options;
1517

1618
if (data.webhook) {
1719
try {
@@ -22,11 +24,14 @@ async function createPrediction(deployment_owner, deployment_name, options) {
2224
}
2325
}
2426

27+
// Transform any file looking fields into strings (either data uri or urls).
28+
const input = await transformFileInputs(_input);
29+
2530
const response = await this.request(
2631
`/deployments/${deployment_owner}/${deployment_name}/predictions`,
2732
{
2833
method: "POST",
29-
data: { ...data, stream },
34+
data: { ...data, input, stream },
3035
}
3136
);
3237

lib/predictions.js

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
const { transformFileInputs } = require("./util");
2+
13
/**
24
* Create a new prediction
35
*
@@ -11,7 +13,7 @@
1113
* @returns {Promise<object>} Resolves with the created prediction
1214
*/
1315
async function createPrediction(options) {
14-
const { model, version, stream, ...data } = options;
16+
const { model, version, stream, input: _input, ...data } = options;
1517

1618
if (data.webhook) {
1719
try {
@@ -22,16 +24,19 @@ async function createPrediction(options) {
2224
}
2325
}
2426

27+
// Transform any file looking fields into strings (either data uri or urls).
28+
const input = await transformFileInputs(_input);
29+
2530
let response;
2631
if (version) {
2732
response = await this.request("/predictions", {
2833
method: "POST",
29-
data: { ...data, stream, version },
34+
data: { ...data, stream, input, version },
3035
});
3136
} else if (model) {
3237
response = await this.request(`/models/${model}/predictions`, {
3338
method: "POST",
34-
data: { ...data, stream },
39+
data: { ...data, stream, input },
3540
});
3641
} else {
3742
throw new Error("Either model or version must be specified");

0 commit comments

Comments
 (0)