Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions hyperactor/src/attrs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,10 @@

use std::any::Any;
use std::collections::HashMap;
use std::fmt::Display;
use std::ops::Index;
use std::ops::IndexMut;
use std::str::FromStr;
use std::sync::LazyLock;

use chrono::DateTime;
Expand Down Expand Up @@ -355,6 +357,33 @@ impl AttrValue for std::time::SystemTime {
}
}

impl<T, E> AttrValue for std::ops::Range<T>
where
T: Named
+ Display
+ FromStr<Err = E>
+ Send
+ Sync
+ Serialize
+ DeserializeOwned
+ Clone
+ 'static,
E: Into<anyhow::Error> + Send + Sync + 'static,
{
fn display(&self) -> String {
format!("{}..{}", self.start, self.end)
}

fn parse(value: &str) -> Result<Self, anyhow::Error> {
let (start, end) = value.split_once("..").ok_or_else(|| {
anyhow::anyhow!("expected range in format `start..end`, got `{}`", value)
})?;
let start = start.parse().map_err(|e: E| e.into())?;
let end = end.parse().map_err(|e: E| e.into())?;
Ok(start..end)
}
}

// Internal trait for type-erased serialization
#[doc(hidden)]
pub trait SerializableValue: Send + Sync {
Expand Down
6 changes: 6 additions & 0 deletions hyperactor/src/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,12 @@ impl<T: Named + 'static, E: Named + 'static> Named for Result<T, E> {
}
}

impl<T: Named + 'static> Named for std::ops::Range<T> {
fn typename() -> &'static str {
intern_typename!(Self, "std::ops::Range<{}>", T)
}
}

static SHAPE_CACHED_TYPEHASH: LazyLock<u64> =
LazyLock::new(|| cityhasher::hash(<ndslice::shape::Shape as Named>::typename()));

Expand Down
209 changes: 209 additions & 0 deletions hyperactor_mesh/src/alloc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,33 @@ pub mod sim;

use std::collections::HashMap;
use std::fmt;
use std::net::IpAddr;
use std::net::Ipv4Addr;
use std::net::Ipv6Addr;
use std::net::SocketAddr;
use std::net::TcpListener;
use std::ops::Range;
use std::sync::Mutex;
use std::sync::OnceLock;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;

use async_trait::async_trait;
use enum_as_inner::EnumAsInner;
use hyperactor::ActorRef;
use hyperactor::Named;
use hyperactor::ProcId;
use hyperactor::RemoteMessage;
use hyperactor::WorldId;
use hyperactor::attrs::declare_attrs;
use hyperactor::channel;
use hyperactor::channel::ChannelAddr;
use hyperactor::channel::ChannelRx;
use hyperactor::channel::ChannelTransport;
use hyperactor::channel::MetaTlsAddr;
use hyperactor::config;
use hyperactor::config::CONFIG;
use hyperactor::config::ConfigAttr;
pub use local::LocalAlloc;
pub use local::LocalAllocator;
use mockall::predicate::*;
Expand All @@ -45,6 +63,54 @@ use crate::assign::Ranks;
use crate::proc_mesh::mesh_agent::ProcMeshAgent;
use crate::shortuuid::ShortUuid;

declare_attrs! {
/// For Tcp channel types, if true, bind the IP address to INADDR_ANY
/// (0.0.0.0 or [::]) for frontend ports.
///
/// This config is useful in environments where we cannot bind the port to
/// the given IP address. For example, in a AWS setting, it might not allow
/// us to bind the port to the host's public IP address.
@meta(CONFIG = ConfigAttr {
env_name: Some("HYPERACTOR_REMOTE_ALLOC_BIND_TO_INADDR_ANY".to_string()),
py_name: None,
})
pub attr REMOTE_ALLOC_BIND_TO_INADDR_ANY: bool = false;

/// Specify the address alloc uses as its bootstrap address. e.g.:
///
/// * "tcp:142.250.81.228:0" means seve at a random port with IP4 address
/// 142.250.81.228.
/// * "tcp:[2401:db00:eef0:1120:3520:0:7812:4eca]:27001" means serve at port
/// 27001 with any IP6 2401:db00:eef0:1120:3520:0:7812:4eca.
///
/// These IP address must be the IP address of the host running the alloc.
///
/// This config is useful when we want the alloc to use a particular IP
/// address. For example, in a AWS setting, we might want to use the host's
/// public IP address.
// TODO: remove this env var, and make it part of alloc spec instead.
@meta(CONFIG = ConfigAttr {
env_name: Some("HYPERACTOR_REMOTE_ALLOC_BOOTSTRAP_ADDR".to_string()),
py_name: None,
})
pub attr REMOTE_ALLOC_BOOTSTRAP_ADDR: String;

/// For Tcp channel types, if set, only uses ports in this range for the
/// frontend ports. The input should be in the format "<start>..<end>",
/// where <end> is exclusive. e.g.:
///
/// * "26601..26611" means only use the 10 ports in the range [26601, 26610],
/// including 26601 and 26610.
///
/// This config is useful in environments where only a certain range of
/// ports are allowed to be used.
@meta(CONFIG = ConfigAttr {
env_name: Some("HYPERACTOR_REMOTE_ALLOC_ALLOWED_PORT_RANGE".to_string()),
py_name: None,
})
pub attr REMOTE_ALLOC_ALLOWED_PORT_RANGE: Range<u16>;
}

