Skip to content

Commit 6a992a1

Browse files
nicktrnericallam
andauthored
Replicate integration and remote callbacks (#507)
* Support tasks with remote callbacks * Add common integration tsconfig * Add Replicate integration * Basic job catalog example * Integration catalog entry * Check for callbackUrl during executeTask * Fix getAll * Improve JSDoc * Bump version * Remove named queue * Simplify runTask types * Trust the types * Fail tasks on timeout * Callback timeout as param * Mess with types * performRunExecutionV1 * Update runTask docs * Shorten callback task methods * Fix run method return type * Image processing jobs * Replicate docs * Text output example * Changeset * Version bump * Roll back ugly types * Remove missing types * Quicker return when waiting on remote callback * Remote callback example * Bump version * Remove schema parsing * Only schedule positive callback timeout * Decrease callback secret length * Explicit default timeouts * Import deployments tasks * JSDoc * Deployments docs * Fix runTask examples, mention wrappers --------- Co-authored-by: Eric Allam <[email protected]>
1 parent 81e886a commit 6a992a1

File tree

43 files changed

+1600
-47
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+1600
-47
lines changed

.changeset/fair-plums-grin.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
---
2+
"@trigger.dev/replicate": patch
3+
"@trigger.dev/airtable": patch
4+
"@trigger.dev/sendgrid": patch
5+
"@trigger.dev/sdk": patch
6+
"@trigger.dev/github": patch
7+
"@trigger.dev/linear": patch
8+
"@trigger.dev/resend": patch
9+
"@trigger.dev/slack": patch
10+
"@trigger.dev/core": patch
11+
---
12+
13+
First release of `@trigger.dev/replicate` integration with remote callback support.

apps/webapp/app/models/task.server.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ export function taskWithAttemptsToServerTask(task: TaskWithAttempts): ServerTask
2323
attempts: task.attempts.length,
2424
idempotencyKey: task.idempotencyKey,
2525
operation: task.operation,
26+
callbackUrl: task.callbackUrl,
2627
};
2728
}
2829

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
import type { ActionArgs } from "@remix-run/server-runtime";
2+
import { json } from "@remix-run/server-runtime";
3+
import { RuntimeEnvironmentType } from "@trigger.dev/database";
4+
import { z } from "zod";
5+
import { $transaction, PrismaClient, PrismaClientOrTransaction, prisma } from "~/db.server";
6+
import { enqueueRunExecutionV2 } from "~/models/jobRunExecution.server";
7+
import { logger } from "~/services/logger.server";
8+
9+
const ParamsSchema = z.object({
10+
runId: z.string(),
11+
id: z.string(),
12+
secret: z.string(),
13+
});
14+
15+
export async function action({ request, params }: ActionArgs) {
16+
// Ensure this is a POST request
17+
if (request.method.toUpperCase() !== "POST") {
18+
return { status: 405, body: "Method Not Allowed" };
19+
}
20+
21+
const { runId, id } = ParamsSchema.parse(params);
22+
23+
// Parse body as JSON (no schema parsing)
24+
const body = await request.json();
25+
26+
const service = new CallbackRunTaskService();
27+
28+
try {
29+
// Complete task with request body as output
30+
await service.call(runId, id, body, request.url);
31+
32+
return json({ success: true });
33+
} catch (error) {
34+
if (error instanceof Error) {
35+
logger.error("Error while processing task callback:", { error });
36+
}
37+
38+
return json({ error: "Something went wrong" }, { status: 500 });
39+
}
40+
}
41+
42+
export class CallbackRunTaskService {
43+
#prismaClient: PrismaClient;
44+
45+
constructor(prismaClient: PrismaClient = prisma) {
46+
this.#prismaClient = prismaClient;
47+
}
48+
49+
public async call(runId: string, id: string, taskBody: any, callbackUrl: string): Promise<void> {
50+
const task = await findTask(prisma, id);
51+
52+
if (!task) {
53+
return;
54+
}
55+
56+
if (task.runId !== runId) {
57+
return;
58+
}
59+
60+
if (task.status !== "WAITING") {
61+
return;
62+
}
63+
64+
if (!task.callbackUrl) {
65+
return;
66+
}
67+
68+
if (new URL(task.callbackUrl).pathname !== new URL(callbackUrl).pathname) {
69+
logger.error("Callback URLs don't match", { runId, taskId: id, callbackUrl });
70+
return;
71+
}
72+
73+
logger.debug("CallbackRunTaskService.call()", { task });
74+
75+
await this.#resumeTask(task, taskBody);
76+
}
77+
78+
async #resumeTask(task: NonNullable<FoundTask>, output: any) {
79+
await $transaction(this.#prismaClient, async (tx) => {
80+
await tx.taskAttempt.updateMany({
81+
where: {
82+
taskId: task.id,
83+
status: "PENDING",
84+
},
85+
data: {
86+
status: "COMPLETED",
87+
},
88+
});
89+
90+
await tx.task.update({
91+
where: { id: task.id },
92+
data: {
93+
status: "COMPLETED",
94+
completedAt: new Date(),
95+
output: output ? output : undefined,
96+
},
97+
});
98+
99+
await this.#resumeRunExecution(task, tx);
100+
});
101+
}
102+
103+
async #resumeRunExecution(task: NonNullable<FoundTask>, prisma: PrismaClientOrTransaction) {
104+
await enqueueRunExecutionV2(task.run, prisma, {
105+
skipRetrying: task.run.environment.type === RuntimeEnvironmentType.DEVELOPMENT,
106+
});
107+
}
108+
}
109+
110+
type FoundTask = Awaited<ReturnType<typeof findTask>>;
111+
112+
async function findTask(prisma: PrismaClientOrTransaction, id: string) {
113+
return prisma.task.findUnique({
114+
where: { id },
115+
include: {
116+
run: {
117+
include: {
118+
environment: true,
119+
queue: true,
120+
},
121+
},
122+
},
123+
});
124+
}

