diff --git a/Cargo.toml b/Cargo.toml index ba5111f..204d9b5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,7 @@ name = "init4-bin-base" description = "Internal utilities for binaries produced by the init4 team" keywords = ["init4", "bin", "base"] -version = "0.4.3" +version = "0.5.0" edition = "2021" rust-version = "1.81" authors = ["init4", "James Prestwich"] diff --git a/examples/oauth.rs b/examples/oauth.rs index 1320190..f538051 100644 --- a/examples/oauth.rs +++ b/examples/oauth.rs @@ -9,7 +9,7 @@ async fn main() -> eyre::Result<()> { let _jh = authenticator.spawn(); tokio::time::sleep(std::time::Duration::from_secs(5)).await; - dbg!(token.read()); + dbg!(token.secret().await.unwrap()); Ok(()) } diff --git a/src/perms/oauth.rs b/src/perms/oauth.rs index c0e9002..654cc38 100644 --- a/src/perms/oauth.rs +++ b/src/perms/oauth.rs @@ -4,13 +4,17 @@ use crate::{ deps::tracing::{error, info}, utils::from_env::FromEnv, }; +use core::fmt; use oauth2::{ basic::{BasicClient, BasicTokenType}, - AuthUrl, ClientId, ClientSecret, EmptyExtraTokenFields, EndpointNotSet, EndpointSet, - HttpClientError, RequestTokenError, StandardErrorResponse, StandardTokenResponse, TokenUrl, + AccessToken, AuthUrl, ClientId, ClientSecret, EmptyExtraTokenFields, EndpointNotSet, + EndpointSet, HttpClientError, RefreshToken, RequestTokenError, Scope, StandardErrorResponse, + StandardTokenResponse, TokenResponse, TokenUrl, +}; +use tokio::{ + sync::watch::{self, Ref}, + task::JoinHandle, }; -use std::sync::{Arc, Mutex}; -use tokio::task::JoinHandle; type Token = StandardTokenResponse; @@ -57,38 +61,17 @@ impl OAuthConfig { } } -/// A shared token that can be read and written to by multiple threads. -#[derive(Debug, Clone, Default)] -pub struct SharedToken(Arc>>); - -impl SharedToken { - /// Read the token from the shared token. - pub fn read(&self) -> Option { - self.0.lock().unwrap().clone() - } - - /// Write a new token to the shared token. - pub fn write(&self, token: Token) { - let mut lock = self.0.lock().unwrap(); - *lock = Some(token); - } - - /// Check if the token is authenticated. - pub fn is_authenticated(&self) -> bool { - self.0.lock().unwrap().is_some() - } -} - /// A self-refreshing, periodically fetching authenticator for the block -/// builder. This task periodically fetches a new token, and stores it in a -/// [`SharedToken`]. +/// builder. This task periodically fetches a new token, and sends it to all +/// active [`SharedToken`]s via a [`tokio::sync::watch`] channel.. #[derive(Debug)] pub struct Authenticator { /// Configuration - pub config: OAuthConfig, + config: OAuthConfig, client: MyOAuthClient, - token: SharedToken, reqwest: reqwest::Client, + + token: watch::Sender>, } impl Authenticator { @@ -99,6 +82,8 @@ impl Authenticator { .set_auth_uri(AuthUrl::from_url(config.oauth_authenticate_url.clone())) .set_token_uri(TokenUrl::from_url(config.oauth_token_url.clone())); + // NB: this is MANDATORY + // https://docs.rs/oauth2/latest/oauth2/#security-warning let rq_client = reqwest::Client::builder() .redirect(reqwest::redirect::Policy::none()) .build() @@ -107,8 +92,8 @@ impl Authenticator { Self { config: config.clone(), client, - token: Default::default(), reqwest: rq_client, + token: watch::channel(None).0, } } @@ -129,20 +114,20 @@ impl Authenticator { /// Returns true if there is Some token set pub fn is_authenticated(&self) -> bool { - self.token.is_authenticated() + self.token.borrow().is_some() } /// Sets the Authenticator's token to the provided value fn set_token(&self, token: StandardTokenResponse) { - self.token.write(token); + self.token.send_replace(Some(token)); } /// Returns the currently set token pub fn token(&self) -> SharedToken { - self.token.clone() + self.token.subscribe().into() } - /// Fetches an oauth token + /// Fetches an oauth token. pub async fn fetch_oauth_token( &self, ) -> Result< @@ -161,25 +146,184 @@ impl Authenticator { Ok(token_result) } - /// Spawns a task that periodically fetches a new token every 300 seconds. - pub fn spawn(self) -> JoinHandle<()> { + /// Get a reference to the OAuth configuration. + pub const fn config(&self) -> &OAuthConfig { + &self.config + } + + /// Create a future that contains the periodic refresh loop. + async fn task_future(self) { let interval = self.config.oauth_token_refresh_interval; - let handle: JoinHandle<()> = tokio::spawn(async move { - loop { - info!("Refreshing oauth token"); - match self.authenticate().await { - Ok(_) => { - info!("Successfully refreshed oauth token"); - } - Err(e) => { - error!(%e, "Failed to refresh oauth token"); - } - }; - let _sleep = tokio::time::sleep(tokio::time::Duration::from_secs(interval)).await; - } - }); - - handle + loop { + info!("Refreshing oauth token"); + match self.authenticate().await { + Ok(_) => { + info!("Successfully refreshed oauth token"); + } + Err(e) => { + error!(%e, "Failed to refresh oauth token"); + } + }; + let _sleep = tokio::time::sleep(tokio::time::Duration::from_secs(interval)).await; + } + } + + /// Spawns a task that periodically fetches a new token. The refresh + /// interval may be configured via the + /// [`OAuthConfig::oauth_token_refresh_interval`] property. + pub fn spawn(self) -> JoinHandle<()> { + tokio::spawn(self.task_future()) + } +} + +/// A shared token, wrapped in a [`tokio::sync::watch`] Receiver. The token is +/// periodically refreshed by an [`Authenticator`] task, and can be awaited +/// for when it becomes available. +/// +/// This allows multiple tasks to wait for the token to be available, and +/// provides a way to check if the token is authenticated without blocking. +/// Please consult the [`Receiver`] documentation for caveats regarding +/// usage. +/// +/// [`Receiver`]: tokio::sync::watch::Receiver +#[derive(Debug, Clone)] +pub struct SharedToken(watch::Receiver>); + +impl From>> for SharedToken { + fn from(inner: watch::Receiver>) -> Self { + Self(inner) + } +} + +impl SharedToken { + /// Wait for the token to be available, and get a reference to the secret. + /// + /// This is implemented using [`Receiver::wait_for`], and has the same + /// blocking, panics, errors, and cancel safety. However, it uses a clone + /// of the [`watch::Receiver`] and will not update the local view of the + /// channel. + /// + /// [`Receiver::wait_for`]: tokio::sync::watch::Receiver::wait_for + pub async fn secret(&self) -> Result { + Ok(self + .clone() + .token() + .await? + .access_token() + .secret() + .to_owned()) + } + + /// Wait for the token to be available, then get a reference to it. + /// + /// Holding this reference will block the background task from updating + /// the token until it is dropped, so it is recommended to drop this + /// reference as soon as possible. + /// + /// This is implemented using [`Receiver::wait_for`], and has the same + /// blocking, panics, errors, and cancel safety. Unlike [`Self::secret`] + /// it is NOT implemented using a clone, and will update the local view of + /// the channel. + /// + /// Generally, prefer using [`Self::secret`] for simple use cases, and + /// this when deeper inspection of the token is required. + /// + /// [`Receiver::wait_for`]: tokio::sync::watch::Receiver::wait_for + pub async fn token(&mut self) -> Result, watch::error::RecvError> { + self.0.wait_for(Option::is_some).await.map(Into::into) + } + + /// Create a future that will resolve when the token is ready. + /// + /// This is implemented using [`Receiver::wait_for`], and has the same + /// blocking, panics, errors, and cancel safety. + /// + /// [`Receiver::wait_for`]: tokio::sync::watch::Receiver::wait_for + pub async fn wait(&self) -> Result<(), watch::error::RecvError> { + self.clone().0.wait_for(Option::is_some).await.map(drop) + } + + /// Borrow the current token, if available. If called before the token is + /// set by the authentication task, this will return `None`. + /// + /// Holding this reference will block the background task from updating + /// the token until it is dropped, so it is recommended to drop this + /// reference as soon as possible. + /// + /// This is implemented using [`Receiver::borrow`]. + /// + /// [`Receiver::borrow`]: tokio::sync::watch::Receiver::borrow + pub fn borrow(&mut self) -> Ref<'_, Option> { + self.0.borrow() + } + + /// Check if the background task has produced an authentication token. + /// + /// Holding this reference will block the background task from updating + /// the token until it is dropped, so it is recommended to drop this + /// reference as soon as possible. + /// + /// This is implemented using [`Receiver::borrow`]. + /// + /// [`Receiver::borrow`]: tokio::sync::watch::Receiver::borrow + pub fn is_authenticated(&self) -> bool { + self.0.borrow().is_some() + } +} + +/// A reference to token data, contained in a [`SharedToken`]. +/// +/// This is implemented using [`watch::Ref`], and as a result holds a lock on +/// the token data. Holding this reference will block the background task +/// from updating the token until it is dropped, so it is recommended to drop +/// this reference as soon as possible. +pub struct TokenRef<'a> { + inner: Ref<'a, Option>, +} + +impl<'a> From>> for TokenRef<'a> { + fn from(inner: Ref<'a, Option>) -> Self { + Self { inner } + } +} + +impl fmt::Debug for TokenRef<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("TokenRef").finish_non_exhaustive() + } +} + +impl<'a> TokenRef<'a> { + /// Get a reference to the inner token. + pub fn inner(&'a self) -> &'a Token { + self.inner.as_ref().unwrap() + } + + /// Get a reference to the [`AccessToken`] contained in the token. + pub fn access_token(&self) -> &AccessToken { + self.inner().access_token() + } + + /// Get a reference to the [`TokenType`] instance contained in the token. + /// + /// [`TokenType`]: oauth2::TokenType + pub fn token_type(&self) -> &::TokenType { + self.inner().token_type() + } + + /// Get a reference to the current token's expiration time, if it has one. + pub fn expires_in(&self) -> Option { + self.inner().expires_in() + } + + /// Get a reference to the refresh token, if it exists. + pub fn refresh_token(&self) -> Option<&RefreshToken> { + self.inner().refresh_token() + } + + /// Get a reference to the scopes associated with the token, if any. + pub fn scopes(&self) -> Option<&Vec> { + self.inner().scopes() } } diff --git a/src/perms/tx_cache.rs b/src/perms/tx_cache.rs index 25113bc..b575176 100644 --- a/src/perms/tx_cache.rs +++ b/src/perms/tx_cache.rs @@ -1,6 +1,5 @@ use crate::perms::oauth::SharedToken; -use eyre::{bail, Result}; -use oauth2::TokenResponse; +use eyre::Result; use serde::de::DeserializeOwned; use signet_tx_cache::{ client::TxCache, @@ -53,14 +52,12 @@ impl BuilderTxCache { async fn get_inner_with_token(&self, join: &str) -> Result { let url = self.tx_cache.url().join(join)?; - let Some(token) = self.token.read() else { - bail!("No token available for authentication"); - }; + let secret = self.token.secret().await?; self.tx_cache .client() .get(url) - .bearer_auth(token.access_token().secret()) + .bearer_auth(secret) .send() .await .inspect_err(|e| warn!(%e, "Failed to get object from transaction cache"))?