Skip to content

Commit 9f5d249

Browse files
pzhan9facebook-github-bot
authored andcommitted
Refactoring: move verify_casting out so it can be re-used by other tests (#1303)
Summary: Two refactorings to make the next diff D82537988 looks less busy: 1. extract `cast_v0`; 2. extract `verify_casting`. Differential Revision: D83001963
1 parent a925382 commit 9f5d249

File tree

3 files changed

+73
-75
lines changed

3 files changed

+73
-75
lines changed

hyperactor_mesh/src/v1/actor_mesh.rs

Lines changed: 41 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ use serde::Deserializer;
3434
use serde::Serialize;
3535
use serde::Serializer;
3636

37+
use crate::CommActor;
3738
use crate::actor_mesh as v0_actor_mesh;
3839
use crate::comm::multicast;
3940
use crate::reference::ActorMeshId;
@@ -135,33 +136,7 @@ impl<A: Actor + RemoteActor> ActorMeshRef<A> {
135136
M: Castable + RemoteMessage + Clone, // Clone is required until we are fully onto comm actor
136137
{
137138
if let Some(root_comm_actor) = self.proc_mesh.root_comm_actor() {
138-
let cast_mesh_shape = view::Ranked::region(self).into();
139-
let actor_mesh_id = ActorMeshId::V1(self.name.clone());
140-
match &self.proc_mesh.root_region {
141-
Some(root_region) => {
142-
let root_mesh_shape = root_region.into();
143-
v0_actor_mesh::cast_to_sliced_mesh::<A, M>(
144-
cx,
145-
actor_mesh_id,
146-
root_comm_actor,
147-
&sel!(*),
148-
message,
149-
&cast_mesh_shape,
150-
&root_mesh_shape,
151-
)
152-
.map_err(|e| Error::CastingError(self.name.clone(), e.into()))
153-
}
154-
None => v0_actor_mesh::actor_mesh_cast::<A, M>(
155-
cx,
156-
actor_mesh_id,
157-
root_comm_actor,
158-
sel!(*),
159-
&cast_mesh_shape,
160-
&cast_mesh_shape,
161-
message,
162-
)
163-
.map_err(|e| Error::CastingError(self.name.clone(), e.into())),
164-
}
139+
self.cast_v0(cx, message, root_comm_actor)
165140
} else {
166141
for (point, actor) in self.iter() {
167142
let mut headers = Attrs::new();
@@ -179,6 +154,45 @@ impl<A: Actor + RemoteActor> ActorMeshRef<A> {
179154
}
180155
}
181156

157+
fn cast_v0<M>(
158+
&self,
159+
cx: &impl context::Actor,
160+
message: M,
161+
root_comm_actor: &ActorRef<CommActor>,
162+
) -> v1::Result<()>
163+
where
164+
A: RemoteHandles<M> + RemoteHandles<IndexedErasedUnbound<M>>,
165+
M: Castable + RemoteMessage + Clone, // Clone is required until we are fully onto comm actor
166+
{
167+
let cast_mesh_shape = view::Ranked::region(self).into();
168+
let actor_mesh_id = ActorMeshId::V1(self.name.clone());
169+
match &self.proc_mesh.root_region {
170+
Some(root_region) => {
171+
let root_mesh_shape = root_region.into();
172+
v0_actor_mesh::cast_to_sliced_mesh::<A, M>(
173+
cx,
174+
actor_mesh_id,
175+
root_comm_actor,
176+
&sel!(*),
177+
message,
178+
&cast_mesh_shape,
179+
&root_mesh_shape,
180+
)
181+
.map_err(|e| Error::CastingError(self.name.clone(), e.into()))
182+
}
183+
None => v0_actor_mesh::actor_mesh_cast::<A, M>(
184+
cx,
185+
actor_mesh_id,
186+
root_comm_actor,
187+
sel!(*),
188+
&cast_mesh_shape,
189+
&cast_mesh_shape,
190+
message,
191+
)
192+
.map_err(|e| Error::CastingError(self.name.clone(), e.into())),
193+
}
194+
}
195+
182196
pub async fn supervision_events(
183197
&self,
184198
cx: &impl context::Actor,

hyperactor_mesh/src/v1/proc_mesh.rs

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -631,11 +631,6 @@ impl view::RankedSliceable for ProcMeshRef {
631631

632632
#[cfg(test)]
633633
mod tests {
634-
use std::collections::HashSet;
635-
636-
use hyperactor::clock::Clock;
637-
use hyperactor::clock::RealClock;
638-
use hyperactor::mailbox;
639634
use ndslice::ViewExt;
640635
use ndslice::extent;
641636
use timed_test::async_timed_test;

hyperactor_mesh/src/v1/testactor.rs

Lines changed: 32 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111
//! the bootstrap binary, which is not built in test mode (and anyway, test mode
1212
//! does not work across crate boundaries)
1313
14+
#[cfg(test)]
1415
use std::collections::HashSet;
1516
use std::collections::VecDeque;
17+
#[cfg(test)]
1618
use std::time::Duration;
1719

1820
use async_trait::async_trait;
@@ -27,17 +29,23 @@ use hyperactor::Named;
2729
use hyperactor::PortRef;
2830
use hyperactor::RefClient;
2931
use hyperactor::Unbind;
32+
#[cfg(test)]
3033
use hyperactor::clock::Clock as _;
34+
#[cfg(test)]
3135
use hyperactor::clock::RealClock;
36+
#[cfg(test)]
3237
use hyperactor::mailbox;
3338
use hyperactor::supervision::ActorSupervisionEvent;
3439
use ndslice::Point;
35-
use ndslice::ViewExt;
40+
#[cfg(test)]
41+
use ndslice::ViewExt as _;
3642
use serde::Deserialize;
3743
use serde::Serialize;
3844

3945
use crate::comm::multicast::CastInfo;
46+
#[cfg(test)]
4047
use crate::v1::ActorMesh;
48+
#[cfg(test)]
4149
use crate::v1::ActorMeshRef;
4250
#[cfg(test)]
4351
use crate::v1::testing;
@@ -212,28 +220,7 @@ impl Handler<GetCastInfo> for TestActor {
212220
pub async fn assert_mesh_shape(actor_mesh: ActorMesh<TestActor>) {
213221
let instance = testing::instance().await;
214222
// Verify casting to the root actor mesh
215-
{
216-
let (port, mut rx) = mailbox::open_port(&instance);
217-
actor_mesh.cast(instance, GetActorId(port.bind())).unwrap();
218-
219-
let mut expected_actor_ids: HashSet<_> = actor_mesh
220-
.values()
221-
.map(|actor_ref| actor_ref.actor_id().clone())
222-
.collect();
223-
224-
while !expected_actor_ids.is_empty() {
225-
let actor_id = rx.recv().await.unwrap();
226-
assert!(
227-
expected_actor_ids.remove(&actor_id),
228-
"got {actor_id}, expect {expected_actor_ids:?}"
229-
);
230-
}
231-
232-
// No more messages
233-
RealClock.sleep(Duration::from_secs(1)).await;
234-
let result = rx.try_recv();
235-
assert!(result.as_ref().unwrap().is_none(), "got {result:?}");
236-
}
223+
verify_casting(&actor_mesh, instance).await;
237224

238225
// Just pick the first dimension. Slice half of it off.
239226
// actor_mesh.extent().
@@ -242,28 +229,30 @@ pub async fn assert_mesh_shape(actor_mesh: ActorMesh<TestActor>) {
242229

243230
// Verify casting to the sliced actor mesh
244231
let sliced_actor_mesh = actor_mesh.range(&label, 0..size).unwrap();
245-
{
246-
let (port, mut rx) = mailbox::open_port(instance);
247-
sliced_actor_mesh
248-
.cast(instance, GetActorId(port.bind()))
249-
.unwrap();
232+
verify_casting(&sliced_actor_mesh, instance).await;
233+
}
250234

251-
let mut expected_actor_ids: HashSet<_> = sliced_actor_mesh
252-
.values()
253-
.map(|actor_ref| actor_ref.actor_id().clone())
254-
.collect();
235+
#[cfg(test)]
236+
/// Cast to the actor mesh, and verify that all actors are reached.
237+
pub async fn verify_casting(actor_mesh: &ActorMeshRef<TestActor>, instance: &Instance<()>) {
238+
let (port, mut rx) = mailbox::open_port(instance);
239+
actor_mesh.cast(instance, GetActorId(port.bind())).unwrap();
255240

256-
while !expected_actor_ids.is_empty() {
257-
let actor_id = rx.recv().await.unwrap();
258-
assert!(
259-
expected_actor_ids.remove(&actor_id),
260-
"got {actor_id}, expect {expected_actor_ids:?}"
261-
);
262-
}
241+
let mut expected_actor_ids: HashSet<_> = actor_mesh
242+
.values()
243+
.map(|actor_ref| actor_ref.actor_id().clone())
244+
.collect();
263245

264-
// No more messages
265-
RealClock.sleep(Duration::from_secs(1)).await;
266-
let result = rx.try_recv();
267-
assert!(result.as_ref().unwrap().is_none(), "got {result:?}");
246+
while !expected_actor_ids.is_empty() {
247+
let actor_id = rx.recv().await.unwrap();
248+
assert!(
249+
expected_actor_ids.remove(&actor_id),
250+
"got {actor_id}, expect {expected_actor_ids:?}"
251+
);
268252
}
253+
254+
// No more messages
255+
RealClock.sleep(Duration::from_secs(1)).await;
256+
let result = rx.try_recv();
257+
assert!(result.as_ref().unwrap().is_none(), "got {result:?}");
269258
}

0 commit comments

Comments
 (0)