Skip to content

Commit 7ce4914

Browse files
committed
add sub agent escalation
1 parent b3ca86e commit 7ce4914

File tree

2 files changed

+153
-4
lines changed

2 files changed

+153
-4
lines changed

core/src/main/java/com/google/adk/agents/ParallelAgent.java

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import com.google.adk.agents.ConfigAgentUtils.ConfigurationException;
2121
import com.google.adk.events.Event;
2222
import io.reactivex.rxjava3.core.Flowable;
23+
import io.reactivex.rxjava3.core.Scheduler;
2324
import java.util.ArrayList;
2425
import java.util.List;
2526
import org.slf4j.Logger;
@@ -35,6 +36,7 @@
3536
public class ParallelAgent extends BaseAgent {
3637

3738
private static final Logger logger = LoggerFactory.getLogger(ParallelAgent.class);
39+
private final Scheduler scheduler;
3840

3941
/**
4042
* Constructor for ParallelAgent.
@@ -44,24 +46,34 @@ public class ParallelAgent extends BaseAgent {
4446
* @param subAgents The list of sub-agents to run in parallel.
4547
* @param beforeAgentCallback Optional callback before the agent runs.
4648
* @param afterAgentCallback Optional callback after the agent runs.
49+
* @param scheduler The scheduler to use for parallel execution.
4750
*/
4851
private ParallelAgent(
4952
String name,
5053
String description,
5154
List<? extends BaseAgent> subAgents,
5255
List<Callbacks.BeforeAgentCallback> beforeAgentCallback,
53-
List<Callbacks.AfterAgentCallback> afterAgentCallback) {
56+
List<Callbacks.AfterAgentCallback> afterAgentCallback,
57+
Scheduler scheduler) {
5458

5559
super(name, description, subAgents, beforeAgentCallback, afterAgentCallback);
60+
this.scheduler = scheduler;
5661
}
5762

5863
/** Builder for {@link ParallelAgent}. */
5964
public static class Builder extends BaseAgent.Builder<Builder> {
6065

66+
private Scheduler scheduler = io.reactivex.rxjava3.schedulers.Schedulers.io();
67+
68+
public Builder scheduler(Scheduler scheduler) {
69+
this.scheduler = scheduler;
70+
return this;
71+
}
72+
6173
@Override
6274
public ParallelAgent build() {
6375
return new ParallelAgent(
64-
name, description, subAgents, beforeAgentCallback, afterAgentCallback);
76+
name, description, subAgents, beforeAgentCallback, afterAgentCallback, scheduler);
6577
}
6678
}
6779

@@ -131,9 +143,10 @@ protected Flowable<Event> runAsyncImpl(InvocationContext invocationContext) {
131143

132144
List<Flowable<Event>> agentFlowables = new ArrayList<>();
133145
for (BaseAgent subAgent : currentSubAgents) {
134-
agentFlowables.add(subAgent.runAsync(invocationContext));
146+
agentFlowables.add(subAgent.runAsync(invocationContext).subscribeOn(scheduler));
135147
}
136-
return Flowable.merge(agentFlowables);
148+
return Flowable.merge(agentFlowables)
149+
.takeUntil((Event event) -> event.actions().escalate().orElse(false));
137150
}
138151

