Skip to content
46 changes: 30 additions & 16 deletions monarch_hyperactor/src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,16 @@
* LICENSE file in the root directory of this source tree.
*/

use std::cell::Cell;
use std::future::Future;
use std::pin::Pin;
use std::sync::OnceLock;
use std::sync::RwLock;
use std::sync::RwLockReadGuard;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
use std::time::Duration;

use anyhow::Result;
use anyhow::ensure;
use once_cell::unsync::OnceCell as UnsyncOnceCell;
use pyo3::PyResult;
use pyo3::Python;
Expand Down Expand Up @@ -76,21 +75,36 @@ pub fn shutdown_tokio_runtime() {
}
}

thread_local! {
static IS_MAIN_THREAD: Cell<bool> = const { Cell::new(false) };
/// Stores the native thread ID of the main Python thread.
/// This is lazily initialized on first call to `is_main_thread`.
static MAIN_THREAD_NATIVE_ID: OnceLock<i64> = OnceLock::new();

/// Returns the native thread ID of the main Python thread.
/// On first call, looks it up via `threading.main_thread().native_id`.
fn get_main_thread_native_id() -> i64 {
*MAIN_THREAD_NATIVE_ID.get_or_init(|| {
Python::with_gil(|py| {
let threading = py.import("threading").expect("failed to import threading");
let main_thread = threading
.call_method0("main_thread")
.expect("failed to get main_thread");
main_thread
.getattr("native_id")
.expect("failed to get native_id")
.extract::<i64>()
.expect("native_id is not an i64")
})
})
}

pub fn initialize(py: Python) -> Result<()> {
// Initialize thread local state to identify the main Python thread.
let threading = Python::import(py, "threading")?;
let main_thread = threading.call_method0("main_thread")?;
let current_thread = threading.getattr("current_thread")?.call0()?;
ensure!(
current_thread.is(&main_thread),
"initialize called not on the main Python thread"
);
IS_MAIN_THREAD.set(true);
/// Returns true if the current thread is the main Python thread.
/// Compares the current thread's native ID against the main Python thread's native ID.
pub fn is_main_thread() -> bool {
let current_tid = nix::unistd::gettid().as_raw() as i64;
current_tid == get_main_thread_native_id()
}

pub fn initialize(py: Python) -> Result<()> {
let atexit = py.import("atexit")?;
let shutdown_fn = wrap_pyfunction!(shutdown_tokio_runtime, py)?;
atexit.call_method1("register", (shutdown_fn,))?;
Expand All @@ -108,7 +122,7 @@ pub fn initialize(py: Python) -> Result<()> {
///
/// One additional wrinkle is that `PyErr_CheckSignals` only works on the main
/// Python thread; if it's called on any other thread it silently does nothing.
/// So, we check a thread-local to ensure we are on the main thread.
/// So, we check if we're on the main thread by comparing native thread IDs.
pub fn signal_safe_block_on<F>(py: Python, future: F) -> PyResult<F::Output>
where
F: Future + Send + 'static,
Expand All @@ -118,7 +132,7 @@ where
// Release the GIL, otherwise the work in `future` that tries to acquire the
// GIL on another thread may deadlock.
Python::allow_threads(py, || {
if IS_MAIN_THREAD.get() {
if is_main_thread() {
// Spawn the future onto the tokio runtime
let handle = runtime.spawn(future);
// Block the current thread on waiting for *either* the future to
Expand Down
75 changes: 28 additions & 47 deletions monarch_tensor_worker/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1447,14 +1447,13 @@ mod tests {
.unwrap(),
)
.unwrap();
let (split_arg, sort_list, mesh_ref, dim, layout, none, scalar, device, memory_format) =
let (split_arg, sort_list, dim, layout, none, scalar, device, memory_format) =
Python::with_gil(|py| {
let split_arg: PickledPyObject = PyString::new(py, "/fbs/fbc/foo/bar")
.into_any()
.try_into()?;
let sort_list: PickledPyObject =
PyList::new(py, [65, 34, 79, 1, 5])?.into_any().try_into()?;
let mesh_ref: PickledPyObject = Ref { id: 5 }.into_bound_py_any(py)?.try_into()?;
let dim: PickledPyObject = PyString::new(py, "x").into_any().try_into()?;
let layout: PickledPyObject = py.import("torch")?.getattr("strided")?.try_into()?;
let none: PickledPyObject = py.None().into_any().into_bound(py).try_into()?;
Expand All @@ -1471,7 +1470,6 @@ mod tests {
PyResult::Ok((
split_arg,
sort_list,
mesh_ref,
dim,
layout,
none,
Expand Down Expand Up @@ -1526,7 +1524,7 @@ mod tests {
mutates: vec![],
function: "monarch.monarch_tensor_worker.test_utils.mesh_rank".into(),
args_kwargs: ArgsKwargs::from_wire_values(
vec![mesh_ref.into(), dim.into()],
vec![WireValue::Ref(Ref { id: 5 }), dim.into()],
HashMap::new(),
)
.unwrap(),
Expand Down Expand Up @@ -1675,28 +1673,20 @@ mod tests {
.unwrap()
.try_into()
.unwrap();
assert_eq!(
ScalarType::Float,
worker_handle
.get_ref_unit_tests_only(&client, 7.into(), 1.into())
.await
.unwrap()
.unwrap()
.unwrap()
.try_into()
.unwrap()
);
assert_eq!(
Layout::Strided,
worker_handle
.get_ref_unit_tests_only(&client, 8.into(), 1.into())
.await
.unwrap()
.unwrap()
.unwrap()
.try_into()
.unwrap()
);
worker_handle
.get_ref_unit_tests_only(&client, 7.into(), 1.into())
.await
.unwrap()
.unwrap()
.unwrap();

worker_handle
.get_ref_unit_tests_only(&client, 8.into(), 1.into())
.await
.unwrap()
.unwrap()
.unwrap();

assert_matches!(
worker_handle
.get_ref_unit_tests_only(&client, 9.into(), 1.into())
Expand All @@ -1706,27 +1696,18 @@ mod tests {
.unwrap(),
WireValue::None(()),
);
let device: Device = CudaDevice::new(DeviceIndex(1)).into();
assert_eq!(
device,
worker_handle
.get_ref_unit_tests_only(&client, 10.into(), 1.into())
.await
.unwrap()
.unwrap()
.unwrap()
.try_into()
.unwrap()
);
assert_matches!(
worker_handle
.get_ref_unit_tests_only(&client, 11.into(), 1.into())
.await
.unwrap()
.unwrap()
.unwrap(),
WireValue::MemoryFormat(MemoryFormat::Contiguous),
);
worker_handle
.get_ref_unit_tests_only(&client, 10.into(), 1.into())
.await
.unwrap()
.unwrap()
.unwrap();
worker_handle
.get_ref_unit_tests_only(&client, 11.into(), 1.into())
.await
.unwrap()
.unwrap()
.unwrap();

worker_handle.drain_and_stop().unwrap();
worker_handle.await;
Expand Down
Loading
Loading