Skip to content

Commit f13068b

Browse files
pzhan9facebook-github-bot
authored andcommitted
Refactoring: move verify_casting out so it can be re-used by other tests
Differential Revision: D83001963
1 parent cd7f711 commit f13068b

File tree

2 files changed

+82
-76
lines changed

2 files changed

+82
-76
lines changed

hyperactor_mesh/src/v1/actor_mesh.rs

Lines changed: 41 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ use serde::Deserializer;
3434
use serde::Serialize;
3535
use serde::Serializer;
3636

37+
use crate::CommActor;
3738
use crate::actor_mesh as v0_actor_mesh;
3839
use crate::comm::multicast;
3940
use crate::reference::ActorMeshId;
@@ -135,33 +136,7 @@ impl<A: Actor + RemoteActor> ActorMeshRef<A> {
135136
M: Castable + RemoteMessage + Clone, // Clone is required until we are fully onto comm actor
136137
{
137138
if let Some(root_comm_actor) = self.proc_mesh.root_comm_actor() {
138-
let cast_mesh_shape = view::Ranked::region(self).into();
139-
let actor_mesh_id = ActorMeshId::V1(self.name.clone());
140-
match &self.proc_mesh.root_region {
141-
Some(root_region) => {
142-
let root_mesh_shape = root_region.into();
143-
v0_actor_mesh::cast_to_sliced_mesh::<A, M>(
144-
cx,
145-
actor_mesh_id,
146-
root_comm_actor,
147-
&sel!(*),
148-
message,
149-
&cast_mesh_shape,
150-
&root_mesh_shape,
151-
)
152-
.map_err(|e| Error::CastingError(self.name.clone(), e.into()))
153-
}
154-
None => v0_actor_mesh::actor_mesh_cast::<A, M>(
155-
cx,
156-
actor_mesh_id,
157-
root_comm_actor,
158-
sel!(*),
159-
&cast_mesh_shape,
160-
&cast_mesh_shape,
161-
message,
162-
)
163-
.map_err(|e| Error::CastingError(self.name.clone(), e.into())),
164-
}
139+
self.cast_v0(cx, message, root_comm_actor)
165140
} else {
166141
for (point, actor) in self.iter() {
167142
let mut headers = Attrs::new();
@@ -179,6 +154,45 @@ impl<A: Actor + RemoteActor> ActorMeshRef<A> {
179154
}
180155
}
181156

157+
fn cast_v0<M>(
158+
&self,
159+
cx: &impl context::Actor,
160+
message: M,
161+
root_comm_actor: &ActorRef<CommActor>,
162+
) -> v1::Result<()>
163+
where
164+
A: RemoteHandles<M> + RemoteHandles<IndexedErasedUnbound<M>>,
165+
M: Castable + RemoteMessage + Clone, // Clone is required until we are fully onto comm actor
166+
{
167+
let cast_mesh_shape = view::Ranked::region(self).into();
168+
let actor_mesh_id = ActorMeshId::V1(self.name.clone());
169+
match &self.proc_mesh.root_region {
170+
Some(root_region) => {
171+
let root_mesh_shape = root_region.into();
172+
v0_actor_mesh::cast_to_sliced_mesh::<A, M>(
173+
cx,
174+
actor_mesh_id,
175+
root_comm_actor,
176+
&sel!(*),
177+
message,
178+
&cast_mesh_shape,
179+
&root_mesh_shape,
180+
)
181+
.map_err(|e| Error::CastingError(self.name.clone(), e.into()))
182+
}
183+
None => v0_actor_mesh::actor_mesh_cast::<A, M>(
184+
cx,
185+
actor_mesh_id,
186+
root_comm_actor,
187+
sel!(*),
188+
&cast_mesh_shape,
189+
&cast_mesh_shape,
190+
message,
191+
)
192+
.map_err(|e| Error::CastingError(self.name.clone(), e.into())),
193+
}
194+
}
195+
182196
pub async fn supervision_events(
183197
&self,
184198
cx: &impl context::Actor,

hyperactor_mesh/src/v1/proc_mesh.rs

Lines changed: 41 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -591,7 +591,9 @@ impl view::RankedSliceable for ProcMeshRef {
591591
#[cfg(test)]
592592
mod tests {
593593
use std::collections::HashSet;
594+
use std::time::Duration;
594595

596+
use hyperactor::Instance;
595597
use hyperactor::clock::Clock;
596598
use hyperactor::clock::RealClock;
597599
use hyperactor::mailbox;
@@ -626,6 +628,34 @@ mod tests {
626628
);
627629
}
628630

631+
async fn verify_casting(
632+
actor_mesh: &ActorMeshRef<testactor::TestActor>,
633+
instance: &Instance<()>,
634+
) {
635+
let (port, mut rx) = mailbox::open_port(&instance);
636+
actor_mesh
637+
.cast(&instance, testactor::GetActorId(port.bind()))
638+
.unwrap();
639+
640+
let mut expected_actor_ids: HashSet<_> = actor_mesh
641+
.values()
642+
.map(|actor_ref| actor_ref.actor_id().clone())
643+
.collect();
644+
645+
while !expected_actor_ids.is_empty() {
646+
let actor_id = rx.recv().await.unwrap();
647+
assert!(
648+
expected_actor_ids.remove(&actor_id),
649+
"got {actor_id}, expect {expected_actor_ids:?}"
650+
);
651+
}
652+
653+
// No more messages
654+
RealClock.sleep(Duration::from_secs(1)).await;
655+
let result = rx.try_recv();
656+
assert!(result.as_ref().unwrap().is_none(), "got {result:?}");
657+
}
658+
629659
#[async_timed_test(timeout_secs = 30)]
630660
async fn test_spawn_actor() {
631661
hyperactor_telemetry::initialize_logging(hyperactor::clock::ClockKind::default());
@@ -636,58 +666,20 @@ mod tests {
636666
let actor_mesh: ActorMesh<testactor::TestActor> =
637667
proc_mesh.spawn(instance, "test", &()).await.unwrap();
638668

639-
// Verify casting to the root actor mesh
640-
{
641-
let (port, mut rx) = mailbox::open_port(&instance);
642-
actor_mesh
643-
.cast(instance, testactor::GetActorId(port.bind()))
644-
.unwrap();
645-
646-
let mut expected_actor_ids: HashSet<_> = actor_mesh
647-
.values()
648-
.map(|actor_ref| actor_ref.actor_id().clone())
649-
.collect();
650-
651-
while !expected_actor_ids.is_empty() {
652-
let actor_id = rx.recv().await.unwrap();
653-
assert!(
654-
expected_actor_ids.remove(&actor_id),
655-
"got {actor_id}, expect {expected_actor_ids:?}"
656-
);
657-
}
658-
659-
// No more messages
660-
RealClock.sleep(Duration::from_secs(1)).await;
661-
let result = rx.try_recv();
662-
assert!(result.as_ref().unwrap().is_none(), "got {result:?}");
663-
}
669+
// First cast. The seq should be 1 for all actors.
670+
verify_casting(&actor_mesh, &instance).await;
664671

665672
// Verify casting to the sliced actor mesh
666673
let sliced_actor_mesh = actor_mesh.range("replicas", 1..3).unwrap();
667-
{
668-
let (port, mut rx) = mailbox::open_port(instance);
669-
sliced_actor_mesh
670-
.cast(instance, testactor::GetActorId(port.bind()))
671-
.unwrap();
672-
673-
let mut expected_actor_ids: HashSet<_> = sliced_actor_mesh
674-
.values()
675-
.map(|actor_ref| actor_ref.actor_id().clone())
676-
.collect();
677-
678-
while !expected_actor_ids.is_empty() {
679-
let actor_id = rx.recv().await.unwrap();
680-
assert!(
681-
expected_actor_ids.remove(&actor_id),
682-
"got {actor_id}, expect {expected_actor_ids:?}"
683-
);
684-
}
685-
686-
// No more messages
687-
RealClock.sleep(Duration::from_secs(1)).await;
688-
let result = rx.try_recv();
689-
assert!(result.as_ref().unwrap().is_none(), "got {result:?}");
690-
}
674+
// Second cast. The seq should be 2 for actors in the sliced mesh.
675+
verify_casting(&sliced_actor_mesh, &instance).await;
676+
677+
// Verify casting to a different sliced actor mesh
678+
let sliced_actor_mesh = actor_mesh.range("replicas", 0..2).unwrap();
679+
// For actors in the previous sliced mesh, the seq should be 3 since
680+
// this is the third cast for them. For other actors, the seq should
681+
// be 2.
682+
verify_casting(&sliced_actor_mesh, &instance).await;
691683
}
692684
}
693685
}

0 commit comments

Comments
 (0)