diff --git a/.changeset/fair-plums-grin.md b/.changeset/fair-plums-grin.md new file mode 100644 index 0000000000..31579ad8c2 --- /dev/null +++ b/.changeset/fair-plums-grin.md @@ -0,0 +1,13 @@ +--- +"@trigger.dev/replicate": patch +"@trigger.dev/airtable": patch +"@trigger.dev/sendgrid": patch +"@trigger.dev/sdk": patch +"@trigger.dev/github": patch +"@trigger.dev/linear": patch +"@trigger.dev/resend": patch +"@trigger.dev/slack": patch +"@trigger.dev/core": patch +--- + +First release of `@trigger.dev/replicate` integration with remote callback support. diff --git a/apps/webapp/app/models/task.server.ts b/apps/webapp/app/models/task.server.ts index ac468779af..b674dd0a42 100644 --- a/apps/webapp/app/models/task.server.ts +++ b/apps/webapp/app/models/task.server.ts @@ -23,6 +23,7 @@ export function taskWithAttemptsToServerTask(task: TaskWithAttempts): ServerTask attempts: task.attempts.length, idempotencyKey: task.idempotencyKey, operation: task.operation, + callbackUrl: task.callbackUrl, }; } diff --git a/apps/webapp/app/routes/api.v1.runs.$runId.tasks.$id.callback.$secret.ts b/apps/webapp/app/routes/api.v1.runs.$runId.tasks.$id.callback.$secret.ts new file mode 100644 index 0000000000..5c1425e2a0 --- /dev/null +++ b/apps/webapp/app/routes/api.v1.runs.$runId.tasks.$id.callback.$secret.ts @@ -0,0 +1,124 @@ +import type { ActionArgs } from "@remix-run/server-runtime"; +import { json } from "@remix-run/server-runtime"; +import { RuntimeEnvironmentType } from "@trigger.dev/database"; +import { z } from "zod"; +import { $transaction, PrismaClient, PrismaClientOrTransaction, prisma } from "~/db.server"; +import { enqueueRunExecutionV2 } from "~/models/jobRunExecution.server"; +import { logger } from "~/services/logger.server"; + +const ParamsSchema = z.object({ + runId: z.string(), + id: z.string(), + secret: z.string(), +}); + +export async function action({ request, params }: ActionArgs) { + // Ensure this is a POST request + if (request.method.toUpperCase() !== "POST") { + return { status: 405, body: "Method Not Allowed" }; + } + + const { runId, id } = ParamsSchema.parse(params); + + // Parse body as JSON (no schema parsing) + const body = await request.json(); + + const service = new CallbackRunTaskService(); + + try { + // Complete task with request body as output + await service.call(runId, id, body, request.url); + + return json({ success: true }); + } catch (error) { + if (error instanceof Error) { + logger.error("Error while processing task callback:", { error }); + } + + return json({ error: "Something went wrong" }, { status: 500 }); + } +} + +export class CallbackRunTaskService { + #prismaClient: PrismaClient; + + constructor(prismaClient: PrismaClient = prisma) { + this.#prismaClient = prismaClient; + } + + public async call(runId: string, id: string, taskBody: any, callbackUrl: string): Promise { + const task = await findTask(prisma, id); + + if (!task) { + return; + } + + if (task.runId !== runId) { + return; + } + + if (task.status !== "WAITING") { + return; + } + + if (!task.callbackUrl) { + return; + } + + if (new URL(task.callbackUrl).pathname !== new URL(callbackUrl).pathname) { + logger.error("Callback URLs don't match", { runId, taskId: id, callbackUrl }); + return; + } + + logger.debug("CallbackRunTaskService.call()", { task }); + + await this.#resumeTask(task, taskBody); + } + + async #resumeTask(task: NonNullable, output: any) { + await $transaction(this.#prismaClient, async (tx) => { + await tx.taskAttempt.updateMany({ + where: { + taskId: task.id, + status: "PENDING", + }, + data: { + status: "COMPLETED", + }, + }); + + await tx.task.update({ + where: { id: task.id }, + data: { + status: "COMPLETED", + completedAt: new Date(), + output: output ? output : undefined, + }, + }); + + await this.#resumeRunExecution(task, tx); + }); + } + + async #resumeRunExecution(task: NonNullable, prisma: PrismaClientOrTransaction) { + await enqueueRunExecutionV2(task.run, prisma, { + skipRetrying: task.run.environment.type === RuntimeEnvironmentType.DEVELOPMENT, + }); + } +} + +type FoundTask = Awaited>; + +async function findTask(prisma: PrismaClientOrTransaction, id: string) { + return prisma.task.findUnique({ + where: { id }, + include: { + run: { + include: { + environment: true, + queue: true, + }, + }, + }, + }); +} diff --git a/apps/webapp/app/routes/api.v1.runs.$runId.tasks.ts b/apps/webapp/app/routes/api.v1.runs.$runId.tasks.ts index 54c9645348..94a6ba8caf 100644 --- a/apps/webapp/app/routes/api.v1.runs.$runId.tasks.ts +++ b/apps/webapp/app/routes/api.v1.runs.$runId.tasks.ts @@ -15,6 +15,8 @@ import { authenticateApiRequest } from "~/services/apiAuth.server"; import { logger } from "~/services/logger.server"; import { ulid } from "~/services/ulid.server"; import { workerQueue } from "~/services/worker.server"; +import { generateSecret } from "~/services/sources/utils.server"; +import { env } from "~/env.server"; const ParamsSchema = z.object({ runId: z.string(), @@ -185,10 +187,13 @@ export class RunTaskService { }, }); + const delayUntilInFuture = taskBody.delayUntil && taskBody.delayUntil.getTime() > Date.now(); + const callbackEnabled = taskBody.callback?.enabled; + if (existingTask) { if (existingTask.status === "CANCELED") { const existingTaskStatus = - (taskBody.delayUntil && taskBody.delayUntil.getTime() > Date.now()) || taskBody.trigger + delayUntilInFuture || callbackEnabled || taskBody.trigger ? "WAITING" : taskBody.noop ? "COMPLETED" @@ -233,16 +238,21 @@ export class RunTaskService { status = "CANCELED"; } else { status = - (taskBody.delayUntil && taskBody.delayUntil.getTime() > Date.now()) || taskBody.trigger + delayUntilInFuture || callbackEnabled || taskBody.trigger ? "WAITING" : taskBody.noop ? "COMPLETED" : "RUNNING"; } + const taskId = ulid(); + const callbackUrl = callbackEnabled + ? `${env.APP_ORIGIN}/api/v1/runs/${runId}/tasks/${taskId}/callback/${generateSecret(12)}` + : undefined; + const task = await tx.task.create({ data: { - id: ulid(), + id: taskId, idempotencyKey, displayKey: taskBody.displayKey, runConnection: taskBody.connectionKey @@ -273,6 +283,7 @@ export class RunTaskService { properties: taskBody.properties ?? undefined, redact: taskBody.redact ?? undefined, operation: taskBody.operation, + callbackUrl, style: taskBody.style ?? { style: "normal" }, attempts: { create: { @@ -296,6 +307,17 @@ export class RunTaskService { }, { tx, runAt: task.delayUntil ?? undefined } ); + } else if (task.status === "WAITING" && callbackUrl && taskBody.callback) { + if (taskBody.callback.timeoutInSeconds > 0) { + // We need to schedule the callback timeout + await workerQueue.enqueue( + "processCallbackTimeout", + { + id: task.id, + }, + { tx, runAt: new Date(Date.now() + taskBody.callback.timeoutInSeconds * 1000) } + ); + } } return task; diff --git a/apps/webapp/app/services/externalApis/integrationCatalog.server.ts b/apps/webapp/app/services/externalApis/integrationCatalog.server.ts index b86c2e496c..3c13f154ac 100644 --- a/apps/webapp/app/services/externalApis/integrationCatalog.server.ts +++ b/apps/webapp/app/services/externalApis/integrationCatalog.server.ts @@ -3,6 +3,7 @@ import { github } from "./integrations/github"; import { linear } from "./integrations/linear"; import { openai } from "./integrations/openai"; import { plain } from "./integrations/plain"; +import { replicate } from "./integrations/replicate"; import { resend } from "./integrations/resend"; import { sendgrid } from "./integrations/sendgrid"; import { slack } from "./integrations/slack"; @@ -37,6 +38,7 @@ export const integrationCatalog = new IntegrationCatalog({ linear, openai, plain, + replicate, resend, slack, stripe, diff --git a/apps/webapp/app/services/externalApis/integrations/replicate.ts b/apps/webapp/app/services/externalApis/integrations/replicate.ts new file mode 100644 index 0000000000..74f20cdafd --- /dev/null +++ b/apps/webapp/app/services/externalApis/integrations/replicate.ts @@ -0,0 +1,50 @@ +import type { HelpSample, Integration } from "../types"; + +function usageSample(hasApiKey: boolean): HelpSample { + const apiKeyPropertyName = "apiKey"; + + return { + title: "Using the client", + code: ` +import { Replicate } from "@trigger.dev/replicate"; + +const replicate = new Replicate({ + id: "__SLUG__",${hasApiKey ? `,\n ${apiKeyPropertyName}: process.env.REPLICATE_API_KEY!` : ""} +}); + +client.defineJob({ + id: "replicate-create-prediction", + name: "Replicate - Create Prediction", + version: "0.1.0", + integrations: { replicate }, + trigger: eventTrigger({ + name: "replicate.predict", + schema: z.object({ + prompt: z.string(), + version: z.string(), + }), + }), + run: async (payload, io, ctx) => { + return io.replicate.predictions.createAndAwait("await-prediction", { + version: payload.version, + input: { prompt: payload.prompt }, + }); + }, +}); + `, + }; +} + +export const replicate: Integration = { + identifier: "replicate", + name: "Replicate", + packageName: "@trigger.dev/replicate@latest", + authenticationMethods: { + apikey: { + type: "apikey", + help: { + samples: [usageSample(true)], + }, + }, + }, +}; diff --git a/apps/webapp/app/services/runs/performRunExecutionV1.server.ts b/apps/webapp/app/services/runs/performRunExecutionV1.server.ts index 7cfdf2d966..c67a9dd180 100644 --- a/apps/webapp/app/services/runs/performRunExecutionV1.server.ts +++ b/apps/webapp/app/services/runs/performRunExecutionV1.server.ts @@ -449,7 +449,9 @@ export class PerformRunExecutionV1Service { // If the task has an operation, then the next performRunExecution will occur // when that operation has finished - if (!data.task.operation) { + // Tasks with callbacks enabled will also get processed separately, i.e. when + // they time out, or on valid requests to their callbackUrl + if (!data.task.operation && !data.task.callbackUrl) { const newJobExecution = await tx.jobRunExecution.create({ data: { runId: run.id, diff --git a/apps/webapp/app/services/runs/performRunExecutionV2.server.ts b/apps/webapp/app/services/runs/performRunExecutionV2.server.ts index 1af87b8c30..f22d13edb1 100644 --- a/apps/webapp/app/services/runs/performRunExecutionV2.server.ts +++ b/apps/webapp/app/services/runs/performRunExecutionV2.server.ts @@ -530,7 +530,9 @@ export class PerformRunExecutionV2Service { // If the task has an operation, then the next performRunExecution will occur // when that operation has finished - if (!data.task.operation) { + // Tasks with callbacks enabled will also get processed separately, i.e. when + // they time out, or on valid requests to their callbackUrl + if (!data.task.operation && !data.task.callbackUrl) { await enqueueRunExecutionV2(run, tx, { runAt: data.task.delayUntil ?? undefined, resumeTaskId: data.task.id, diff --git a/apps/webapp/app/services/sources/utils.server.ts b/apps/webapp/app/services/sources/utils.server.ts index 127ca4b7ab..4c2bc7ae7b 100644 --- a/apps/webapp/app/services/sources/utils.server.ts +++ b/apps/webapp/app/services/sources/utils.server.ts @@ -1,5 +1,5 @@ import crypto from "node:crypto"; -export function generateSecret(): string { - return crypto.randomBytes(32).toString("hex"); +export function generateSecret(sizeInBytes = 32): string { + return crypto.randomBytes(sizeInBytes).toString("hex"); } diff --git a/apps/webapp/app/services/tasks/processCallbackTimeout.ts b/apps/webapp/app/services/tasks/processCallbackTimeout.ts new file mode 100644 index 0000000000..948691990d --- /dev/null +++ b/apps/webapp/app/services/tasks/processCallbackTimeout.ts @@ -0,0 +1,76 @@ +import { RuntimeEnvironmentType } from "@trigger.dev/database"; +import { $transaction, PrismaClient, PrismaClientOrTransaction, prisma } from "~/db.server"; +import { enqueueRunExecutionV2 } from "~/models/jobRunExecution.server"; +import { logger } from "../logger.server"; + +type FoundTask = Awaited>; + +export class ProcessCallbackTimeoutService { + #prismaClient: PrismaClient; + + constructor(prismaClient: PrismaClient = prisma) { + this.#prismaClient = prismaClient; + } + + public async call(id: string) { + const task = await findTask(this.#prismaClient, id); + + if (!task) { + return; + } + + if (task.status !== "WAITING" || !task.callbackUrl) { + return; + } + + logger.debug("ProcessCallbackTimeoutService.call", { task }); + + return await this.#failTask(task, "Remote callback timeout - no requests received"); + } + + async #failTask(task: NonNullable, error: string) { + await $transaction(this.#prismaClient, async (tx) => { + await tx.taskAttempt.updateMany({ + where: { + taskId: task.id, + status: "PENDING", + }, + data: { + status: "ERRORED", + error + }, + }); + + await tx.task.update({ + where: { id: task.id }, + data: { + status: "ERRORED", + completedAt: new Date(), + output: error, + }, + }); + + await this.#resumeRunExecution(task, tx); + }); + } + + async #resumeRunExecution(task: NonNullable, prisma: PrismaClientOrTransaction) { + await enqueueRunExecutionV2(task.run, prisma, { + skipRetrying: task.run.environment.type === RuntimeEnvironmentType.DEVELOPMENT, + }); + } +} + +async function findTask(prisma: PrismaClient, id: string) { + return prisma.task.findUnique({ + where: { id }, + include: { + run: { + include: { + environment: true, + queue: true, + }, + }, + }, + }); +} diff --git a/apps/webapp/app/services/worker.server.ts b/apps/webapp/app/services/worker.server.ts index 59c5e8b52c..07493aab51 100644 --- a/apps/webapp/app/services/worker.server.ts +++ b/apps/webapp/app/services/worker.server.ts @@ -19,6 +19,7 @@ import { DeliverScheduledEventService } from "./schedules/deliverScheduledEvent. import { ActivateSourceService } from "./sources/activateSource.server"; import { DeliverHttpSourceRequestService } from "./sources/deliverHttpSourceRequest.server"; import { PerformTaskOperationService } from "./tasks/performTaskOperation.server"; +import { ProcessCallbackTimeoutService } from "./tasks/processCallbackTimeout"; import { addMissingVersionField } from "@trigger.dev/core"; const workerCatalog = { @@ -30,6 +31,9 @@ const workerCatalog = { }), scheduleEmail: DeliverEmailSchema, startRun: z.object({ id: z.string() }), + processCallbackTimeout: z.object({ + id: z.string(), + }), performTaskOperation: z.object({ id: z.string(), }), @@ -240,6 +244,15 @@ function getWorkerQueue() { await service.call(payload.id); }, }, + processCallbackTimeout: { + priority: 0, // smaller number = higher priority + maxAttempts: 3, + handler: async (payload, job) => { + const service = new ProcessCallbackTimeoutService(); + + await service.call(payload.id); + }, + }, performTaskOperation: { priority: 0, // smaller number = higher priority queueName: (payload) => `tasks:${payload.id}`, diff --git a/config-packages/tsconfig/integration.json b/config-packages/tsconfig/integration.json new file mode 100644 index 0000000000..753e6091dd --- /dev/null +++ b/config-packages/tsconfig/integration.json @@ -0,0 +1,17 @@ +{ + "extends": "./node18.json", + "compilerOptions": { + "lib": ["DOM", "DOM.Iterable", "ES2019"], + "paths": { + "@trigger.dev/sdk/*": ["../../packages/trigger-sdk/src/*"], + "@trigger.dev/sdk": ["../../packages/trigger-sdk/src/index"], + "@trigger.dev/integration-kit/*": ["../../packages/integration-kit/src/*"], + "@trigger.dev/integration-kit": ["../../packages/integration-kit/src/index"] + }, + "declaration": false, + "declarationMap": false, + "baseUrl": ".", + "stripInternal": true + }, + "exclude": ["node_modules"] +} diff --git a/docs/integrations/apis/replicate.mdx b/docs/integrations/apis/replicate.mdx new file mode 100644 index 0000000000..978bb51c19 --- /dev/null +++ b/docs/integrations/apis/replicate.mdx @@ -0,0 +1,170 @@ +--- +title: Replicate +description: "Run machine learning tasks easily at scale" +--- + + + +## Installation + +To get started with the Replicate integration on Trigger.dev, you need to install the `@trigger.dev/replicate` package. +You can do this using npm, pnpm, or yarn: + + + +```bash npm +npm install @trigger.dev/replicate@latest +``` + +```bash pnpm +pnpm add @trigger.dev/replicate@latest +``` + +```bash yarn +yarn add @trigger.dev/replicate@latest +``` + + + +## Authentication + +To use the Replicate API with Trigger.dev, you have to provide an API Key. + +### API Key + +You can create an API Key in your [Account Settings](https://replicate.com/account/api-tokens). + +```ts +import { Replicate } from "@trigger.dev/replicate"; + +//this will use the passed in API key (defined in your environment variables) +const replicate = new Replicate({ + id: "replicate", + apiKey: process.env["REPLICATE_API_KEY"], +}); +``` + +## Usage + +Include the Replicate integration in your Trigger.dev job. + +```ts +client.defineJob({ + id: "replicate-cinematic-prompt", + name: "Replicate - Cinematic Prompt", + version: "0.1.0", + integrations: { replicate }, + trigger: eventTrigger({ + name: "replicate.cinematic", + schema: z.object({ + prompt: z.string().default("rick astley riding a harley through post-apocalyptic miami"), + version: z + .string() + .default("af1a68a271597604546c09c64aabcd7782c114a63539a4a8d14d1eeda5630c33"), + }), + }), + run: async (payload, io, ctx) => { + //wait for prediction completion (uses remote callbacks internally) + const prediction = await io.replicate.predictions.createAndAwait("await-prediction", { + version: payload.version, + input: { + prompt: `${payload.prompt}, cinematic, 70mm, anamorphic, bokeh`, + width: 1280, + height: 720, + }, + }); + return prediction.output; + }, +}); +``` + +### Pagination + +You can paginate responses: + +- Using the `getAll` helper +- Using the `paginate` helper + +```ts +client.defineJob({ + id: "replicate-pagination", + name: "Replicate Pagination", + version: "0.1.0", + integrations: { + replicate, + }, + trigger: eventTrigger({ + name: "replicate.paginate", + }), + run: async (payload, io, ctx) => { + // getAll - returns an array of all results (uses paginate internally) + const all = await io.replicate.getAll(io.replicate.predictions.list, "get-all"); + + // paginate - returns an async generator, useful to process one page at a time + for await (const predictions of io.replicate.paginate( + io.replicate.predictions.list, + "paginate-all" + )) { + await io.logger.info("stats", { + total: predictions.length, + versions: predictions.map((p) => p.version), + }); + } + + return { count: all.length }; + }, +}); +``` + +## Tasks + +### Collections + +| Function Name | Description | +| ------------------ | ---------------------------------------------------------------------- | +| `collections.get` | Gets a collection. | +| `collections.list` | Returns the first page of all collections. Use with pagination helper. | + +### Deployments + +| Function Name | Description | +| ---------------------------------------- | --------------------------------------------------------- | +| `deployments.predictions.create` | Creates a new prediction with a deployment. | +| `deployments.predictions.createAndAwait` | Creates and waits for a new prediction with a deployment. | + +### Models + +| Function Name | Description | +| ----------------- | ------------------------ | +| `models.get` | Gets a model. | +| `models.versions` | Gets a model version. | +| `models.versions` | Gets all model versions. | + +### Predictions + +| Function Name | Description | +| ---------------------------- | ---------------------------------------------------------------------- | +| `predictions.cancel` | Cancels a prediction. | +| `predictions.create` | Creates a prediction. | +| `predictions.createAndAwait` | Creates and waits for a prediction. | +| `predictions.get` | Gets a prediction. | +| `predictions.list` | Returns the first page of all predictions. Use with pagination helper. | + +### Trainings + +| Function Name | Description | +| -------------------------- | -------------------------------------------------------------------- | +| `trainings.cancel` | Cancels a training. | +| `trainings.create` | Creates a training. | +| `trainings.createAndAwait` | Creates and waits for a training. | +| `trainings.get` | Gets a training. | +| `trainings.list` | Returns the first page of all trainings. Use with pagination helper. | + +### Misc + +| Function Name | Description | +| ------------- | --------------------------------------------------- | +| `getAll` | Pagination helper that returns an array of results. | +| `paginate` | Pagination helper that returns an async generator. | +| `request` | Sends authenticated requests to the Replicate API. | +| `run` | Creates and waits for a prediction. | diff --git a/docs/integrations/create-tasks.mdx b/docs/integrations/create-tasks.mdx index fe5386a181..2ed3b09dfc 100644 --- a/docs/integrations/create-tasks.mdx +++ b/docs/integrations/create-tasks.mdx @@ -24,7 +24,7 @@ export class Github implements TriggerIntegration { if (!this._io) throw new Error("No IO"); if (!this._connectionKey) throw new Error("No connection key"); - return this._io.runTask( + return this._io.runTask( key, (task, io) => { if (!this._client) throw new Error("No client"); diff --git a/docs/integrations/introduction.mdx b/docs/integrations/introduction.mdx index 30be9c1e36..688463510c 100644 --- a/docs/integrations/introduction.mdx +++ b/docs/integrations/introduction.mdx @@ -30,14 +30,15 @@ description: "Integrations make it easy to authenticate and use APIs." Navigate the menu or select Integrations from the table below. -| API | Description | Webhooks | Tasks | -| --------------------------------------- | ---------------------------------------------------------------- | -------- | ----- | -| [GitHub](/integrations/apis/github) | Subscribe to webhooks and perform actions | ✅ | ✅ | -| [Linear](/integrations/apis/linear) | Streamline project and issue tracking | ✅ | ✅ | -| [OpenAI](/integrations/apis/openai) | Generate text and images. Including longer than 30s prompts | N/A | ✅ | -| [Plain](/integrations/apis/plain) | Perform customer support using Plain | 🕘 | ✅ | -| [Resend](/integrations/apis/resend) | Send emails using Resend | 🕘 | ✅ | -| [SendGrid](/integrations/apis/sendgrid) | Send emails using SendGrid | 🕘 | ✅ | -| [Slack](/integrations/apis/slack) | Send Slack messages | 🕘 | ✅ | -| [Supabase](/integrations/apis/supabase) | Interact with your projects and databases | ✅ | ✅ | -| [Typeform](/integrations/apis/typeform) | Interact with the Typeform API and get notified of new responses | ✅ | ✅ | +| API | Description | Webhooks | Tasks | +| ----------------------------------------- | ---------------------------------------------------------------- | -------- | ----- | +| [GitHub](/integrations/apis/github) | Subscribe to webhooks and perform actions | ✅ | ✅ | +| [Linear](/integrations/apis/linear) | Streamline project and issue tracking | ✅ | ✅ | +| [OpenAI](/integrations/apis/openai) | Generate text and images. Including longer than 30s prompts | N/A | ✅ | +| [Plain](/integrations/apis/plain) | Perform customer support using Plain | 🕘 | ✅ | +| [Replicate](/integrations/apis/replicate) | Run machine learning tasks easily at scale | N/A | ✅ | +| [Resend](/integrations/apis/resend) | Send emails using Resend | 🕘 | ✅ | +| [SendGrid](/integrations/apis/sendgrid) | Send emails using SendGrid | 🕘 | ✅ | +| [Slack](/integrations/apis/slack) | Send Slack messages | 🕘 | ✅ | +| [Supabase](/integrations/apis/supabase) | Interact with your projects and databases | ✅ | ✅ | +| [Typeform](/integrations/apis/typeform) | Interact with the Typeform API and get notified of new responses | ✅ | ✅ | diff --git a/docs/mint.json b/docs/mint.json index 7d7c3bd4b5..02e2cf50e8 100644 --- a/docs/mint.json +++ b/docs/mint.json @@ -247,6 +247,7 @@ "integrations/apis/linear", "integrations/apis/openai", "integrations/apis/plain", + "integrations/apis/replicate", "integrations/apis/resend", "integrations/apis/sendgrid", "integrations/apis/slack", diff --git a/docs/sdk/io/runtask.mdx b/docs/sdk/io/runtask.mdx index 5791429980..46a8d829d9 100644 --- a/docs/sdk/io/runtask.mdx +++ b/docs/sdk/io/runtask.mdx @@ -6,6 +6,8 @@ description: "`io.runTask()` allows you to run a [Task](/documentation/concepts/ A Task is a resumable unit of a Run that can be retried, resumed and is logged. [Integrations](/integrations) use Tasks internally to perform their actions. +The wrappers at `io.integration.runTask()` expose the underlying Integration client as the first callback parameter (see examples on the right). They will have defaults set for options and `onError` handlers, but should otherwise be considered identical to raw `io.runTask()`. + ## Parameters @@ -112,6 +114,22 @@ A Task is a resumable unit of a Run that can be retried, resumed and is logged. + + + An optional object that exposes settings for the remote callback feature. + + Enabling this feature will expose a `callbackUrl` property on the callback's Task parameter. Additionally, `io.runTask()` will now return a Promise that resolves with the body of the first request sent to that URL. + + + + Whether to enable the remote callback feature. + + + The value of the property. + + + + @@ -133,6 +151,8 @@ A Task is a resumable unit of a Run that can be retried, resumed and is logged. A Promise that resolves with the returned value of the callback. +If the remote callback feature `options.callback` is enabled, the Promise will instead resolve with the body of the first request sent to `task.callbackUrl`. + ```typescript Run a task @@ -150,11 +170,11 @@ client.defineJob({ }, run: async (payload, io, ctx) => { //runTask - const response = await io.runTask( + const response = await io.github.runTask( "create-card", - async () => { + async (client) => { //create a project card using the underlying GitHub Integration client - return io.github.client.rest.projects.createCard({ + return client.rest.projects.createCard({ column_id: 123, note: "test", }); @@ -201,4 +221,43 @@ client.defineJob({ }); ``` +```typescript Remote callbacks +client.defineJob({ + id: "remote-callback-example", + name: "Remote Callback example", + version: "0.1.1", + trigger: eventTrigger({ name: "predict" }), + integrations: { replicate }, + run: async (payload, io, ctx) => { + //runTask + const prediction = await io.replicate.runTask( + "create-and-await-prediction", + async (client, task) => { + //create a prediction using the underlying Replicate Integration client + await client.predictions.create({ + ...payload, + webhook: task.callbackUrl ?? "", + webhook_events_filter: ["completed"], + }); + //the actual return value will be the data sent to callbackUrl + //cast to the exact data type you expect to receive or `any` if unsure + return {} as Prediction; + }, + { + name: "Create and await Prediction", + icon: "replicate", + //remote callback settings + callback: { + enabled: true, + timeoutInSeconds: 300, + }, + } + ); + + //log the prediction output + await io.logger.info(prediction.output); + }, +}); +``` + diff --git a/integrations/airtable/src/index.ts b/integrations/airtable/src/index.ts index f8c2d99f15..b4d9d542cf 100644 --- a/integrations/airtable/src/index.ts +++ b/integrations/airtable/src/index.ts @@ -92,7 +92,7 @@ export class Airtable implements TriggerIntegration { if (!this._io) throw new Error("No IO"); if (!this._connectionKey) throw new Error("No connection key"); - return this._io.runTask( + return this._io.runTask( key, (task, io) => { if (!this._client) throw new Error("No client"); diff --git a/integrations/github/src/index.ts b/integrations/github/src/index.ts index 59e2dc968f..0716e849b5 100644 --- a/integrations/github/src/index.ts +++ b/integrations/github/src/index.ts @@ -138,7 +138,7 @@ export class Github implements TriggerIntegration { if (!this._io) throw new Error("No IO"); if (!this._connectionKey) throw new Error("No connection key"); - return this._io.runTask( + return this._io.runTask( key, (task, io) => { if (!this._client) throw new Error("No client"); diff --git a/integrations/linear/src/index.ts b/integrations/linear/src/index.ts index dde947e7d7..0f53f9e7ac 100644 --- a/integrations/linear/src/index.ts +++ b/integrations/linear/src/index.ts @@ -158,7 +158,7 @@ export class Linear implements TriggerIntegration { if (!this._io) throw new Error("No IO"); if (!this._connectionKey) throw new Error("No connection key"); - return this._io.runTask( + return this._io.runTask( key, (task, io) => { if (!this._client) throw new Error("No client"); diff --git a/integrations/replicate/README.md b/integrations/replicate/README.md new file mode 100644 index 0000000000..67a4b88e87 --- /dev/null +++ b/integrations/replicate/README.md @@ -0,0 +1 @@ +# @trigger.dev/replicate diff --git a/integrations/replicate/package.json b/integrations/replicate/package.json new file mode 100644 index 0000000000..9478a0a36e --- /dev/null +++ b/integrations/replicate/package.json @@ -0,0 +1,37 @@ +{ + "name": "@trigger.dev/replicate", + "version": "2.1.7", + "description": "Trigger.dev integration for replicate", + "main": "./dist/index.js", + "types": "./dist/index.d.ts", + "publishConfig": { + "access": "public" + }, + "files": [ + "dist/index.js", + "dist/index.d.ts", + "dist/index.js.map" + ], + "devDependencies": { + "@trigger.dev/tsconfig": "workspace:*", + "@types/node": "16.x", + "rimraf": "^3.0.2", + "tsup": "7.1.x", + "typescript": "4.9.4" + }, + "scripts": { + "clean": "rimraf dist", + "build": "npm run clean && npm run build:tsup", + "build:tsup": "tsup", + "typecheck": "tsc --noEmit" + }, + "dependencies": { + "@trigger.dev/integration-kit": "workspace:^2.1.0", + "@trigger.dev/sdk": "workspace:^2.1.0", + "replicate": "^0.18.1", + "zod": "3.21.4" + }, + "engines": { + "node": ">=16.8.0" + } +} \ No newline at end of file diff --git a/integrations/replicate/src/collections.ts b/integrations/replicate/src/collections.ts new file mode 100644 index 0000000000..b6f1c01001 --- /dev/null +++ b/integrations/replicate/src/collections.ts @@ -0,0 +1,37 @@ +import { IntegrationTaskKey } from "@trigger.dev/sdk"; +import { Page, Collection } from "replicate"; + +import { ReplicateRunTask } from "./index"; +import { ReplicateReturnType } from "./types"; + +export class Collections { + constructor(private runTask: ReplicateRunTask) {} + + /** Fetch a model collection. */ + get(key: IntegrationTaskKey, params: { slug: string }): ReplicateReturnType { + return this.runTask( + key, + (client) => { + return client.collections.get(params.slug); + }, + { + name: "Get Collection", + params, + properties: [{ label: "Collection Slug", text: params.slug }], + } + ); + } + + /** Fetch a list of model collections. */ + list(key: IntegrationTaskKey): ReplicateReturnType> { + return this.runTask( + key, + (client) => { + return client.collections.list(); + }, + { + name: "List Collections", + } + ); + } +} diff --git a/integrations/replicate/src/deployments.ts b/integrations/replicate/src/deployments.ts new file mode 100644 index 0000000000..c5c1508d12 --- /dev/null +++ b/integrations/replicate/src/deployments.ts @@ -0,0 +1,76 @@ +import { IntegrationTaskKey } from "@trigger.dev/sdk"; +import ReplicateClient, { Prediction } from "replicate"; + +import { ReplicateRunTask } from "./index"; +import { callbackProperties, createDeploymentProperties } from "./utils"; +import { CallbackTimeout, ReplicateReturnType } from "./types"; + +export class Deployments { + constructor(private runTask: ReplicateRunTask) {} + + get predictions() { + return new Predictions(this.runTask); + } +} + +class Predictions { + constructor(private runTask: ReplicateRunTask) {} + + /** Create a new prediction with a deployment. */ + create( + key: IntegrationTaskKey, + params: { + deployment_owner: string; + deployment_name: string; + } & Parameters[2] + ): ReplicateReturnType { + return this.runTask( + key, + (client) => { + const { deployment_owner, deployment_name, ...options } = params; + + return client.deployments.predictions.create(deployment_owner, deployment_name, options); + }, + { + name: "Create Prediction With Deployment", + params, + properties: createDeploymentProperties(params), + } + ); + } + + /** Create a new prediction with a deployment and await the result. */ + createAndAwait( + key: IntegrationTaskKey, + params: { + deployment_owner: string; + deployment_name: string; + } & Omit< + Parameters[2], + "webhook" | "webhook_events_filter" + >, + options: CallbackTimeout = { timeoutInSeconds: 3600 } + ): ReplicateReturnType { + return this.runTask( + key, + (client, task) => { + const { deployment_owner, deployment_name, ...options } = params; + + return client.deployments.predictions.create(deployment_owner, deployment_name, { + ...options, + webhook: task.callbackUrl ?? "", + webhook_events_filter: ["completed"], + }); + }, + { + name: "Create And Await Prediction With Deployment", + params, + properties: [...createDeploymentProperties(params), ...callbackProperties(options)], + callback: { + enabled: true, + timeoutInSeconds: options.timeoutInSeconds, + }, + } + ); + } +} diff --git a/integrations/replicate/src/index.ts b/integrations/replicate/src/index.ts new file mode 100644 index 0000000000..0093be164d --- /dev/null +++ b/integrations/replicate/src/index.ts @@ -0,0 +1,280 @@ +import { + TriggerIntegration, + RunTaskOptions, + IO, + IOTask, + IntegrationTaskKey, + RunTaskErrorCallback, + Json, + retry, + ConnectionAuth, +} from "@trigger.dev/sdk"; +import ReplicateClient, { Page, Prediction } from "replicate"; + +import { Predictions } from "./predictions"; +import { Models } from "./models"; +import { Trainings } from "./trainings"; +import { Collections } from "./collections"; +import { ReplicateReturnType } from "./types"; +import { Deployments } from "./deployments"; + +export type ReplicateIntegrationOptions = { + id: string; + apiKey: string; +}; + +export type ReplicateRunTask = InstanceType["runTask"]; + +export class Replicate implements TriggerIntegration { + private _options: ReplicateIntegrationOptions; + private _client?: any; + private _io?: IO; + private _connectionKey?: string; + + constructor(private options: ReplicateIntegrationOptions) { + if (Object.keys(options).includes("apiKey") && !options.apiKey) { + throw `Can't create Replicate integration (${options.id}) as apiKey was undefined`; + } + + this._options = options; + } + + get authSource() { + return "LOCAL" as const; + } + + get id() { + return this.options.id; + } + + get metadata() { + return { id: "replicate", name: "Replicate" }; + } + + cloneForRun(io: IO, connectionKey: string, auth?: ConnectionAuth) { + const replicate = new Replicate(this._options); + replicate._io = io; + replicate._connectionKey = connectionKey; + replicate._client = this.createClient(auth); + return replicate; + } + + createClient(auth?: ConnectionAuth) { + return new ReplicateClient({ + auth: this._options.apiKey, + }); + } + + runTask | void>( + key: IntegrationTaskKey, + callback: (client: ReplicateClient, task: IOTask, io: IO) => Promise, + options?: RunTaskOptions, + errorCallback?: RunTaskErrorCallback + ): Promise { + if (!this._io) throw new Error("No IO"); + if (!this._connectionKey) throw new Error("No connection key"); + + return this._io.runTask( + key, + (task, io) => { + if (!this._client) throw new Error("No client"); + return callback(this._client, task, io); + }, + { + icon: "replicate", + retry: retry.standardBackoff, + ...(options ?? {}), + connectionKey: this._connectionKey, + }, + errorCallback ?? onError + ); + } + + get collections() { + return new Collections(this.runTask.bind(this)); + } + + get deployments() { + return new Deployments(this.runTask.bind(this)); + } + + get models() { + return new Models(this.runTask.bind(this)); + } + + get predictions() { + return new Predictions(this.runTask.bind(this)); + } + + get trainings() { + return new Trainings(this.runTask.bind(this)); + } + + /** Paginate through a list of results. */ + async *paginate( + task: (key: string) => Promise>, + key: IntegrationTaskKey, + counter: number = 0 + ): AsyncGenerator { + const boundTask = task.bind(this as any); + + const page = await boundTask(`${key}-${counter}`); + yield page.results; + + if (page.next) { + const nextStep = counter++; + + const nextPage = () => { + return this.request>(`${key}-${nextStep}`, { + route: page.next!, + options: { method: "GET" }, + }); + }; + + yield* this.paginate(nextPage, key, nextStep); + } + } + + /** Auto-paginate and return all results. */ + async getAll( + task: (key: string) => Promise>, + key: IntegrationTaskKey + ): ReplicateReturnType { + const allResults: T[] = []; + + for await (const results of this.paginate(task, key)) { + allResults.push(...results); + } + + return allResults; + } + + /** Make a request to the Replicate API. */ + request( + key: IntegrationTaskKey, + params: { + route: string | URL; + options: Parameters[1]; + } + ): ReplicateReturnType { + return this.runTask( + key, + async (client) => { + const response = await client.request(params.route, params.options); + + return response.json(); + }, + { + name: "Send Request", + params, + properties: [ + { label: "Route", text: params.route.toString() }, + ...(params.options.method ? [{ label: "Method", text: params.options.method }] : []), + ], + callback: { enabled: true }, + } + ); + } + + /** Run a model and await the result. */ + run( + key: IntegrationTaskKey, + params: { + identifier: Parameters[0]; + } & Omit< + Parameters[1], + "webhook" | "webhook_events_filter" | "wait" | "signal" + > + ): ReplicateReturnType { + const { identifier, ...paramsWithoutIdentifier } = params; + + // see: https://github.com/replicate/replicate-javascript/blob/4b0d9cb0e226fab3d3d31de5b32261485acf5626/index.js#L102 + + const namePattern = /[a-zA-Z0-9]+(?:(?:[._]|__|[-]*)[a-zA-Z0-9]+)*/; + const pattern = new RegExp( + `^(?${namePattern.source})/(?${namePattern.source}):(?[0-9a-fA-F]+)$` + ); + + const match = identifier.match(pattern); + + if (!match || !match.groups) { + throw new Error('Invalid version. It must be in the format "owner/name:version"'); + } + + const { version } = match.groups; + + return this.predictions.createAndAwait(key, { ...paramsWithoutIdentifier, version }); + } + + // TODO: wait(prediction) - needs polling +} + +class ApiError extends Error { + constructor( + message: string, + readonly request: Request, + readonly response: Response + ) { + super(message); + this.name = "ApiError"; + } +} + +function isReplicateApiError(error: unknown): error is ApiError { + if (typeof error !== "object" || error === null) { + return false; + } + + const apiError = error as ApiError; + + return ( + apiError.name === "ApiError" && + apiError.request instanceof Request && + apiError.response instanceof Response + ); +} + +function shouldRetry(method: string, status: number) { + return status === 429 || (method === "GET" && status >= 500); +} + +export function onError(error: unknown): ReturnType { + if (!isReplicateApiError(error)) { + return; + } + + if (!shouldRetry(error.request.method, error.response.status)) { + return { + skipRetrying: true, + }; + } + + // see: https://github.com/replicate/replicate-javascript/blob/4b0d9cb0e226fab3d3d31de5b32261485acf5626/lib/util.js#L43 + + const retryAfter = error.response.headers.get("retry-after"); + + if (retryAfter) { + const resetDate = new Date(retryAfter); + + if (!Number.isNaN(resetDate.getTime())) { + return { + retryAt: resetDate, + error, + }; + } + } + + const rateLimitRemaining = error.response.headers.get("ratelimit-remaining"); + const rateLimitReset = error.response.headers.get("ratelimit-reset"); + + if (rateLimitRemaining === "0" && rateLimitReset) { + const resetDate = new Date(Number(rateLimitReset) * 1000); + + if (!Number.isNaN(resetDate.getTime())) { + return { + retryAt: resetDate, + error, + }; + } + } +} diff --git a/integrations/replicate/src/models.ts b/integrations/replicate/src/models.ts new file mode 100644 index 0000000000..d4b3a78ac5 --- /dev/null +++ b/integrations/replicate/src/models.ts @@ -0,0 +1,82 @@ +import { IntegrationTaskKey } from "@trigger.dev/sdk"; +import { Model, ModelVersion } from "replicate"; + +import { ReplicateRunTask } from "./index"; +import { modelProperties } from "./utils"; +import { ReplicateReturnType } from "./types"; + +export class Models { + constructor(private runTask: ReplicateRunTask) {} + + /** Get information about a model. */ + get( + key: IntegrationTaskKey, + params: { + model_owner: string; + model_name: string; + } + ): ReplicateReturnType { + return this.runTask( + key, + (client) => { + return client.models.get(params.model_owner, params.model_name); + }, + { + name: "Get Model", + params, + properties: modelProperties(params), + } + ); + } + + get versions() { + return new Versions(this.runTask); + } +} + +class Versions { + constructor(private runTask: ReplicateRunTask) {} + + /** Get a specific model version. */ + get( + key: IntegrationTaskKey, + params: { + model_owner: string; + model_name: string; + version_id: string; + } + ): ReplicateReturnType { + return this.runTask( + key, + (client) => { + return client.models.versions.get(params.model_owner, params.model_name, params.version_id); + }, + { + name: "Get Model Version", + params, + properties: modelProperties(params), + } + ); + } + + /** List model versions. */ + list( + key: IntegrationTaskKey, + params: { + model_owner: string; + model_name: string; + } + ): ReplicateReturnType { + return this.runTask( + key, + (client) => { + return client.models.versions.list(params.model_owner, params.model_name); + }, + { + name: "List Models", + params, + properties: modelProperties(params), + } + ); + } +} diff --git a/integrations/replicate/src/predictions.ts b/integrations/replicate/src/predictions.ts new file mode 100644 index 0000000000..9f6c604fd2 --- /dev/null +++ b/integrations/replicate/src/predictions.ts @@ -0,0 +1,101 @@ +import { IntegrationTaskKey } from "@trigger.dev/sdk"; +import ReplicateClient, { Page, Prediction } from "replicate"; + +import { ReplicateRunTask } from "./index"; +import { CallbackTimeout, ReplicateReturnType } from "./types"; +import { callbackProperties, createPredictionProperties } from "./utils"; + +export class Predictions { + constructor(private runTask: ReplicateRunTask) {} + + /** Cancel a prediction. */ + cancel(key: IntegrationTaskKey, params: { id: string }): ReplicateReturnType { + return this.runTask( + key, + (client) => { + return client.predictions.cancel(params.id); + }, + { + name: "Cancel Prediction", + params, + properties: [{ label: "Prediction ID", text: params.id }], + } + ); + } + + /** Create a new prediction. */ + create( + key: IntegrationTaskKey, + params: Parameters[0] + ): ReplicateReturnType { + return this.runTask( + key, + (client) => { + return client.predictions.create(params); + }, + { + name: "Create Prediction", + params, + properties: createPredictionProperties(params), + } + ); + } + + /** Create a new prediction and await the result. */ + createAndAwait( + key: IntegrationTaskKey, + params: Omit< + Parameters[0], + "webhook" | "webhook_events_filter" + >, + options: CallbackTimeout = { timeoutInSeconds: 3600 } + ): ReplicateReturnType { + return this.runTask( + key, + (client, task) => { + return client.predictions.create({ + ...params, + webhook: task.callbackUrl ?? "", + webhook_events_filter: ["completed"], + }); + }, + { + name: "Create And Await Prediction", + params, + properties: [...createPredictionProperties(params), ...callbackProperties(options)], + callback: { + enabled: true, + timeoutInSeconds: options.timeoutInSeconds, + }, + } + ); + } + + /** Fetch a prediction. */ + get(key: IntegrationTaskKey, params: { id: string }): ReplicateReturnType { + return this.runTask( + key, + (client) => { + return client.predictions.get(params.id); + }, + { + name: "Get Prediction", + params, + properties: [{ label: "Prediction ID", text: params.id }], + } + ); + } + + /** List all predictions. */ + list(key: IntegrationTaskKey): ReplicateReturnType> { + return this.runTask( + key, + (client) => { + return client.predictions.list(); + }, + { + name: "List Predictions", + } + ); + } +} diff --git a/integrations/replicate/src/trainings.ts b/integrations/replicate/src/trainings.ts new file mode 100644 index 0000000000..10a3ae576b --- /dev/null +++ b/integrations/replicate/src/trainings.ts @@ -0,0 +1,113 @@ +import { IntegrationTaskKey } from "@trigger.dev/sdk"; +import ReplicateClient, { Page, Training } from "replicate"; + +import { ReplicateRunTask } from "./index"; +import { CallbackTimeout, ReplicateReturnType } from "./types"; +import { callbackProperties, modelProperties } from "./utils"; + +export class Trainings { + constructor(private runTask: ReplicateRunTask) {} + + /** Cancel a training. */ + cancel(key: IntegrationTaskKey, params: { id: string }): ReplicateReturnType { + return this.runTask( + key, + (client) => { + return client.trainings.cancel(params.id); + }, + { + name: "Cancel Training", + params, + properties: [{ label: "Training ID", text: params.id }], + } + ); + } + + /** Create a new training. */ + create( + key: IntegrationTaskKey, + params: { + model_owner: string; + model_name: string; + version_id: string; + } & Parameters[3] + ): ReplicateReturnType { + return this.runTask( + key, + (client) => { + const { model_owner, model_name, version_id, ...options } = params; + + return client.trainings.create(model_owner, model_name, version_id, options); + }, + { + name: "Create Training", + params, + properties: modelProperties(params), + } + ); + } + + /** Create a new training and await the result. */ + createAndAwait( + key: IntegrationTaskKey, + params: { + model_owner: string; + model_name: string; + version_id: string; + } & Omit< + Parameters[3], + "webhook" | "webhook_events_filter" + >, + options: CallbackTimeout = { timeoutInSeconds: 3600 } + ): ReplicateReturnType { + return this.runTask( + key, + (client, task) => { + const { model_owner, model_name, version_id, ...options } = params; + + return client.trainings.create(model_owner, model_name, version_id, { + ...options, + webhook: task.callbackUrl ?? "", + webhook_events_filter: ["completed"], + }); + }, + { + name: "Create And Await Training", + params, + properties: [...modelProperties(params), ...callbackProperties(options)], + callback: { + enabled: true, + timeoutInSeconds: options.timeoutInSeconds, + }, + } + ); + } + + /** Fetch a training. */ + get(key: IntegrationTaskKey, params: { id: string }): ReplicateReturnType { + return this.runTask( + key, + (client) => { + return client.trainings.get(params.id); + }, + { + name: "Get Training", + params, + properties: [{ label: "Training ID", text: params.id }], + } + ); + } + + /** List all trainings. */ + list(key: IntegrationTaskKey): ReplicateReturnType> { + return this.runTask( + key, + async (client) => { + return client.trainings.list(); + }, + { + name: "List Trainings", + } + ); + } +} diff --git a/integrations/replicate/src/types.ts b/integrations/replicate/src/types.ts new file mode 100644 index 0000000000..d8fafcd1dc --- /dev/null +++ b/integrations/replicate/src/types.ts @@ -0,0 +1,3 @@ +export type CallbackTimeout = { timeoutInSeconds?: number }; + +export type ReplicateReturnType = Promise; diff --git a/integrations/replicate/src/utils.ts b/integrations/replicate/src/utils.ts new file mode 100644 index 0000000000..0a510690fe --- /dev/null +++ b/integrations/replicate/src/utils.ts @@ -0,0 +1,58 @@ +import { CallbackTimeout } from "./types"; + +export const createPredictionProperties = ( + params: Partial<{ + version: string; + stream: boolean; + }> +) => { + return [ + ...(params.version ? [{ label: "Model Version", text: params.version }] : []), + ...streamingProperty(params), + ]; +}; + +export const createDeploymentProperties = ( + params: Partial<{ + deployment_owner: string; + deployment_name: string; + stream: boolean; + }> +) => { + return [ + ...(params.deployment_owner + ? [{ label: "Deployment Owner", text: params.deployment_owner }] + : []), + ...(params.deployment_name ? [{ label: "Deployment Name", text: params.deployment_name }] : []), + ...streamingProperty(params), + ]; +}; + +export const modelProperties = ( + params: Partial<{ + model_owner: string; + model_name: string; + version_id: string; + destination: string; + }> +) => { + return [ + ...(params.model_owner ? [{ label: "Model Owner", text: params.model_owner }] : []), + ...(params.model_name ? [{ label: "Model Name", text: params.model_name }] : []), + ...(params.version_id ? [{ label: "Model Version", text: params.version_id }] : []), + ...(params.destination ? [{ label: "Destination Model", text: params.destination }] : []), + ]; +}; + +export const streamingProperty = (params: { stream?: boolean }) => { + return [{ label: "Streaming Enabled", text: String(!!params.stream) }]; +}; + +export const callbackProperties = (options: CallbackTimeout) => { + return [ + { + label: "Callback Timeout", + text: options.timeoutInSeconds ? `${options.timeoutInSeconds}s` : "default", + }, + ]; +}; diff --git a/integrations/replicate/tsconfig.json b/integrations/replicate/tsconfig.json new file mode 100644 index 0000000000..36ae307e42 --- /dev/null +++ b/integrations/replicate/tsconfig.json @@ -0,0 +1,4 @@ +{ + "extends": "@trigger.dev/tsconfig/integration.json", + "include": ["./src/**/*.ts", "tsup.config.ts"], +} diff --git a/integrations/replicate/tsup.config.ts b/integrations/replicate/tsup.config.ts new file mode 100644 index 0000000000..483aba1d59 --- /dev/null +++ b/integrations/replicate/tsup.config.ts @@ -0,0 +1,22 @@ +import { defineConfig } from "tsup"; + +export default defineConfig([ + { + name: "main", + entry: ["./src/index.ts"], + outDir: "./dist", + platform: "node", + format: ["cjs"], + legacyOutput: true, + sourcemap: true, + clean: true, + bundle: true, + splitting: false, + dts: true, + treeshake: { + preset: "smallest", + }, + esbuildPlugins: [], + external: ["http", "https", "util", "events", "tty", "os", "timers"], + }, +]); diff --git a/integrations/resend/src/index.ts b/integrations/resend/src/index.ts index f93be1964d..b483fa48c9 100644 --- a/integrations/resend/src/index.ts +++ b/integrations/resend/src/index.ts @@ -100,7 +100,7 @@ export class Resend implements TriggerIntegration { if (!this._io) throw new Error("No IO"); if (!this._connectionKey) throw new Error("No connection key"); - return this._io.runTask( + return this._io.runTask( key, (task, io) => { if (!this._client) throw new Error("No client"); diff --git a/integrations/sendgrid/src/index.ts b/integrations/sendgrid/src/index.ts index 49a1a55f2d..d08a39379b 100644 --- a/integrations/sendgrid/src/index.ts +++ b/integrations/sendgrid/src/index.ts @@ -70,7 +70,7 @@ export class SendGrid implements TriggerIntegration { if (!this._io) throw new Error("No IO"); if (!this._connectionKey) throw new Error("No connection key"); - return this._io.runTask( + return this._io.runTask( key, (task, io) => { if (!this._client) throw new Error("No client"); diff --git a/integrations/slack/src/index.ts b/integrations/slack/src/index.ts index 9a8571d4ce..4934a99d1c 100644 --- a/integrations/slack/src/index.ts +++ b/integrations/slack/src/index.ts @@ -92,7 +92,7 @@ export class Slack implements TriggerIntegration { if (!this._io) throw new Error("No IO"); if (!this._connectionKey) throw new Error("No connection key"); - return this._io.runTask( + return this._io.runTask( key, (task, io) => { if (!this._client) throw new Error("No client"); diff --git a/packages/core/src/schemas/api.ts b/packages/core/src/schemas/api.ts index 62edc1f15d..ac9828e7a8 100644 --- a/packages/core/src/schemas/api.ts +++ b/packages/core/src/schemas/api.ts @@ -629,6 +629,16 @@ export const RunTaskOptionsSchema = z.object({ params: z.any(), /** The style of the log entry. */ style: StyleSchema.optional(), + /** Allows you to expose a `task.callbackUrl` to use in your tasks. Enabling this feature will cause the task to return the data sent to the callbackUrl instead of the usual async callback result. */ + callback: z + .object({ + /** Causes the task to wait for and return the data of the first request sent to `task.callbackUrl`. */ + enabled: z.boolean(), + /** Time to wait for the first request to `task.callbackUrl`. Default: One hour. */ + timeoutInSeconds: z.number(), + }) + .partial() + .optional(), /** Allows you to link the Integration connection in the logs. This is handled automatically in integrations. */ connectionKey: z.string().optional(), /** An operation you want to perform on the Trigger.dev platform, current only "fetch" is supported. If you wish to `fetch` use [`io.backgroundFetch()`](https://trigger.dev/docs/sdk/io/backgroundfetch) instead. */ @@ -655,6 +665,12 @@ export type RunTaskBodyInput = z.infer; export const RunTaskBodyOutputSchema = RunTaskBodyInputSchema.extend({ params: DeserializedJsonSchema.optional().nullable(), + callback: z + .object({ + enabled: z.boolean(), + timeoutInSeconds: z.number().default(3600), + }) + .optional(), }); export type RunTaskBodyOutput = z.infer; diff --git a/packages/core/src/schemas/tasks.ts b/packages/core/src/schemas/tasks.ts index fe6d43bbdb..559dc4bab3 100644 --- a/packages/core/src/schemas/tasks.ts +++ b/packages/core/src/schemas/tasks.ts @@ -31,6 +31,7 @@ export const TaskSchema = z.object({ parentId: z.string().optional().nullable(), style: StyleSchema.optional().nullable(), operation: z.string().optional().nullable(), + callbackUrl: z.string().optional().nullable(), }); export const ServerTaskSchema = TaskSchema.extend({ diff --git a/packages/database/prisma/migrations/20230925174509_add_callback_url/migration.sql b/packages/database/prisma/migrations/20230925174509_add_callback_url/migration.sql new file mode 100644 index 0000000000..4808101efc --- /dev/null +++ b/packages/database/prisma/migrations/20230925174509_add_callback_url/migration.sql @@ -0,0 +1,2 @@ +-- AlterTable +ALTER TABLE "Task" ADD COLUMN "callbackUrl" TEXT; diff --git a/packages/database/prisma/schema.prisma b/packages/database/prisma/schema.prisma index 192e0d66a4..0bcdc15477 100644 --- a/packages/database/prisma/schema.prisma +++ b/packages/database/prisma/schema.prisma @@ -798,6 +798,7 @@ model Task { redact Json? style Json? operation String? + callbackUrl String? startedAt DateTime? completedAt DateTime? diff --git a/packages/trigger-sdk/src/io.ts b/packages/trigger-sdk/src/io.ts index 58d5a95e9c..b4dacfc5e0 100644 --- a/packages/trigger-sdk/src/io.ts +++ b/packages/trigger-sdk/src/io.ts @@ -694,28 +694,18 @@ export class IO { throw new Error(task.error ?? task?.output ? JSON.stringify(task.output) : "Task errored"); } - if (task.status === "WAITING") { - this._logger.debug("Task waiting", { - idempotencyKey, - task, - }); - - throw new ResumeWithTaskError(task); - } - - if (task.status === "RUNNING" && typeof task.operation === "string") { - this._logger.debug("Task running operation", { - idempotencyKey, - task, - }); - - throw new ResumeWithTaskError(task); - } - const executeTask = async () => { try { const result = await callback(task, this); + if (task.status === "WAITING" && task.callbackUrl) { + this._logger.debug("Waiting for remote callback", { + idempotencyKey, + task, + }); + return {} as T; + } + const output = SerializableJsonSchema.parse(result) as T; this._logger.debug("Completing using output", { @@ -800,6 +790,28 @@ export class IO { } }; + if (task.status === "WAITING") { + this._logger.debug("Task waiting", { + idempotencyKey, + task, + }); + + if (task.callbackUrl) { + await this._taskStorage.run({ taskId: task.id }, executeTask); + } + + throw new ResumeWithTaskError(task); + } + + if (task.status === "RUNNING" && typeof task.operation === "string") { + this._logger.debug("Task running operation", { + idempotencyKey, + task, + }); + + throw new ResumeWithTaskError(task); + } + return this._taskStorage.run({ taskId: task.id }, executeTask); } diff --git a/references/job-catalog/package.json b/references/job-catalog/package.json index ab55bf836e..ef34a56c98 100644 --- a/references/job-catalog/package.json +++ b/references/job-catalog/package.json @@ -25,6 +25,7 @@ "status": "nodemon --watch src/status.ts -r tsconfig-paths/register -r dotenv/config src/status.ts", "byo-auth": "nodemon --watch src/byo-auth.ts -r tsconfig-paths/register -r dotenv/config src/byo-auth.ts", "redacted": "nodemon --watch src/redacted.ts -r tsconfig-paths/register -r dotenv/config src/redacted.ts", + "replicate": "nodemon --watch src/replicate.ts -r tsconfig-paths/register -r dotenv/config src/replicate.ts", "dev:trigger": "trigger-cli dev --port 8080" }, "dependencies": { @@ -44,7 +45,8 @@ "@types/node": "20.4.2", "typescript": "5.1.6", "zod": "3.21.4", - "@trigger.dev/linear": "workspace:*" + "@trigger.dev/linear": "workspace:*", + "@trigger.dev/replicate": "workspace:*" }, "trigger.dev": { "endpointId": "job-catalog" diff --git a/references/job-catalog/src/replicate.ts b/references/job-catalog/src/replicate.ts new file mode 100644 index 0000000000..5cd416bf08 --- /dev/null +++ b/references/job-catalog/src/replicate.ts @@ -0,0 +1,146 @@ +import { createExpressServer } from "@trigger.dev/express"; +import { TriggerClient, eventTrigger } from "@trigger.dev/sdk"; +import { Replicate } from "@trigger.dev/replicate"; +import { z } from "zod"; + +export const client = new TriggerClient({ + id: "job-catalog", + apiKey: process.env["TRIGGER_API_KEY"], + apiUrl: process.env["TRIGGER_API_URL"], + verbose: false, + ioLogLocalEnabled: true, +}); + +const replicate = new Replicate({ + id: "replicate", + apiKey: process.env["REPLICATE_API_KEY"]!, +}); + +client.defineJob({ + id: "replicate-forge-image", + name: "Replicate - Forge Image", + version: "0.1.0", + integrations: { replicate }, + trigger: eventTrigger({ + name: "replicate.bad.forgery", + schema: z.object({ + imageUrl: z + .string() + .url() + .default("https://trigger.dev/blog/supabase-integration/postgres-meme.png"), + }), + }), + run: async (payload, io, ctx) => { + const blipVersion = "2e1dddc8621f72155f24cf2e0adbde548458d3cab9f00c0139eea840d0ac4746"; + const sdVersion = "ac732df83cea7fff18b8472768c88ad041fa750ff7682a21affe81863cbe77e4"; + + const blipPrediction = await io.replicate.run("caption-image", { + identifier: `salesforce/blip:${blipVersion}`, + input: { + image: payload.imageUrl, + }, + }); + + if (typeof blipPrediction.output !== "string") { + throw new Error(`Expected string output, got ${typeof blipPrediction.output}`); + } + + const caption = blipPrediction.output.replace("Caption: ", ""); + + const sdPrediction = await io.replicate.predictions.createAndAwait("draw-image", { + version: sdVersion, + input: { + prompt: caption, + }, + }); + + return { + caption, + output: sdPrediction.output, + }; + }, +}); + +client.defineJob({ + id: "replicate-python-answers", + name: "Replicate - Python Answers", + version: "0.1.0", + integrations: { replicate }, + trigger: eventTrigger({ + name: "replicate.serious.monty", + schema: z.object({ + prompt: z.string().default("why are apples not oranges?"), + }), + }), + run: async (payload, io, ctx) => { + const prediction = await io.replicate.run("await-prediction", { + identifier: + "meta/llama-2-13b-chat:f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d", + input: { + prompt: payload.prompt, + system_prompt: "Answer like John Cleese. Don't be funny.", + max_new_tokens: 200, + }, + }); + + return Array.isArray(prediction.output) ? prediction.output.join("") : prediction.output; + }, +}); + +client.defineJob({ + id: "replicate-cinematic-prompt", + name: "Replicate - Cinematic Prompt", + version: "0.1.0", + integrations: { replicate }, + trigger: eventTrigger({ + name: "replicate.cinematic", + schema: z.object({ + prompt: z.string().default("rick astley riding a harley through post-apocalyptic miami"), + version: z + .string() + .default("af1a68a271597604546c09c64aabcd7782c114a63539a4a8d14d1eeda5630c33"), + }), + }), + run: async (payload, io, ctx) => { + const prediction = await io.replicate.predictions.createAndAwait("await-prediction", { + version: payload.version, + input: { + prompt: `${payload.prompt}, cinematic, 70mm, anamorphic, bokeh`, + width: 1280, + height: 720, + }, + }); + return prediction.output; + }, +}); + +client.defineJob({ + id: "replicate-pagination", + name: "Replicate - Pagination", + version: "0.1.0", + integrations: { + replicate, + }, + trigger: eventTrigger({ + name: "replicate.paginate", + }), + run: async (payload, io, ctx) => { + // getAll - returns an array of all results (uses paginate internally) + const all = await io.replicate.getAll(io.replicate.predictions.list, "get-all"); + + // paginate - returns an async generator, useful to process one page at a time + for await (const predictions of io.replicate.paginate( + io.replicate.predictions.list, + "paginate-all" + )) { + await io.logger.info("stats", { + total: predictions.length, + versions: predictions.map((p) => p.version), + }); + } + + return { count: all.length }; + }, +}); + +createExpressServer(client); diff --git a/references/job-catalog/tsconfig.json b/references/job-catalog/tsconfig.json index 6ec3167e67..80823d1a8e 100644 --- a/references/job-catalog/tsconfig.json +++ b/references/job-catalog/tsconfig.json @@ -97,6 +97,12 @@ ], "@trigger.dev/linear/*": [ "../../integrations/linear/src/*" + ], + "@trigger.dev/replicate": [ + "../../integrations/replicate/src/index" + ], + "@trigger.dev/replicate/*": [ + "../../integrations/replicate/src/*" ] } }