Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/every-windows-lose.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@langchain/langgraph": patch
---

Improve performance of scheduling tasks with large graphs
79 changes: 71 additions & 8 deletions libs/langgraph/src/pregel/algo.ts
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ function triggersNextStep(
}

// Avoids unnecessary double iteration
const MAX_CHANNEL_MAP_VERSION_CACHE = new WeakMap<
Record<string, number | string>,
number | string
>();
function maxChannelMapVersion(
channelVersions: Record<string, number | string>
): number | string | undefined {
Expand Down Expand Up @@ -249,7 +253,7 @@ export function _applyWrites<Cc extends Record<string, BaseChannel>>(
// eslint-disable-next-line @typescript-eslint/no-explicit-any
getNextVersion: ((version: any) => any) | undefined,
triggerToNodes: Record<string, string[]> | undefined
): void {
): Set<string> {
// Sort tasks by first 3 path elements for deterministic order
// Later path parts (like task IDs) are ignored for sorting
tasks.sort((a, b) => {
Expand Down Expand Up @@ -285,7 +289,9 @@ export function _applyWrites<Cc extends Record<string, BaseChannel>>(
}

// Find the highest version of all channels
let maxVersion = maxChannelMapVersion(checkpoint.channel_versions);
let maxVersion =
MAX_CHANNEL_MAP_VERSION_CACHE.get(checkpoint.channel_versions) ??
maxChannelMapVersion(checkpoint.channel_versions);

// Consume all channels that were read
const channelsToConsume = new Set(
Expand All @@ -294,12 +300,12 @@ export function _applyWrites<Cc extends Record<string, BaseChannel>>(
.filter((chan) => !RESERVED.includes(chan))
);

let usedNewVersion = false;
let usedNextVersion = false;
for (const chan of channelsToConsume) {
if (chan in onlyChannels && onlyChannels[chan].consume()) {
if (getNextVersion !== undefined) {
checkpoint.channel_versions[chan] = getNextVersion(maxVersion);
usedNewVersion = true;
usedNextVersion = true;
}
}
}
Expand All @@ -318,8 +324,11 @@ export function _applyWrites<Cc extends Record<string, BaseChannel>>(
}

// Find the highest version of all channels
if (maxVersion != null && getNextVersion != null) {
maxVersion = usedNewVersion ? getNextVersion(maxVersion) : maxVersion;
if (maxVersion != null && getNextVersion != null && usedNextVersion) {
maxVersion = getNextVersion(maxVersion);
usedNextVersion = false;

MAX_CHANNEL_MAP_VERSION_CACHE.set(checkpoint.channel_versions, maxVersion!);
}

const updatedChannels: Set<string> = new Set();
Expand All @@ -346,6 +355,7 @@ export function _applyWrites<Cc extends Record<string, BaseChannel>>(
}
if (updated && getNextVersion !== undefined) {
checkpoint.channel_versions[chan] = getNextVersion(maxVersion);
usedNextVersion = true;

// unavailable channels can't trigger tasks, so don't add them
if (channel.isAvailable()) updatedChannels.add(chan);
Expand All @@ -364,6 +374,7 @@ export function _applyWrites<Cc extends Record<string, BaseChannel>>(

if (updated && getNextVersion !== undefined) {
checkpoint.channel_versions[chan] = getNextVersion(maxVersion);
usedNextVersion = true;

// unavailable channels can't trigger tasks, so don't add them
if (channel.isAvailable()) updatedChannels.add(chan);
Expand All @@ -380,12 +391,63 @@ export function _applyWrites<Cc extends Record<string, BaseChannel>>(
const channel = onlyChannels[chan];
if (channel.finish() && getNextVersion !== undefined) {
checkpoint.channel_versions[chan] = getNextVersion(maxVersion);
usedNextVersion = true;

// unavailable channels can't trigger tasks, so don't add them
if (channel.isAvailable()) updatedChannels.add(chan);
}
}
}

if (maxVersion != null && getNextVersion != null && usedNextVersion) {
maxVersion = getNextVersion(maxVersion);
usedNextVersion = false;
MAX_CHANNEL_MAP_VERSION_CACHE.set(checkpoint.channel_versions, maxVersion!);
}

return updatedChannels;
}

function* candidateNodes(
checkpoint: ReadonlyCheckpoint,
processes: StrRecord<string, PregelNode>,
extra: NextTaskExtraFields
) {
// This section is an optimization that allows which
// nodes will be active during the next step.
// When there's information about:
// 1. Which channels were updated in the previous step
// 2. Which nodes are triggered by which channels
// Then we can determine which nodes should be triggered
// in the next step without having to cycle through all nodes.
if (extra.updatedChannels != null && extra.triggerToNodes != null) {
const triggeredNodes = new Set<string>();

// Get all nodes that have triggers associated with an updated channel
for (const channel of extra.updatedChannels) {
const nodeIds = extra.triggerToNodes[channel];
for (const id of nodeIds ?? []) triggeredNodes.add(id);
}

// Sort the nodes to ensure deterministic order
yield* [...triggeredNodes].sort();
return;
}

// If there are no values in checkpoint, no need to run
// through all the PULL candidates
const isEmptyChannelVersions = (() => {
for (const chan in checkpoint.channel_versions) {
if (checkpoint.channel_versions[chan] !== null) return false;
}
return true;
})();

if (isEmptyChannelVersions) return;
for (const name in processes) {
if (!Object.prototype.hasOwnProperty.call(processes, name)) continue;
yield name;
}
}

export type NextTaskExtraFields = {
Expand All @@ -395,6 +457,8 @@ export type NextTaskExtraFields = {
manager?: CallbackManagerForChainRun;
store?: BaseStore;
stream?: IterableReadableWritableStream;
triggerToNodes?: Record<string, string[]>;
updatedChannels?: Set<string>;
};

export type NextTaskExtraFieldsWithStore = NextTaskExtraFields & {
Expand Down Expand Up @@ -478,8 +542,7 @@ export function _prepareNextTasks<

// Check if any processes should be run in next step
// If so, prepare the values to be passed to them
for (const name in processes) {
if (!Object.prototype.hasOwnProperty.call(processes, name)) continue;
for (const name of candidateNodes(checkpoint, processes, extra)) {
const task = _prepareSingleTask(
[PULL, name],
checkpoint,
Expand Down
1 change: 1 addition & 0 deletions libs/langgraph/src/pregel/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,7 @@ export class Pregel<
this.store = fields.store;
this.cache = fields.cache;
this.name = fields.name;
this.triggerToNodes = fields.triggerToNodes ?? this.triggerToNodes;

if (this.autoValidate) {
this.validate();
Expand Down
10 changes: 7 additions & 3 deletions libs/langgraph/src/pregel/loop.ts
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,8 @@ export class PregelLoop {

protected prevCheckpointConfig: RunnableConfig | undefined;

protected updatedChannels: Set<string> | undefined;

status:
| "pending"
| "done"
Expand Down Expand Up @@ -701,7 +703,7 @@ export class PregelLoop {
// finish superstep
const writes = Object.values(this.tasks).flatMap((t) => t.writes);
// All tasks have finished
_applyWrites(
this.updatedChannels = _applyWrites(
this.checkpoint,
this.channels,
Object.values(this.tasks),
Expand Down Expand Up @@ -757,6 +759,8 @@ export class PregelLoop {
manager: this.manager,
store: this.store,
stream: this.stream,
triggerToNodes: this.triggerToNodes,
updatedChannels: this.updatedChannels,
}
);
this.tasks = nextTasks;
Expand Down Expand Up @@ -857,7 +861,7 @@ export class PregelLoop {
this.checkpointPendingWrites.length > 0 &&
Object.values(this.tasks).some((task) => task.writes.length > 0)
) {
_applyWrites(
this.updatedChannels = _applyWrites(
this.checkpoint,
this.channels,
Object.values(this.tasks),
Expand Down Expand Up @@ -1067,7 +1071,7 @@ export class PregelLoop {
true,
{ step: this.step }
);
_applyWrites(
this.updatedChannels = _applyWrites(
this.checkpoint,
this.channels,
(Object.values(discardTasks) as WritesProtocol[]).concat([
Expand Down
6 changes: 6 additions & 0 deletions libs/langgraph/src/pregel/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,12 @@ export type PregelParams<
* Storage used for node caching.
*/
cache?: BaseCache;

/**
* The trigger to node mapping for the graph run.
* @internal
*/
triggerToNodes?: Record<string, string[]>;
};

export interface PregelTaskDescription {
Expand Down
Loading