apps/webapp/app/routes/api.v1.runs.$runId.tasks.ts

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ import { authenticateApiRequest } from "~/services/apiAuth.server";
1515
import { logger } from "~/services/logger.server";
1616
import { ulid } from "~/services/ulid.server";
1717
import { workerQueue } from "~/services/worker.server";
18+
import { generateSecret } from "~/services/sources/utils.server";
19+
import { env } from "~/env.server";
1820

1921
const ParamsSchema = z.object({
2022
runId: z.string(),
@@ -185,10 +187,13 @@ export class RunTaskService {
185187
},
186188
});
187189

190+
const delayUntilInFuture = taskBody.delayUntil && taskBody.delayUntil.getTime() > Date.now();
191+
const callbackEnabled = taskBody.callback?.enabled;
192+
188193
if (existingTask) {
189194
if (existingTask.status === "CANCELED") {
190195
const existingTaskStatus =
191-
(taskBody.delayUntil && taskBody.delayUntil.getTime() > Date.now()) || taskBody.trigger
196+
delayUntilInFuture || callbackEnabled || taskBody.trigger
192197
? "WAITING"
193198
: taskBody.noop
194199
? "COMPLETED"
@@ -233,16 +238,21 @@ export class RunTaskService {
233238
status = "CANCELED";
234239
} else {
235240
status =
236-
(taskBody.delayUntil && taskBody.delayUntil.getTime() > Date.now()) || taskBody.trigger
241+
delayUntilInFuture || callbackEnabled || taskBody.trigger
237242
? "WAITING"
238243
: taskBody.noop
239244
? "COMPLETED"
240245
: "RUNNING";
241246
}
242247

248+
const taskId = ulid();
249+
const callbackUrl = callbackEnabled
250+
? `${env.APP_ORIGIN}/api/v1/runs/${runId}/tasks/${taskId}/callback/${generateSecret(12)}`
251+
: undefined;
252+
243253
const task = await tx.task.create({
244254
data: {
245-
id: ulid(),
255+
id: taskId,
246256
idempotencyKey,
247257
displayKey: taskBody.displayKey,
248258
runConnection: taskBody.connectionKey
@@ -273,6 +283,7 @@ export class RunTaskService {
273283
properties: taskBody.properties ?? undefined,
274284
redact: taskBody.redact ?? undefined,
275285
operation: taskBody.operation,
286+
callbackUrl,
276287
style: taskBody.style ?? { style: "normal" },
277288
attempts: {
278289
create: {
@@ -296,6 +307,17 @@ export class RunTaskService {
296307
},
297308
{ tx, runAt: task.delayUntil ?? undefined }
298309
);
310+
} else if (task.status === "WAITING" && callbackUrl && taskBody.callback) {
311+
if (taskBody.callback.timeoutInSeconds > 0) {
312+
// We need to schedule the callback timeout
313+
await workerQueue.enqueue(
314+
"processCallbackTimeout",
315+
{
316+
id: task.id,
317+
},
318+
{ tx, runAt: new Date(Date.now() + taskBody.callback.timeoutInSeconds * 1000) }
319+
);
320+
}
299321
}
300322

301323
return task;

apps/webapp/app/services/externalApis/integrationCatalog.server.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import { github } from "./integrations/github";
33
import { linear } from "./integrations/linear";
44
import { openai } from "./integrations/openai";
55
import { plain } from "./integrations/plain";
6+
import { replicate } from "./integrations/replicate";
67
import { resend } from "./integrations/resend";
78
import { sendgrid } from "./integrations/sendgrid";
89
import { slack } from "./integrations/slack";
@@ -37,6 +38,7 @@ export const integrationCatalog = new IntegrationCatalog({
3738
linear,
3839
openai,
3940
plain,
41+
replicate,
4042
resend,
4143
slack,
4244
stripe,
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import type { HelpSample, Integration } from "../types";
2+
3+
function usageSample(hasApiKey: boolean): HelpSample {
4+
const apiKeyPropertyName = "apiKey";
5+
6+
return {
7+
title: "Using the client",
8+
code: `
9+
import { Replicate } from "@trigger.dev/replicate";
10+
11+
const replicate = new Replicate({
12+
id: "__SLUG__",${hasApiKey ? `,\n ${apiKeyPropertyName}: process.env.REPLICATE_API_KEY!` : ""}
13+
});
14+
15+
client.defineJob({
16+
id: "replicate-create-prediction",
17+
name: "Replicate - Create Prediction",
18+
version: "0.1.0",
19+
integrations: { replicate },
20+
trigger: eventTrigger({
21+
name: "replicate.predict",
22+
schema: z.object({
23+
prompt: z.string(),
24+
version: z.string(),
25+
}),
26+
}),
27+
run: async (payload, io, ctx) => {
28+
return io.replicate.predictions.createAndAwait("await-prediction", {
29+
version: payload.version,
30+
input: { prompt: payload.prompt },
31+
});
32+
},
33+
});
34+
`,
35+
};
36+
}
37+
38+
export const replicate: Integration = {
39+
identifier: "replicate",
40+
name: "Replicate",
41+
packageName: "@trigger.dev/replicate@latest",
42+
authenticationMethods: {
43+
apikey: {
44+
type: "apikey",
45+
help: {
46+
samples: [usageSample(true)],
47+
},
48+
},
49+
},
50+
};

apps/webapp/app/services/runs/performRunExecutionV1.server.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,9 @@ export class PerformRunExecutionV1Service {
449449

450450
// If the task has an operation, then the next performRunExecution will occur
451451
// when that operation has finished
452-
if (!data.task.operation) {
452+
// Tasks with callbacks enabled will also get processed separately, i.e. when
453+
// they time out, or on valid requests to their callbackUrl
454+
if (!data.task.operation && !data.task.callbackUrl) {
453455
const newJobExecution = await tx.jobRunExecution.create({
454456
data: {
455457
runId: run.id,

apps/webapp/app/services/runs/performRunExecutionV2.server.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -530,7 +530,9 @@ export class PerformRunExecutionV2Service {
530530

531531
// If the task has an operation, then the next performRunExecution will occur
532532
// when that operation has finished
533-
if (!data.task.operation) {
533+
// Tasks with callbacks enabled will also get processed separately, i.e. when
534+
// they time out, or on valid requests to their callbackUrl
535+
if (!data.task.operation && !data.task.callbackUrl) {
534536
await enqueueRunExecutionV2(run, tx, {
535537
runAt: data.task.delayUntil ?? undefined,
536538
resumeTaskId: data.task.id,
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import crypto from "node:crypto";
22

3-
export function generateSecret(): string {
4-
return crypto.randomBytes(32).toString("hex");
3+
export function generateSecret(sizeInBytes = 32): string {
4+
return crypto.randomBytes(sizeInBytes).toString("hex");
55
}
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import { RuntimeEnvironmentType } from "@trigger.dev/database";
2+
import { $transaction, PrismaClient, PrismaClientOrTransaction, prisma } from "~/db.server";
3+
import { enqueueRunExecutionV2 } from "~/models/jobRunExecution.server";
4+
import { logger } from "../logger.server";
5+
6+
type FoundTask = Awaited<ReturnType<typeof findTask>>;
7+
8+
export class ProcessCallbackTimeoutService {
9+
#prismaClient: PrismaClient;
10+
11+
constructor(prismaClient: PrismaClient = prisma) {
12+
this.#prismaClient = prismaClient;
13+
}
14+
15+
public async call(id: string) {
16+
const task = await findTask(this.#prismaClient, id);
17+
18+
if (!task) {
19+
return;
20+
}
21+
22+
if (task.status !== "WAITING" || !task.callbackUrl) {
23+
return;
24+
}
25+
26+
logger.debug("ProcessCallbackTimeoutService.call", { task });
27+
28+
return await this.#failTask(task, "Remote callback timeout - no requests received");
29+
}
30+
31+
async #failTask(task: NonNullable<FoundTask>, error: string) {
32+
await $transaction(this.#prismaClient, async (tx) => {
33+
await tx.taskAttempt.updateMany({
34+
where: {
35+
taskId: task.id,
36+
status: "PENDING",
37+
},
38+
data: {
39+
status: "ERRORED",
40+
error
41+
},
42+
});
43+
44+
await tx.task.update({
45+
where: { id: task.id },
46+
data: {
47+
status: "ERRORED",
48+
completedAt: new Date(),
49+
output: error,
50+
},
51+
});
52+
53+
await this.#resumeRunExecution(task, tx);
54+
});
55+
}
56+
57+
async #resumeRunExecution(task: NonNullable<FoundTask>, prisma: PrismaClientOrTransaction) {
58+
await enqueueRunExecutionV2(task.run, prisma, {
59+
skipRetrying: task.run.environment.type === RuntimeEnvironmentType.DEVELOPMENT,
60+
});
61+
}
62+
}
63+
64+
async function findTask(prisma: PrismaClient, id: string) {
65+
return prisma.task.findUnique({
66+
where: { id },
67+
include: {
68+
run: {
69+
include: {
70+
environment: true,
71+
queue: true,
72+
},
73+
},
74+
},
75+
});
76+
}

0 commit comments

Comments
 (0)