Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 10 additions & 19 deletions index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -188,28 +188,19 @@ declare module "replicate" {
version_id: string
): Promise<ModelVersion>;
};
predictions: {
create(
model_owner: string,
model_name: string,
options: {
input: object;
stream?: boolean;
webhook?: string;
webhook_events_filter?: WebhookEventType[];
}
): Promise<Prediction>;
};
};

predictions: {
create(options: {
version: string;
input: object;
stream?: boolean;
webhook?: string;
webhook_events_filter?: WebhookEventType[];
}): Promise<Prediction>;
create(
options: {
model?: string;
version?: string;
input: object;
stream?: boolean;
webhook?: string;
webhook_events_filter?: WebhookEventType[];
} & ({ version: string } | { model: string })
): Promise<Prediction>;
get(prediction_id: string): Promise<Prediction>;
cancel(prediction_id: string): Promise<Prediction>;
list(): Promise<Page<Prediction>>;
Expand Down
26 changes: 13 additions & 13 deletions index.js
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,6 @@ class Replicate {
list: models.versions.list.bind(this),
get: models.versions.get.bind(this),
},
predictions: {
create: models.predictions.create.bind(this),
},
};

this.predictions = {
Expand Down Expand Up @@ -117,12 +114,13 @@ class Replicate {
...data,
version: identifier.version,
});
} else if (identifier.owner && identifier.name) {
prediction = await this.predictions.create({
...data,
model: `${identifier.owner}/${identifier.name}`,
});
} else {
prediction = await this.models.predictions.create(
identifier.owner,
identifier.name,
data
);
throw new Error("Invalid model version identifier");
}

// Call progress callback with the initial prediction object
Expand Down Expand Up @@ -260,12 +258,14 @@ class Replicate {
version: identifier.version,
stream: true,
});
} else if (identifier.owner && identifier.name) {
prediction = await this.predictions.create({
...data,
model: `${identifier.owner}/${identifier.name}`,
stream: true,
});
} else {
prediction = await this.models.predictions.create(
identifier.owner,
identifier.name,
{ ...data, stream: true }
);
throw new Error("Invalid model version identifier");
}

if (prediction.urls && prediction.urls.stream) {
Expand Down
21 changes: 9 additions & 12 deletions index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,7 @@ describe("Replicate client", () => {
// Add more tests for error handling, edge cases, etc.
});

describe("models.predictions.create", () => {
describe("predictions.create with model", () => {
test("Calls the correct API route with the correct payload", async () => {
nock(BASE_URL)
.post("/models/meta/llama-2-70b-chat/predictions")
Expand All @@ -721,17 +721,14 @@ describe("Replicate client", () => {
get: "https://api.replicate.com/v1/predictions/heat2o3bzn3ahtr6bjfftvbaci",
},
});
const prediction = await client.models.predictions.create(
"meta",
"llama-2-70b-chat",
{
input: {
prompt: "Please write a haiku about llamas.",
},
webhook: "http://test.host/webhook",
webhook_events_filter: ["output", "completed"],
}
);
const prediction = await client.predictions.create({
model: "meta/llama-2-70b-chat",
input: {
prompt: "Please write a haiku about llamas.",
},
webhook: "http://test.host/webhook",
webhook_events_filter: ["output", "completed"],
});
expect(prediction.id).toBe("heat2o3bzn3ahtr6bjfftvbaci");
});
// Add more tests for error handling, edge cases, etc.
Expand Down
27 changes: 0 additions & 27 deletions lib/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -89,36 +89,9 @@ async function createModel(model_owner, model_name, options) {
return response.json();
}

/**
* Create a new prediction
*
* @param {string} model_owner - Required. The name of the user or organization that owns the model
* @param {string} model_name - Required. The name of the model
* @param {object} options
* @param {object} options.input - Required. An object with the model inputs
* @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output
* @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
* @param {boolean} [options.stream] - Whether to stream the prediction output. Defaults to false
* @returns {Promise<object>} Resolves with the created prediction
*/
async function createPrediction(model_owner, model_name, options) {
const { stream, ...data } = options;

const response = await this.request(
`/models/${model_owner}/${model_name}/predictions`,
{
method: "POST",
data: { ...data, stream },
}
);

return response.json();
}

module.exports = {
get: getModel,
list: listModels,
create: createModel,
versions: { list: listModelVersions, get: getModelVersion },
predictions: { create: createPrediction },
};
23 changes: 17 additions & 6 deletions lib/predictions.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@
* Create a new prediction
*
* @param {object} options
* @param {string} options.version - Required. The model version
* @param {string} options.model - The model.
* @param {string} options.version - The model version.
* @param {object} options.input - Required. An object with the model inputs
* @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output
* @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
* @param {boolean} [options.stream] - Whether to stream the prediction output. Defaults to false
* @returns {Promise<object>} Resolves with the created prediction
*/
async function createPrediction(options) {
const { stream, ...data } = options;
const { model, version, stream, ...data } = options;

if (data.webhook) {
try {
Expand All @@ -21,10 +22,20 @@ async function createPrediction(options) {
}
}

const response = await this.request("/predictions", {
method: "POST",
data: { ...data, stream },
});
let response;
if (version) {
response = await this.request("/predictions", {
method: "POST",
data: { ...data, stream, version },
});
} else if (model) {
response = await this.request(`/models/${model}/predictions`, {
method: "POST",
data: { ...data, stream },
});
} else {
throw new Error("Either model or version must be specified");
}

return response.json();
}
Expand Down