/// Errors that occur during allocation operations.
#[derive(Debug, thiserror::Error)]
pub enum AllocatorError {
Expand Down Expand Up @@ -276,6 +342,11 @@ pub trait Alloc {
fn is_local(&self) -> bool {
false
}

/// The address that should be used to serve the client's router.
fn client_router_addr(&self) -> AllocAssignedAddr {
AllocAssignedAddr(ChannelAddr::any(self.transport()))
}
}

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
Expand Down Expand Up @@ -409,6 +480,144 @@ impl<A: ?Sized + Send + Alloc> AllocExt for A {
}
}

/// A new type to indicate this addr is assigned by alloc.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AllocAssignedAddr(ChannelAddr);

impl AllocAssignedAddr {
pub(crate) fn new(addr: ChannelAddr) -> AllocAssignedAddr {
AllocAssignedAddr(addr)
}

/// If addr is Tcp or Metatls, use its IP address or hostname to create
/// a new addr with port unspecified.
///
/// for other types of addr, return "any" address.
pub(crate) fn with_unspecified_port_or_any(addr: &ChannelAddr) -> AllocAssignedAddr {
let new_addr = match addr {
ChannelAddr::Tcp(socket) => {
let mut new_socket = socket.clone();
new_socket.set_port(0);
ChannelAddr::Tcp(new_socket)
}
ChannelAddr::MetaTls(MetaTlsAddr::Socket(socket)) => {
let mut new_socket = socket.clone();
new_socket.set_port(0);
ChannelAddr::MetaTls(MetaTlsAddr::Socket(new_socket))
}
ChannelAddr::MetaTls(MetaTlsAddr::Host { hostname, port: _ }) => {
ChannelAddr::MetaTls(MetaTlsAddr::Host {
hostname: hostname.clone(),
port: 0,
})
}
_ => addr.transport().any(),
};
AllocAssignedAddr(new_addr)
}

pub(crate) fn serve_with_config<M: RemoteMessage>(
self,
) -> anyhow::Result<(ChannelAddr, ChannelRx<M>)> {
fn set_as_inaddr_any(original: &mut SocketAddr) {
let inaddr_any: IpAddr = match &original {
SocketAddr::V4(_) => Ipv4Addr::UNSPECIFIED.into(),
SocketAddr::V6(_) => Ipv6Addr::UNSPECIFIED.into(),
};
original.set_ip(inaddr_any);
}

let use_inaddr_any = config::global::get(REMOTE_ALLOC_BIND_TO_INADDR_ANY);
let mut bind_to = self.0;
let mut original_ip: Option<IpAddr> = None;
match &mut bind_to {
ChannelAddr::Tcp(socket) => {
original_ip = Some(socket.ip().clone());
if use_inaddr_any {
set_as_inaddr_any(socket);
tracing::debug!("binding {} to INADDR_ANY", original_ip.as_ref().unwrap(),);
}
if socket.port() == 0 {
socket.set_port(next_allowed_port(socket.ip().clone())?);
}
}
_ => {
if use_inaddr_any {
tracing::debug!(
"can only bind to INADDR_ANY for TCP; got transport {}, addr {}",
bind_to.transport(),
bind_to
);
}
}
};

let (mut bound, rx) = channel::serve(bind_to)?;

// Restore the original IP address if we used INADDR_ANY.
match &mut bound {
ChannelAddr::Tcp(socket) => {
if use_inaddr_any {
socket.set_ip(original_ip.unwrap());
}
}
_ => (),
}

Ok((bound, rx))
}
}

enum AllowedPorts {
Config { range: Vec<u16>, next: AtomicUsize },
Any,
}

impl AllowedPorts {
fn next(&self, ip: IpAddr) -> anyhow::Result<u16> {
match self {
Self::Config { range, next } => {
let mut count = 0;
loop {
let i = next.fetch_add(1, Ordering::Relaxed);
count += 1;
// Since we do not have a good way to put release ports back to the list,
// we opportunistically hope ports previously took already released. If
// not, we'll just see error when binding to it later. This
// is not much different from raising error here.
let port = range.get(i % range.len()).cloned().unwrap();
let socket = SocketAddr::new(ip, port);
if TcpListener::bind(socket).is_ok() {
tracing::debug!("taking port {port} from the allowed list",);
return Ok(port);
}
if count == range.len() {
anyhow::bail!(
"fail to find a port because all ports in the allowed list are already bound"
);
}
}
}
Self::Any => Ok(0),
}
}
}

static ALLOWED_PORTS: OnceLock<Mutex<AllowedPorts>> = OnceLock::new();
fn next_allowed_port(ip: IpAddr) -> anyhow::Result<u16> {
let mutex = ALLOWED_PORTS.get_or_init(|| {
let ports = match config::global::try_get_cloned(REMOTE_ALLOC_ALLOWED_PORT_RANGE) {
Some(range) => AllowedPorts::Config {
range: range.into_iter().collect(),
next: AtomicUsize::new(0),
},
None => AllowedPorts::Any,
};
Mutex::new(ports)
});
mutex.lock().unwrap().next(ip)
}

pub mod test_utils {
use std::time::Duration;

Expand Down
Loading