139152
/**
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
/*
2+
* Copyright 2025 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.adk.agents;
18+
19+
import static com.google.adk.testing.TestUtils.createInvocationContext;
20+
import static com.google.common.truth.Truth.assertThat;
21+
import static java.util.concurrent.TimeUnit.MILLISECONDS;
22+
23+
import com.google.adk.events.Event;
24+
import com.google.adk.events.EventActions;
25+
import com.google.common.collect.ImmutableList;
26+
import com.google.genai.types.Content;
27+
import com.google.genai.types.Part;
28+
import io.reactivex.rxjava3.core.Flowable;
29+
import io.reactivex.rxjava3.core.Scheduler;
30+
import io.reactivex.rxjava3.schedulers.TestScheduler;
31+
import org.junit.Test;
32+
import org.junit.runner.RunWith;
33+
import org.junit.runners.JUnit4;
34+
35+
@RunWith(JUnit4.class)
36+
public final class ParallelAgentEscalationTest {
37+
38+
static class EscalatingAgent extends BaseAgent {
39+
private final long delayMillis;
40+
private final Scheduler scheduler;
41+
42+
private EscalatingAgent(String name, long delayMillis, Scheduler scheduler) {
43+
super(name, "Escalating Agent", ImmutableList.of(), null, null);
44+
this.delayMillis = delayMillis;
45+
this.scheduler = scheduler;
46+
}
47+
48+
@Override
49+
protected Flowable<Event> runAsyncImpl(InvocationContext invocationContext) {
50+
Flowable<Event> event =
51+
Flowable.fromCallable(
52+
() ->
53+
Event.builder()
54+
.author(name())
55+
.branch(invocationContext.branch())
56+
.invocationId(invocationContext.invocationId())
57+
.content(Content.fromParts(Part.fromText("Escalating!")))
58+
.actions(EventActions.builder().escalate(true).build())
59+
.build());
60+
61+
if (delayMillis > 0) {
62+
return event.delay(delayMillis, MILLISECONDS, scheduler);
63+
}
64+
return event;
65+
}
66+
67+
@Override
68+
protected Flowable<Event> runLiveImpl(InvocationContext invocationContext) {
69+
throw new UnsupportedOperationException("Not implemented");
70+
}
71+
}
72+
73+
static class SlowAgent extends BaseAgent {
74+
private final long delayMillis;
75+
private final Scheduler scheduler;
76+
77+
private SlowAgent(String name, long delayMillis, Scheduler scheduler) {
78+
super(name, "Slow Agent", ImmutableList.of(), null, null);
79+
this.delayMillis = delayMillis;
80+
this.scheduler = scheduler;
81+
}
82+
83+
@Override
84+
protected Flowable<Event> runAsyncImpl(InvocationContext invocationContext) {
85+
Flowable<Event> event =
86+
Flowable.fromCallable(
87+
() ->
88+
Event.builder()
89+
.author(name())
90+
.branch(invocationContext.branch())
91+
.invocationId(invocationContext.invocationId())
92+
.content(Content.fromParts(Part.fromText("Finished")))
93+
.build());
94+
95+
if (delayMillis > 0) {
96+
return event.delay(delayMillis, MILLISECONDS, scheduler);
97+
}
98+
return event;
99+
}
100+
101+
@Override
102+
protected Flowable<Event> runLiveImpl(InvocationContext invocationContext) {
103+
throw new UnsupportedOperationException("Not implemented");
104+
}
105+
}
106+
107+
@Test
108+
public void runAsync_escalationEvent_shortCircuitsOtherAgents() {
109+
TestScheduler testScheduler = new TestScheduler();
110+
111+
EscalatingAgent agent1 = new EscalatingAgent("agent1", 100, testScheduler);
112+
SlowAgent agent2 = new SlowAgent("agent2", 500, testScheduler);
113+
114+
ParallelAgent parallelAgent =
115+
ParallelAgent.builder()
116+
.name("parallel_agent")
117+
.subAgents(agent1, agent2)
118+
.scheduler(testScheduler)
119+
.build();
120+
121+
InvocationContext invocationContext = createInvocationContext(parallelAgent);
122+
123+
var subscriber = parallelAgent.runAsync(invocationContext).test();
124+
125+
testScheduler.advanceTimeBy(200, MILLISECONDS);
126+
127+
subscriber.assertValueCount(1);
128+
Event event = subscriber.values().get(0);
129+
assertThat(event.author()).isEqualTo("agent1");
130+
assertThat(event.actions().escalate()).hasValue(true);
131+
132+
subscriber.assertComplete();
133+
testScheduler.advanceTimeBy(1000, MILLISECONDS);
134+
subscriber.assertValueCount(1);
135+
}
136+
}

0 commit comments

Comments
 (0)