Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,18 @@ declare module 'replicate' {
version_id: string
): Promise<ModelVersion>;
};
predictions: {
create(
model_owner: string,
model_name: string,
options: {
input: object;
stream?: boolean;
webhook?: string;
webhook_events_filter?: WebhookEventType[];
}
): Promise<Prediction>;
};
};

predictions: {
Expand Down
3 changes: 3 additions & 0 deletions index.js
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ class Replicate {
list: models.versions.list.bind(this),
get: models.versions.get.bind(this),
},
predictions: {
create: models.predictions.create.bind(this),
},
};

this.predictions = {
Expand Down
32 changes: 32 additions & 0 deletions index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,38 @@ describe('Replicate client', () => {
// Add more tests for error handling, edge cases, etc.
});

describe('models.predictions.create', () => {
test('Calls the correct API route with the correct payload', async () => {
nock(BASE_URL)
.post('/models/meta/llama-2-70b-chat/predictions')
.reply(200, {
id: "heat2o3bzn3ahtr6bjfftvbaci",
model: "replicate/lifeboat-70b",
version: "d-c6559c5791b50af57b69f4a73f8e021c",
input: {
prompt: "Please write a haiku about llamas."
},
logs: "",
error: null,
status: "starting",
created_at: "2023-11-27T13:35:45.99397566Z",
urls: {
cancel: "https://api.replicate.com/v1/predictions/heat2o3bzn3ahtr6bjfftvbaci/cancel",
get: "https://api.replicate.com/v1/predictions/heat2o3bzn3ahtr6bjfftvbaci"
}
});
const prediction = await client.models.predictions.create("meta", "llama-2-70b-chat", {
input: {
prompt: "Please write a haiku about llamas."
},
webhook: 'http://test.host/webhook',
webhook_events_filter: [ 'output', 'completed' ],
});
expect(prediction.id).toBe('heat2o3bzn3ahtr6bjfftvbaci');
});
// Add more tests for error handling, edge cases, etc.
});

describe('hardware.list', () => {
test('Calls the correct API route', async () => {
nock(BASE_URL)
Expand Down
24 changes: 24 additions & 0 deletions lib/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,33 @@ async function createModel(model_owner, model_name, options) {
return response.json();
}

/**
* Create a new prediction
*
* @param {string} model_owner - Required. The name of the user or organization that owns the model
* @param {string} model_name - Required. The name of the model
* @param {object} options
* @param {object} options.input - Required. An object with the model inputs
* @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output
* @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
* @param {boolean} [options.stream] - Whether to stream the prediction output. Defaults to false
* @returns {Promise<object>} Resolves with the created prediction
*/
async function createPrediction(model_owner, model_name, options) {
const { stream, ...data } = options;

const response = await this.request(`/models/${model_owner}/${model_name}/predictions`, {
method: 'POST',
data: { ...data, stream },
});

return response.json();
}

module.exports = {
get: getModel,
list: listModels,
create: createModel,
versions: { list: listModelVersions, get: getModelVersion },
predictions: { create: createPrediction },
};