Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,23 @@ const response = await replicate.trainings.list();
}
```

### `replicate.deployments.predictions.create`

```js
const response = await replicate.deployments.predictions.create(deployment_owner, deployment_name, options);
```

| name | type | description |
| ------------------------------- | -------- | -------------------------------------------------------------------------------------------------------------------------------- |
| `deployment_owner` | string | **Required**. The name of the user or organization that owns the deployment |
| `deployment_name` | string | **Required**. The name of the deployment |
| `options.input` | object | **Required**. An object with the model's inputs |
| `options.webhook` | string | An HTTPS URL for receiving a webhook when the prediction has new output |
| `options.webhook_events_filter` | string[] | You can change which events trigger webhook requests by specifying webhook events (`start` \| `output` \| `logs` \| `completed`) |

Use `replicate.wait` to wait for a prediction to finish,
or `replicate.predictions.cancel` to cancel a prediction before it finishes.

### `replicate.paginate`

Pass another method as an argument to iterate over results
Expand Down
17 changes: 16 additions & 1 deletion index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ declare module 'replicate' {
logs?: string;
metrics?: {
predict_time?: number;
}
};
webhook?: string;
webhook_events_filter?: WebhookEventType[];
created_at: string;
Expand Down Expand Up @@ -156,5 +156,20 @@ declare module 'replicate' {
cancel(training_id: string): Promise<Training>;
list(): Promise<Page<Training>>;
};

deployments: {
predictions: {
create(
deployment_name: string,
deployment_owner: string,
options: {
input: object;
stream?: boolean;
webhook?: string;
webhook_events_filter?: WebhookEventType[];
}
): Promise<Prediction>;
};
};
}
}
7 changes: 7 additions & 0 deletions index.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ const ApiError = require('./lib/error');
const { withAutomaticRetries } = require('./lib/util');

const collections = require('./lib/collections');
const deployments = require('./lib/deployments');
const models = require('./lib/models');
const predictions = require('./lib/predictions');
const trainings = require('./lib/trainings');
Expand Down Expand Up @@ -69,6 +70,12 @@ class Replicate {
cancel: trainings.cancel.bind(this),
list: trainings.list.bind(this),
};

this.deployments = {
predictions: {
create: deployments.predictions.create.bind(this),
}
};
}

/**
Expand Down
37 changes: 37 additions & 0 deletions index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,43 @@ describe('Replicate client', () => {
});
});

describe('deployments.predictions.create', () => {
test('Calls the correct API route with the correct payload', async () => {
nock(BASE_URL)
.post('/deployments/replicate/greeter/predictions')
.reply(200, {
id: 'mfrgcyzzme2wkmbwgzrgmntcg',
version:
'5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa',
urls: {
get: 'https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq',
cancel:
'https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel',
},
created_at: '2022-09-10T09:44:22.165836Z',
started_at: null,
completed_at: null,
status: 'starting',
input: {
text: 'Alice',
},
output: null,
error: null,
logs: null,
metrics: {},
});
const prediction = await client.deployments.predictions.create("replicate", "greeter", {
input: {
text: 'Alice',
},
webhook: 'http://test.host/webhook',
webhook_events_filter: [ 'output', 'completed' ],
});
expect(prediction.id).toBe('mfrgcyzzme2wkmbwgzrgmntcg');
});
// Add more tests for error handling, edge cases, etc.
});

describe('run', () => {
test('Calls the correct API routes', async () => {
let firstPollingRequest = true;
Expand Down
37 changes: 37 additions & 0 deletions lib/deployments.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/**
* Create a new prediction with a deployment
*
* @param {string} deployment_owner - Required. The username of the user or organization who owns the deployment
* @param {string} deployment_name - Required. The name of the deployment
* @param {object} options
* @param {object} options.input - Required. An object with the model inputs
* @param {boolean} [options.stream] - Whether to stream the prediction output. Defaults to false
* @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output
* @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
* @returns {Promise<object>} Resolves with the created prediction data
*/
async function createPrediction(deployment_owner, deployment_name, options) {
const { stream, ...data } = options;

if (data.webhook) {
try {
// eslint-disable-next-line no-new
new URL(data.webhook);
} catch (err) {
throw new Error('Invalid webhook URL');
}
}

const response = await this.request(`/deployments/${deployment_owner}/${deployment_name}/predictions`, {
method: 'POST',
data: { ...data, stream },
});

return response.json();
}

module.exports = {
predictions: {
create: createPrediction,
}
};