From 177b88aace9441ca5d51b84129de5c6ef82d61e0 Mon Sep 17 00:00:00 2001 From: Ashley Stanton-Nurse Date: Thu, 2 Oct 2025 19:00:52 +0000 Subject: [PATCH 1/8] add container routing map and tests --- Cargo.lock | 44 +++ sdk/cosmos/azure_data_cosmos/Cargo.toml | 2 + sdk/cosmos/azure_data_cosmos/src/cache.rs | 0 sdk/cosmos/azure_data_cosmos/src/lib.rs | 5 + .../azure_data_cosmos/src/models/mod.rs | 2 + sdk/cosmos/azure_data_cosmos/src/routing.rs | 320 ++++++++++++++++++ sdk/cosmos/azure_data_cosmos/src/types.rs | 41 +++ 7 files changed, 414 insertions(+) create mode 100644 sdk/cosmos/azure_data_cosmos/src/cache.rs create mode 100644 sdk/cosmos/azure_data_cosmos/src/routing.rs create mode 100644 sdk/cosmos/azure_data_cosmos/src/types.rs diff --git a/Cargo.lock b/Cargo.lock index 289136ca88..a4b5736e15 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -321,6 +321,8 @@ dependencies = [ "azure_identity", "clap", "futures", + "moka", + "pin-project", "reqwest", "serde", "serde_json", @@ -884,6 +886,15 @@ dependencies = [ "itertools", ] +[[package]] +name = "crossbeam-channel" +version = "0.5.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82b8f8f868b36967f9606790d1903570de9ceaf870a7bf9fbbd3016d636a2cb2" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "crossbeam-deque" version = "0.8.6" @@ -1792,6 +1803,27 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "moka" +version = "0.12.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8261cd88c312e0004c1d51baad2980c66528dfdb2bee62003e643a4d8f86b077" +dependencies = [ + "async-lock", + "crossbeam-channel", + "crossbeam-epoch", + "crossbeam-utils", + "equivalent", + "event-listener", + "futures-util", + "parking_lot", + "portable-atomic", + "rustc_version", + "smallvec", + "tagptr", + "uuid", +] + [[package]] name = "native-tls" version = "0.2.14" @@ -2092,6 +2124,12 @@ dependencies = [ "plotters-backend", ] +[[package]] +name = "portable-atomic" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483" + [[package]] name = "potential_utf" version = "0.1.2" @@ -2789,6 +2827,12 @@ dependencies = [ "syn 2.0.104", ] +[[package]] +name = "tagptr" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b2093cf4c8eb1e67749a6762251bc9cd836b6fc171623bd0a9d324d37af2417" + [[package]] name = "tap" version = "1.0.1" diff --git a/sdk/cosmos/azure_data_cosmos/Cargo.toml b/sdk/cosmos/azure_data_cosmos/Cargo.toml index 22dce2f890..12d137e488 100644 --- a/sdk/cosmos/azure_data_cosmos/Cargo.toml +++ b/sdk/cosmos/azure_data_cosmos/Cargo.toml @@ -17,11 +17,13 @@ categories = ["api-bindings"] async-trait.workspace = true azure_core.workspace = true futures.workspace = true +pin-project.workspace = true serde_json.workspace = true serde.workspace = true tracing.workspace = true typespec_client_core = { workspace = true, features = ["derive"] } url.workspace = true +moka = { version = "0.12.11", features = ["future"] } [dev-dependencies] azure_identity.workspace = true diff --git a/sdk/cosmos/azure_data_cosmos/src/cache.rs b/sdk/cosmos/azure_data_cosmos/src/cache.rs new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sdk/cosmos/azure_data_cosmos/src/lib.rs b/sdk/cosmos/azure_data_cosmos/src/lib.rs index f5caf27e35..b0ceef2f0e 100644 --- a/sdk/cosmos/azure_data_cosmos/src/lib.rs +++ b/sdk/cosmos/azure_data_cosmos/src/lib.rs @@ -19,11 +19,16 @@ mod partition_key; pub(crate) mod pipeline; pub mod query; pub(crate) mod resource_context; +pub(crate) mod routing; pub(crate) mod utils; pub mod models; +mod cache; mod location_cache; +mod types; + +pub(crate) use types::*; #[doc(inline)] pub use clients::CosmosClient; diff --git a/sdk/cosmos/azure_data_cosmos/src/models/mod.rs b/sdk/cosmos/azure_data_cosmos/src/models/mod.rs index dbd3355383..92421a39d5 100644 --- a/sdk/cosmos/azure_data_cosmos/src/models/mod.rs +++ b/sdk/cosmos/azure_data_cosmos/src/models/mod.rs @@ -18,6 +18,8 @@ pub use partition_key_definition::*; pub use patch_operations::*; pub use throughput_properties::*; +use crate::types::{EffectivePartitionKey, PartitionKeyRangeId}; + fn deserialize_cosmos_timestamp<'de, D>(deserializer: D) -> Result, D::Error> where D: Deserializer<'de>, diff --git a/sdk/cosmos/azure_data_cosmos/src/routing.rs b/sdk/cosmos/azure_data_cosmos/src/routing.rs new file mode 100644 index 0000000000..d4a9372604 --- /dev/null +++ b/sdk/cosmos/azure_data_cosmos/src/routing.rs @@ -0,0 +1,320 @@ +use std::ops::Range; + +use crate::types::{EffectivePartitionKey, PartitionKeyRangeId}; +use serde::{Deserialize, Serialize}; + +/// Represents a partition key range within a Cosmos DB collection. +#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)] +pub(crate) struct PartitionKeyRange { + // The ID of the partition key range. + pub id: PartitionKeyRangeId, + + /// The minimum inclusive value of the partition key range. + pub min_inclusive: EffectivePartitionKey, + + /// The maximum exclusive value of the partition key range. + pub max_exclusive: EffectivePartitionKey, + + /// The parents of the partition key range, if they still exist. + /// + /// During a split or merge, this will contain the ID of the partition key range(s) that were split or merged to create this range. + #[serde(default)] + pub parents: Vec, +} + +impl PartitionKeyRange { + /// Checks if the given effective partition key falls within this partition key range. + pub fn contains(&self, epk: &EffectivePartitionKey) -> bool { + self.min_inclusive <= *epk && *epk < self.max_exclusive + } + + pub fn overlaps(&self, range: &Range) -> bool { + // Empty ranges don't overlap with anything + if range.start >= range.end { + return false; + } + !(self.max_exclusive <= range.start || self.min_inclusive >= range.end) + } +} + +#[derive(Debug)] +pub struct ContainerRoutingMap { + /// The list of partition key ranges for the container, sorted by their minimum inclusive value. + pk_ranges: Vec, +} + +impl ContainerRoutingMap { + /// Creates a new `ContainerRoutingMap` from the provided partition key ranges. + pub fn new(mut pk_ranges: Vec) -> Self { + normalize_ranges(&mut pk_ranges); + Self { pk_ranges } + } + + /// Replaces the partition key ranges with a new set, normalizing them in the process. + pub fn with_ranges(mut self, mut pk_ranges: Vec) -> Self { + normalize_ranges(&mut pk_ranges); + self.pk_ranges = pk_ranges; + self + } + + /// Gets the list of partition key ranges for the container, sorted by their minimum inclusive value. + pub fn ranges(&self) -> &[PartitionKeyRange] { + &self.pk_ranges + } + + pub fn range_containing(&self, epk: &EffectivePartitionKey) -> Option<&PartitionKeyRange> { + // TODO: Could be optimized with binary search if needed + self.pk_ranges.iter().find(|pkr| pkr.contains(epk)) + } + + pub fn range(&self, id: &PartitionKeyRangeId) -> Option<&PartitionKeyRange> { + self.pk_ranges.iter().find(|pkr| &pkr.id == id) + } + + pub fn overlapping_ranges( + &self, + ranges: &[Range], + ) -> Vec<&PartitionKeyRange> { + // TODO: Could be optimized further later + self.pk_ranges + .iter() + .filter(|pkr| ranges.iter().any(|r| pkr.overlaps(r))) + .collect() + } +} + +/// Discards any [`PartitionKeyRange`] that is the parent of another range, leaving only the leaf ranges. +fn normalize_ranges(pk_ranges: &mut Vec) { + let parent_ids: std::collections::HashSet<_> = pk_ranges + .iter() + .flat_map(|pkr| pkr.parents.iter().cloned()) + .collect(); + pk_ranges.retain(|pkr| !parent_ids.contains(&pkr.id)); + pk_ranges.sort_by(|a, b| a.min_inclusive.cmp(&b.min_inclusive)); +} + +#[cfg(test)] +mod tests { + use super::*; + + fn create_range(id: &str, min: &str, max: &str, parents: Vec<&str>) -> PartitionKeyRange { + PartitionKeyRange { + id: id.into(), + min_inclusive: min.into(), + max_exclusive: max.into(), + parents: parents.into_iter().map(|p| p.into()).collect(), + } + } + + #[test] + fn partition_key_range_contains() { + let range = create_range("1", "00", "33", vec![]); + + assert!(range.contains(&"00".into())); + assert!(range.contains(&"15".into())); + assert!(range.contains(&"32".into())); + assert!(!range.contains(&"33".into())); + assert!(!range.contains(&"34".into())); + assert!(!range.contains(&"FF".into())); + } + + #[test] + fn partition_key_range_overlaps() { + let range = create_range("1", "20", "60", vec![]); + + assert!(range.overlaps(&("30".into().."40".into()))); + assert!(range.overlaps(&("10".into().."30".into()))); + assert!(range.overlaps(&("50".into().."70".into()))); + assert!(range.overlaps(&("10".into().."70".into()))); + assert!(range.overlaps(&("20".into().."60".into()))); + assert!(!range.overlaps(&("00".into().."20".into()))); + assert!(!range.overlaps(&("60".into().."80".into()))); + assert!(!range.overlaps(&("30".into().."30".into()))); + } + + #[test] + fn container_routing_map_construction() { + let ranges = vec![ + create_range("2", "33", "66", vec![]), + create_range("1", "00", "33", vec![]), + create_range("3", "66", "FF", vec![]), + ]; + + let map = ContainerRoutingMap::new(ranges); + assert_eq!(map.ranges().len(), 3); + assert_eq!(map.ranges()[0].id.value(), "1"); + assert_eq!(map.ranges()[1].id.value(), "2"); + assert_eq!(map.ranges()[2].id.value(), "3"); + + let new_ranges = vec![ + create_range("2", "33", "66", vec![]), + create_range("1", "00", "33", vec![]), + ]; + let map = map.with_ranges(new_ranges); + assert_eq!(map.ranges().len(), 2); + assert_eq!(map.ranges()[0].id.value(), "1"); + assert_eq!(map.ranges()[1].id.value(), "2"); + } + + #[test] + fn container_routing_map_range_containing() { + let ranges = vec![ + create_range("1", "00", "33", vec![]), + create_range("2", "33", "66", vec![]), + create_range("3", "66", "FF", vec![]), + ]; + + let map = ContainerRoutingMap::new(ranges); + + assert_eq!(map.range_containing(&"00".into()).unwrap().id.value(), "1"); + assert_eq!(map.range_containing(&"32".into()).unwrap().id.value(), "1"); + assert_eq!(map.range_containing(&"33".into()).unwrap().id.value(), "2"); + assert_eq!(map.range_containing(&"50".into()).unwrap().id.value(), "2"); + assert_eq!(map.range_containing(&"66".into()).unwrap().id.value(), "3"); + assert_eq!(map.range_containing(&"AA".into()).unwrap().id.value(), "3"); + + let single_range_map = + ContainerRoutingMap::new(vec![create_range("1", "33", "66", vec![])]); + assert!(single_range_map.range_containing(&"00".into()).is_none()); + assert!(single_range_map.range_containing(&"99".into()).is_none()); + } + + #[test] + fn container_routing_map_range_by_id() { + let ranges = vec![ + create_range("1", "00", "33", vec![]), + create_range("2", "33", "66", vec![]), + create_range("3", "66", "FF", vec![]), + ]; + + let map = ContainerRoutingMap::new(ranges); + + assert_eq!(map.range(&"1".into()).unwrap().min_inclusive.value(), "00"); + assert_eq!(map.range(&"2".into()).unwrap().min_inclusive.value(), "33"); + assert_eq!(map.range(&"3".into()).unwrap().min_inclusive.value(), "66"); + assert!(map.range(&"99".into()).is_none()); + } + + #[test] + fn container_routing_map_overlapping_ranges() { + let ranges = vec![ + create_range("1", "00", "33", vec![]), + create_range("2", "33", "66", vec![]), + create_range("3", "66", "FF", vec![]), + ]; + + let map = ContainerRoutingMap::new(ranges); + + let overlapping = map.overlapping_ranges(&["00".into().."20".into()]); + assert_eq!(overlapping.len(), 1); + assert_eq!(overlapping[0].id.value(), "1"); + + let overlapping = map.overlapping_ranges(&["20".into().."50".into()]); + assert_eq!(overlapping.len(), 2); + assert_eq!(overlapping[0].id.value(), "1"); + assert_eq!(overlapping[1].id.value(), "2"); + + let overlapping = map.overlapping_ranges(&["00".into().."FF".into()]); + assert_eq!(overlapping.len(), 3); + + let overlapping = map.overlapping_ranges(&[]); + assert_eq!(overlapping.len(), 0); + + let overlapping = + map.overlapping_ranges(&["10".into().."20".into(), "70".into().."80".into()]); + assert_eq!(overlapping.len(), 2); + assert_eq!(overlapping[0].id.value(), "1"); + assert_eq!(overlapping[1].id.value(), "3"); + } + + #[test] + fn normalize_ranges_removes_parents() { + let mut ranges = vec![ + create_range("parent1", "00", "50", vec![]), + create_range("parent2", "50", "FF", vec![]), + create_range("child1", "00", "25", vec!["parent1"]), + create_range("child2", "25", "50", vec!["parent1"]), + create_range("child3", "50", "FF", vec!["parent2"]), + ]; + + normalize_ranges(&mut ranges); + + assert_eq!(ranges.len(), 3); + assert_eq!(ranges[0].id.value(), "child1"); + assert_eq!(ranges[1].id.value(), "child2"); + assert_eq!(ranges[2].id.value(), "child3"); + assert!(ranges[0].min_inclusive <= ranges[1].min_inclusive); + assert!(ranges[1].min_inclusive <= ranges[2].min_inclusive); + } + + #[test] + fn normalize_ranges_sorts_by_min_inclusive() { + let mut ranges = vec![ + create_range("3", "66", "99", vec![]), + create_range("1", "00", "33", vec![]), + create_range("2", "33", "66", vec![]), + ]; + + normalize_ranges(&mut ranges); + assert_eq!(ranges[0].id.value(), "1"); + assert_eq!(ranges[1].id.value(), "2"); + assert_eq!(ranges[2].id.value(), "3"); + } + + #[test] + fn normalize_ranges_handles_empty_list() { + let mut ranges: Vec = vec![]; + normalize_ranges(&mut ranges); + assert_eq!(ranges.len(), 0); + } + + #[test] + fn normalize_ranges_handles_no_parents() { + let mut ranges = vec![ + create_range("2", "33", "66", vec![]), + create_range("1", "00", "33", vec![]), + ]; + + normalize_ranges(&mut ranges); + assert_eq!(ranges.len(), 2); + assert_eq!(ranges[0].id.value(), "1"); + assert_eq!(ranges[1].id.value(), "2"); + } + + #[test] + fn container_routing_map_full_test() { + let ranges = vec![ + create_range("parent1", "00", "80", vec![]), + create_range("parent2", "80", "FF", vec![]), + create_range("child1", "00", "40", vec!["parent1"]), + create_range("child2", "40", "80", vec!["parent1"]), + create_range("child3", "80", "C0", vec!["parent2"]), + create_range("child4", "C0", "FF", vec!["parent2"]), + ]; + + let map = ContainerRoutingMap::new(ranges); + + assert_eq!(map.ranges().len(), 4); + assert_eq!( + map.range_containing(&"10".into()).unwrap().id.value(), + "child1" + ); + assert_eq!( + map.range_containing(&"50".into()).unwrap().id.value(), + "child2" + ); + assert_eq!( + map.range_containing(&"90".into()).unwrap().id.value(), + "child3" + ); + assert_eq!( + map.range_containing(&"D0".into()).unwrap().id.value(), + "child4" + ); + + let overlapping = map.overlapping_ranges(&["30".into().."70".into()]); + assert_eq!(overlapping.len(), 2); + assert_eq!(overlapping[0].id.value(), "child1"); + assert_eq!(overlapping[1].id.value(), "child2"); + } +} diff --git a/sdk/cosmos/azure_data_cosmos/src/types.rs b/sdk/cosmos/azure_data_cosmos/src/types.rs new file mode 100644 index 0000000000..08db8ba505 --- /dev/null +++ b/sdk/cosmos/azure_data_cosmos/src/types.rs @@ -0,0 +1,41 @@ +//! Internal module to define several newtypes used in the SDK. + +macro_rules! string_newtype { + ($name:ident) => { + #[derive(serde::Deserialize, serde::Serialize, Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] + #[serde(transparent)] + pub struct $name(String); + + impl $name { + pub fn new(value: String) -> Self { + Self(value) + } + + pub fn value(&self) -> &str { + &self.0 + } + } + + impl std::fmt::Display for $name { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } + } + + impl From<&str> for $name { + fn from(s: &str) -> Self { + Self(s.to_string()) + } + } + + impl From for $name { + fn from(s: String) -> Self { + Self(s) + } + } + }; +} + +string_newtype!(ResourceId); +string_newtype!(PartitionKeyRangeId); +string_newtype!(EffectivePartitionKey); From 1165fc51bf33339b977a1a873f5d85fc5c42fd67 Mon Sep 17 00:00:00 2001 From: Ashley Stanton-Nurse Date: Thu, 2 Oct 2025 20:22:48 +0000 Subject: [PATCH 2/8] start on CMC --- sdk/cosmos/azure_data_cosmos/src/cache.rs | 42 +++++++++++++++++++++ sdk/cosmos/azure_data_cosmos/src/routing.rs | 23 ++++++++++- 2 files changed, 63 insertions(+), 2 deletions(-) diff --git a/sdk/cosmos/azure_data_cosmos/src/cache.rs b/sdk/cosmos/azure_data_cosmos/src/cache.rs index e69de29bb2..0da0a21cae 100644 --- a/sdk/cosmos/azure_data_cosmos/src/cache.rs +++ b/sdk/cosmos/azure_data_cosmos/src/cache.rs @@ -0,0 +1,42 @@ +use std::sync::Arc; + +use moka::future::Cache; + +use crate::{ + models::{ContainerProperties, DatabaseProperties}, + routing::ContainerRoutingMap, + ResourceId, +}; + +pub struct ContainerMetadataCache { + /// Caches a mapping from container ID (the "name") to container properties, including the RID. + container_properties_cache: Cache>, + + /// Caches a mapping from database ID (the "name") to database properties, including the RID. + database_properties_cache: Cache>, + + /// Caches container routing information, mapping from container RID to routing info. + routing_map_cache: Cache>, +} + +// TODO: Review this value. +// Cosmos has a backend limit of 500 databases and containers per account by default. +// This value affects when Moka will start evicting entries from the cache. +// It could probably be much lower without much impact, but we need to do the research to be sure. +const MAX_CACHE_CAPACITY: u64 = 500; + +impl ContainerMetadataCache { + /// Creates a new `ContainerMetadataCache` with default settings. + /// + /// Since the cache is designed to be shared, it is returned inside an `Arc`. + pub fn new() -> Arc { + let container_properties_cache = Cache::new(MAX_CACHE_CAPACITY); + let database_properties_cache = Cache::new(MAX_CACHE_CAPACITY); + let routing_map_cache = Cache::new(MAX_CACHE_CAPACITY); + Arc::new(Self { + container_properties_cache, + database_properties_cache, + routing_map_cache, + }) + } +} diff --git a/sdk/cosmos/azure_data_cosmos/src/routing.rs b/sdk/cosmos/azure_data_cosmos/src/routing.rs index d4a9372604..2b2a4e9623 100644 --- a/sdk/cosmos/azure_data_cosmos/src/routing.rs +++ b/sdk/cosmos/azure_data_cosmos/src/routing.rs @@ -28,6 +28,17 @@ impl PartitionKeyRange { self.min_inclusive <= *epk && *epk < self.max_exclusive } + /// Compares the given effective partition key to this partition key range, returning an [`Ordering`](std::cmp::Ordering) that can be used for searching a sorted list of ranges for the range containing the key. + pub fn compare_to(&self, epk: &EffectivePartitionKey) -> std::cmp::Ordering { + if self.contains(epk) { + std::cmp::Ordering::Equal + } else if self.min_inclusive > *epk { + std::cmp::Ordering::Greater + } else { + std::cmp::Ordering::Less + } + } + pub fn overlaps(&self, range: &Range) -> bool { // Empty ranges don't overlap with anything if range.start >= range.end { @@ -62,12 +73,19 @@ impl ContainerRoutingMap { &self.pk_ranges } + /// Finds the partition key range that contains the given effective partition key, if any. pub fn range_containing(&self, epk: &EffectivePartitionKey) -> Option<&PartitionKeyRange> { - // TODO: Could be optimized with binary search if needed - self.pk_ranges.iter().find(|pkr| pkr.contains(epk)) + // It's critical that pk_ranges is sorted by min_inclusive for this to work correctly, so assert that in debug builds + debug_assert!(self.pk_ranges.is_sorted_by_key(|pkr| &pkr.min_inclusive)); + self.pk_ranges + .binary_search_by(|pkr| pkr.compare_to(epk)) + .ok() + .map(|idx| &self.pk_ranges[idx]) } + /// Finds the partition key range with the given ID, if any. pub fn range(&self, id: &PartitionKeyRangeId) -> Option<&PartitionKeyRange> { + // We don't know that IDs are sorted, so just do a linear search self.pk_ranges.iter().find(|pkr| &pkr.id == id) } @@ -84,6 +102,7 @@ impl ContainerRoutingMap { } /// Discards any [`PartitionKeyRange`] that is the parent of another range, leaving only the leaf ranges. +/// Also sorts the ranges by their minimum inclusive value. fn normalize_ranges(pk_ranges: &mut Vec) { let parent_ids: std::collections::HashSet<_> = pk_ranges .iter() From 595692ff7c567ce81cce12ee0ef5856d4be9c0c4 Mon Sep 17 00:00:00 2001 From: Ashley Stanton-Nurse Date: Thu, 2 Oct 2025 22:22:03 +0000 Subject: [PATCH 3/8] rename CosmosPipeline to CosmosConnection --- .../src/{pipeline => connection}/authorization_policy.rs | 0 .../azure_data_cosmos/src/{pipeline => connection}/mod.rs | 6 +++--- .../src/{pipeline => connection}/signature_target.rs | 0 3 files changed, 3 insertions(+), 3 deletions(-) rename sdk/cosmos/azure_data_cosmos/src/{pipeline => connection}/authorization_policy.rs (100%) rename sdk/cosmos/azure_data_cosmos/src/{pipeline => connection}/mod.rs (98%) rename sdk/cosmos/azure_data_cosmos/src/{pipeline => connection}/signature_target.rs (100%) diff --git a/sdk/cosmos/azure_data_cosmos/src/pipeline/authorization_policy.rs b/sdk/cosmos/azure_data_cosmos/src/connection/authorization_policy.rs similarity index 100% rename from sdk/cosmos/azure_data_cosmos/src/pipeline/authorization_policy.rs rename to sdk/cosmos/azure_data_cosmos/src/connection/authorization_policy.rs diff --git a/sdk/cosmos/azure_data_cosmos/src/pipeline/mod.rs b/sdk/cosmos/azure_data_cosmos/src/connection/mod.rs similarity index 98% rename from sdk/cosmos/azure_data_cosmos/src/pipeline/mod.rs rename to sdk/cosmos/azure_data_cosmos/src/connection/mod.rs index 0f84843a30..fc3f2a17d6 100644 --- a/sdk/cosmos/azure_data_cosmos/src/pipeline/mod.rs +++ b/sdk/cosmos/azure_data_cosmos/src/connection/mod.rs @@ -26,18 +26,18 @@ use crate::{ /// Newtype that wraps an Azure Core pipeline to provide a Cosmos-specific pipeline which configures our authorization policy and enforces that a [`ResourceType`] is set on the context. #[derive(Debug, Clone)] -pub struct CosmosPipeline { +pub struct CosmosConnection { pub endpoint: Url, pipeline: azure_core::http::Pipeline, } -impl CosmosPipeline { +impl CosmosConnection { pub fn new( endpoint: Url, auth_policy: AuthorizationPolicy, client_options: ClientOptions, ) -> Self { - CosmosPipeline { + CosmosConnection { endpoint, pipeline: azure_core::http::Pipeline::new( option_env!("CARGO_PKG_NAME"), diff --git a/sdk/cosmos/azure_data_cosmos/src/pipeline/signature_target.rs b/sdk/cosmos/azure_data_cosmos/src/connection/signature_target.rs similarity index 100% rename from sdk/cosmos/azure_data_cosmos/src/pipeline/signature_target.rs rename to sdk/cosmos/azure_data_cosmos/src/connection/signature_target.rs From 30fea5be6fa62706d721623811d2d78626f1ef7f Mon Sep 17 00:00:00 2001 From: Ashley Stanton-Nurse Date: Fri, 3 Oct 2025 21:20:28 +0000 Subject: [PATCH 4/8] container metadata cache, and caching for read_throughput --- sdk/cosmos/azure_data_cosmos/src/cache.rs | 137 ++++++- .../src/clients/container_client.rs | 115 +++--- .../src/clients/cosmos_client.rs | 25 +- .../src/clients/database_client.rs | 33 +- .../src/connection/authorization_policy.rs | 4 +- .../azure_data_cosmos/src/connection/mod.rs | 16 +- .../src/connection/signature_target.rs | 2 +- sdk/cosmos/azure_data_cosmos/src/lib.rs | 17 +- .../azure_data_cosmos/src/models/mod.rs | 4 +- .../azure_data_cosmos/src/query/executor.rs | 26 +- .../azure_data_cosmos/src/resource_context.rs | 4 +- sdk/cosmos/azure_data_cosmos/src/routing.rs | 339 ------------------ sdk/cosmos/azure_data_cosmos/src/types.rs | 2 - .../tests/cosmos_containers.rs | 73 +++- .../tests/framework/local_recorder.rs | 70 ++++ .../azure_data_cosmos/tests/framework/mod.rs | 4 +- .../tests/framework/test_account.rs | 10 + 17 files changed, 413 insertions(+), 468 deletions(-) delete mode 100644 sdk/cosmos/azure_data_cosmos/src/routing.rs create mode 100644 sdk/cosmos/azure_data_cosmos/tests/framework/local_recorder.rs diff --git a/sdk/cosmos/azure_data_cosmos/src/cache.rs b/sdk/cosmos/azure_data_cosmos/src/cache.rs index 0da0a21cae..afd72e86ba 100644 --- a/sdk/cosmos/azure_data_cosmos/src/cache.rs +++ b/sdk/cosmos/azure_data_cosmos/src/cache.rs @@ -3,20 +3,100 @@ use std::sync::Arc; use moka::future::Cache; use crate::{ - models::{ContainerProperties, DatabaseProperties}, - routing::ContainerRoutingMap, + models::{ContainerProperties, PartitionKeyDefinition}, + resource_context::ResourceLink, ResourceId, }; -pub struct ContainerMetadataCache { - /// Caches a mapping from container ID (the "name") to container properties, including the RID. - container_properties_cache: Cache>, +#[derive(Debug)] +pub enum CacheError { + FetchError(Arc), +} + +impl From> for CacheError { + fn from(e: Arc) -> Self { + CacheError::FetchError(e) + } +} + +impl From for azure_core::Error { + fn from(e: CacheError) -> Self { + match e { + CacheError::FetchError(e) => { + let message = format!("error updating Container Metadata Cache: {}", e); + azure_core::Error::with_error(azure_core::error::ErrorKind::Other, e, message) + } + } + } +} + +impl std::fmt::Display for CacheError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + CacheError::FetchError(e) => write!(f, "error fetching latest value: {}", e), + } + } +} + +impl std::error::Error for CacheError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + CacheError::FetchError(e) => Some(&**e), + } + } +} - /// Caches a mapping from database ID (the "name") to database properties, including the RID. - database_properties_cache: Cache>, +/// A subset of container properties that are stable and suitable for caching. +pub(crate) struct ContainerMetadata { + pub self_link: String, + pub resource_id: ResourceId, + pub partition_key: PartitionKeyDefinition, + pub container_link: ResourceLink, +} - /// Caches container routing information, mapping from container RID to routing info. - routing_map_cache: Cache>, +impl ContainerMetadata { + // We can't use From because we also want the container link. + pub fn from_properties( + properties: &ContainerProperties, + container_link: ResourceLink, + ) -> azure_core::Result { + let self_link = properties + .system_properties + .self_link + .as_ref() + .ok_or_else(|| { + azure_core::Error::new( + azure_core::error::ErrorKind::Other, + "container properties is missing expected value 'self_link'", + ) + })? + .clone(); + let resource_id = properties + .system_properties + .resource_id + .clone() + .ok_or_else(|| { + azure_core::Error::new( + azure_core::error::ErrorKind::Other, + "container properties is missing expected value 'resource_id'", + ) + })?; + Ok(Self { + self_link, + resource_id, + partition_key: properties.partition_key.clone(), + container_link, + }) + } +} + +/// A cache for container metadata, including properties and routing information. +/// +/// The cache can be cloned cheaply, and all clones share the same underlying cache data. +#[derive(Clone)] +pub struct ContainerMetadataCache { + /// Caches stable container metadata, mapping from container link and RID to metadata. + container_properties_cache: Cache>, } // TODO: Review this value. @@ -29,14 +109,39 @@ impl ContainerMetadataCache { /// Creates a new `ContainerMetadataCache` with default settings. /// /// Since the cache is designed to be shared, it is returned inside an `Arc`. - pub fn new() -> Arc { + pub fn new() -> Self { let container_properties_cache = Cache::new(MAX_CACHE_CAPACITY); - let database_properties_cache = Cache::new(MAX_CACHE_CAPACITY); - let routing_map_cache = Cache::new(MAX_CACHE_CAPACITY); - Arc::new(Self { + Self { container_properties_cache, - database_properties_cache, - routing_map_cache, - }) + } + } + + /// Unconditionally updates the cache with the provided container metadata. + pub async fn set_container_metadata(&self, metadata: ContainerMetadata) { + let metadata = Arc::new(metadata); + + self.container_properties_cache + .insert(metadata.container_link.clone(), metadata) + .await; + } + + /// Gets the container metadata from the cache, or initializes it using the provided async function if not present. + pub async fn get_container_metadata( + &self, + key: &ResourceLink, + init: impl std::future::Future>, + ) -> Result, CacheError> { + // TODO: Background refresh. We can do background refresh by storing an expiry time in the cache entry. + // Then, if the entry is stale, we can return the stale entry and spawn a background task to refresh it. + // There's a little trickiness here in that + Ok(self + .container_properties_cache + .try_get_with_by_ref(key, async { init.await.map(Arc::new) }) + .await?) + } + + /// Clears the cached container metadata for the specified key, so that the next request will fetch fresh data. + pub async fn clear_container_metadata(&self, key: &ResourceLink) { + self.container_properties_cache.invalidate(key).await; } } diff --git a/sdk/cosmos/azure_data_cosmos/src/clients/container_client.rs b/sdk/cosmos/azure_data_cosmos/src/clients/container_client.rs index 7d28ee80f2..e6f249bcc2 100644 --- a/sdk/cosmos/azure_data_cosmos/src/clients/container_client.rs +++ b/sdk/cosmos/azure_data_cosmos/src/clients/container_client.rs @@ -1,11 +1,14 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +use std::sync::Arc; + use crate::{ + cache::ContainerMetadata, + connection::CosmosConnection, constants, models::{ContainerProperties, PatchDocument, ThroughputProperties}, options::{QueryOptions, ReadContainerOptions}, - pipeline::CosmosPipeline, resource_context::{ResourceLink, ResourceType}, DeleteContainerOptions, FeedPager, ItemOptions, PartitionKey, Query, ReplaceContainerOptions, ThroughputOptions, @@ -14,7 +17,7 @@ use crate::{ use azure_core::http::{ request::{options::ContentType, Request}, response::Response, - Method, + Method, RawResponse, }; use serde::{de::DeserializeOwned, Serialize}; @@ -23,14 +26,15 @@ use serde::{de::DeserializeOwned, Serialize}; /// You can get a `Container` by calling [`DatabaseClient::container_client()`](crate::clients::DatabaseClient::container_client()). #[derive(Clone)] pub struct ContainerClient { + id: String, link: ResourceLink, items_link: ResourceLink, - pipeline: CosmosPipeline, + connection: CosmosConnection, } impl ContainerClient { pub(crate) fn new( - pipeline: CosmosPipeline, + connection: CosmosConnection, database_link: &ResourceLink, container_id: &str, ) -> Self { @@ -40,9 +44,10 @@ impl ContainerClient { let items_link = link.feed(ResourceType::Items); Self { + id: container_id.to_string(), link, items_link, - pipeline, + connection, } } @@ -66,12 +71,15 @@ impl ContainerClient { &self, options: Option>, ) -> azure_core::Result> { - let options = options.unwrap_or_default(); - let url = self.pipeline.url(&self.link); - let mut req = Request::new(url, Method::Get); - self.pipeline - .send(options.method_options.context, &mut req, self.link.clone()) - .await + let response: RawResponse = self.read_properties(options).await?.into(); + + // Read the properties and cache the stable metadata (things that don't change for the life of a container) + // TODO: Replace with `response.body().json()` when that becomes borrowing. + let properties = serde_json::from_slice::(response.body())?; + let metadata = ContainerMetadata::from_properties(&properties, self.link.clone())?; + self.connection.cache.set_container_metadata(metadata).await; + + Ok(response.into()) } /// Updates the indexing policy of the container. @@ -110,11 +118,11 @@ impl ContainerClient { options: Option>, ) -> azure_core::Result> { let options = options.unwrap_or_default(); - let url = self.pipeline.url(&self.link); + let url = self.connection.url(&self.link); let mut req = Request::new(url, Method::Put); req.insert_headers(&ContentType::APPLICATION_JSON)?; req.set_json(&properties)?; - self.pipeline + self.connection .send(options.method_options.context, &mut req, self.link.clone()) .await } @@ -130,16 +138,9 @@ impl ContainerClient { options: Option>, ) -> azure_core::Result>> { let options = options.unwrap_or_default(); - - // We need to get the RID for the database. - let db = self.read(None).await?.into_body()?; - let resource_id = db - .system_properties - .resource_id - .expect("service should always return a '_rid' for a container"); - - self.pipeline - .read_throughput_offer(options.method_options.context, &resource_id) + let resource_id = &self.metadata().await?.resource_id; + self.connection + .read_throughput_offer(options.method_options.context, resource_id) .await } @@ -156,14 +157,9 @@ impl ContainerClient { let options = options.unwrap_or_default(); // We need to get the RID for the database. - let db = self.read(None).await?.into_body()?; - let resource_id = db - .system_properties - .resource_id - .expect("service should always return a '_rid' for a container"); - - self.pipeline - .replace_throughput_offer(options.method_options.context, &resource_id, throughput) + let resource_id = &self.metadata().await?.resource_id; + self.connection + .replace_throughput_offer(options.method_options.context, resource_id, throughput) .await } @@ -178,9 +174,9 @@ impl ContainerClient { options: Option>, ) -> azure_core::Result> { let options = options.unwrap_or_default(); - let url = self.pipeline.url(&self.link); + let url = self.connection.url(&self.link); let mut req = Request::new(url, Method::Delete); - self.pipeline + self.connection .send(options.method_options.context, &mut req, self.link.clone()) .await } @@ -257,13 +253,13 @@ impl ContainerClient { options: Option>, ) -> azure_core::Result> { let options = options.unwrap_or_default(); - let url = self.pipeline.url(&self.items_link); + let url = self.connection.url(&self.items_link); let mut req = Request::new(url, Method::Post); req.insert_headers(&options)?; req.insert_headers(&partition_key.into())?; req.insert_headers(&ContentType::APPLICATION_JSON)?; req.set_json(&item)?; - self.pipeline + self.connection .send( options.method_options.context, &mut req, @@ -346,13 +342,13 @@ impl ContainerClient { ) -> azure_core::Result> { let options = options.unwrap_or_default(); let link = self.items_link.item(item_id); - let url = self.pipeline.url(&link); + let url = self.connection.url(&link); let mut req = Request::new(url, Method::Put); req.insert_headers(&options)?; req.insert_headers(&partition_key.into())?; req.insert_headers(&ContentType::APPLICATION_JSON)?; req.set_json(&item)?; - self.pipeline + self.connection .send(options.method_options.context, &mut req, link) .await } @@ -432,14 +428,14 @@ impl ContainerClient { options: Option>, ) -> azure_core::Result> { let options = options.unwrap_or_default(); - let url = self.pipeline.url(&self.items_link); + let url = self.connection.url(&self.items_link); let mut req = Request::new(url, Method::Post); req.insert_headers(&options)?; req.insert_header(constants::IS_UPSERT, "true"); req.insert_headers(&partition_key.into())?; req.insert_headers(&ContentType::APPLICATION_JSON)?; req.set_json(&item)?; - self.pipeline + self.connection .send( options.method_options.context, &mut req, @@ -490,11 +486,11 @@ impl ContainerClient { options.enable_content_response_on_write = true; let link = self.items_link.item(item_id); - let url = self.pipeline.url(&link); + let url = self.connection.url(&link); let mut req = Request::new(url, Method::Get); req.insert_headers(&options)?; req.insert_headers(&partition_key.into())?; - self.pipeline + self.connection .send(options.method_options.context, &mut req, link) .await } @@ -527,11 +523,11 @@ impl ContainerClient { ) -> azure_core::Result> { let options = options.unwrap_or_default(); let link = self.items_link.item(item_id); - let url = self.pipeline.url(&link); + let url = self.connection.url(&link); let mut req = Request::new(url, Method::Delete); req.insert_headers(&options)?; req.insert_headers(&partition_key.into())?; - self.pipeline + self.connection .send(options.method_options.context, &mut req, link) .await } @@ -600,14 +596,14 @@ impl ContainerClient { ) -> azure_core::Result> { let options = options.unwrap_or_default(); let link = self.items_link.item(item_id); - let url = self.pipeline.url(&link); + let url = self.connection.url(&link); let mut req = Request::new(url, Method::Patch); req.insert_headers(&options)?; req.insert_headers(&partition_key.into())?; req.insert_headers(&ContentType::APPLICATION_JSON)?; req.set_json(&patch)?; - self.pipeline + self.connection .send(options.method_options.context, &mut req, link) .await } @@ -686,7 +682,7 @@ impl ContainerClient { if partition_key.is_empty() { if let Some(query_engine) = options.query_engine.take() { return crate::query::executor::QueryExecutor::new( - self.pipeline.clone(), + self.connection.clone(), self.link.clone(), query, options, @@ -696,8 +692,8 @@ impl ContainerClient { } } - let url = self.pipeline.url(&self.items_link); - self.pipeline.send_query_request( + let url = self.connection.url(&self.items_link); + self.connection.send_query_request( options.method_options.context, query, url, @@ -705,4 +701,27 @@ impl ContainerClient { |r| r.insert_headers(&partition_key), ) } + + async fn metadata(&self) -> azure_core::Result> { + Ok(self + .connection + .cache + .get_container_metadata(&self.link, async { + let properties = self.read_properties(None).await?.into_body()?; + ContainerMetadata::from_properties(&properties, self.link.clone()) + }) + .await?) + } + + async fn read_properties( + &self, + options: Option>, + ) -> azure_core::Result> { + let options = options.unwrap_or_default(); + let url = self.connection.url(&self.link); + let mut req = Request::new(url, Method::Get); + self.connection + .send(options.method_options.context, &mut req, self.link.clone()) + .await + } } diff --git a/sdk/cosmos/azure_data_cosmos/src/clients/cosmos_client.rs b/sdk/cosmos/azure_data_cosmos/src/clients/cosmos_client.rs index 209fadc284..b02600d61a 100644 --- a/sdk/cosmos/azure_data_cosmos/src/clients/cosmos_client.rs +++ b/sdk/cosmos/azure_data_cosmos/src/clients/cosmos_client.rs @@ -3,8 +3,8 @@ use crate::{ clients::DatabaseClient, + connection::{AuthorizationPolicy, CosmosConnection}, models::DatabaseProperties, - pipeline::{AuthorizationPolicy, CosmosPipeline}, resource_context::{ResourceLink, ResourceType}, CosmosClientOptions, CreateDatabaseOptions, FeedPager, Query, QueryDatabasesOptions, }; @@ -23,10 +23,13 @@ use std::sync::Arc; use azure_core::credentials::Secret; /// Client for Azure Cosmos DB. -#[derive(Debug, Clone)] +/// +/// A [`CosmosClient`] can be safely shared between threads and is cheap to clone, as it holds most of the connection state in an [`Arc`]. +/// However, it's generally preferred to have a single `CosmosClient` per Cosmos account in your application, and share that between threads as needed. +#[derive(Clone)] pub struct CosmosClient { databases_link: ResourceLink, - pipeline: CosmosPipeline, + connection: CosmosConnection, } impl CosmosClient { @@ -55,7 +58,7 @@ impl CosmosClient { let options = options.unwrap_or_default(); Ok(Self { databases_link: ResourceLink::root(ResourceType::Databases), - pipeline: CosmosPipeline::new( + connection: CosmosConnection::new( endpoint.parse()?, AuthorizationPolicy::from_token_credential(credential), options.client_options, @@ -88,7 +91,7 @@ impl CosmosClient { let options = options.unwrap_or_default(); Ok(Self { databases_link: ResourceLink::root(ResourceType::Databases), - pipeline: CosmosPipeline::new( + connection: CosmosConnection::new( endpoint.parse()?, AuthorizationPolicy::from_shared_key(key), options.client_options, @@ -131,12 +134,12 @@ impl CosmosClient { /// # Arguments /// * `id` - The ID of the database. pub fn database_client(&self, id: &str) -> DatabaseClient { - DatabaseClient::new(self.pipeline.clone(), id) + DatabaseClient::new(self.connection.clone(), id) } /// Gets the endpoint of the database account this client is connected to. pub fn endpoint(&self) -> &Url { - &self.pipeline.endpoint + &self.connection.endpoint } /// Executes a query against databases in the account. @@ -168,9 +171,9 @@ impl CosmosClient { options: Option>, ) -> azure_core::Result> { let options = options.unwrap_or_default(); - let url = self.pipeline.url(&self.databases_link); + let url = self.connection.url(&self.databases_link); - self.pipeline.send_query_request( + self.connection.send_query_request( options.method_options.context, query.into(), url, @@ -198,13 +201,13 @@ impl CosmosClient { id: &'a str, } - let url = self.pipeline.url(&self.databases_link); + let url = self.connection.url(&self.databases_link); let mut req = Request::new(url, Method::Post); req.insert_headers(&options.throughput)?; req.insert_headers(&ContentType::APPLICATION_JSON)?; req.set_json(&RequestBody { id })?; - self.pipeline + self.connection .send( options.method_options.context, &mut req, diff --git a/sdk/cosmos/azure_data_cosmos/src/clients/database_client.rs b/sdk/cosmos/azure_data_cosmos/src/clients/database_client.rs index c9dc6889f0..38aa22e958 100644 --- a/sdk/cosmos/azure_data_cosmos/src/clients/database_client.rs +++ b/sdk/cosmos/azure_data_cosmos/src/clients/database_client.rs @@ -1,11 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +use std::sync::Arc; + use crate::{ clients::ContainerClient, + connection::CosmosConnection, models::{ContainerProperties, DatabaseProperties, ThroughputProperties}, options::ReadDatabaseOptions, - pipeline::CosmosPipeline, resource_context::{ResourceLink, ResourceType}, CreateContainerOptions, DeleteDatabaseOptions, FeedPager, Query, QueryContainersOptions, ThroughputOptions, @@ -20,15 +22,16 @@ use azure_core::http::{ /// A client for working with a specific database in a Cosmos DB account. /// /// You can get a `DatabaseClient` by calling [`CosmosClient::database_client()`](crate::CosmosClient::database_client()). +#[derive(Clone)] pub struct DatabaseClient { link: ResourceLink, containers_link: ResourceLink, database_id: String, - pipeline: CosmosPipeline, + connection: CosmosConnection, } impl DatabaseClient { - pub(crate) fn new(pipeline: CosmosPipeline, database_id: &str) -> Self { + pub(crate) fn new(connection: CosmosConnection, database_id: &str) -> Self { let database_id = database_id.to_string(); let link = ResourceLink::root(ResourceType::Databases).item(&database_id); let containers_link = link.feed(ResourceType::Containers); @@ -37,7 +40,7 @@ impl DatabaseClient { link, containers_link, database_id, - pipeline, + connection, } } @@ -46,7 +49,7 @@ impl DatabaseClient { /// # Arguments /// * `name` - The name of the container. pub fn container_client(&self, name: &str) -> ContainerClient { - ContainerClient::new(self.pipeline.clone(), &self.link, name) + ContainerClient::new(self.connection.clone(), &self.link, name) } /// Returns the identifier of the Cosmos database. @@ -76,9 +79,9 @@ impl DatabaseClient { options: Option>, ) -> azure_core::Result> { let options = options.unwrap_or_default(); - let url = self.pipeline.url(&self.link); + let url = self.connection.url(&self.link); let mut req = Request::new(url, Method::Get); - self.pipeline + self.connection .send(options.method_options.context, &mut req, self.link.clone()) .await } @@ -112,9 +115,9 @@ impl DatabaseClient { options: Option>, ) -> azure_core::Result> { let options = options.unwrap_or_default(); - let url = self.pipeline.url(&self.containers_link); + let url = self.connection.url(&self.containers_link); - self.pipeline.send_query_request( + self.connection.send_query_request( options.method_options.context, query.into(), url, @@ -136,13 +139,13 @@ impl DatabaseClient { options: Option>, ) -> azure_core::Result> { let options = options.unwrap_or_default(); - let url = self.pipeline.url(&self.containers_link); + let url = self.connection.url(&self.containers_link); let mut req = Request::new(url, Method::Post); req.insert_headers(&options.throughput)?; req.insert_headers(&ContentType::APPLICATION_JSON)?; req.set_json(&properties)?; - self.pipeline + self.connection .send( options.method_options.context, &mut req, @@ -162,9 +165,9 @@ impl DatabaseClient { options: Option>, ) -> azure_core::Result> { let options = options.unwrap_or_default(); - let url = self.pipeline.url(&self.link); + let url = self.connection.url(&self.link); let mut req = Request::new(url, Method::Delete); - self.pipeline + self.connection .send(options.method_options.context, &mut req, self.link.clone()) .await } @@ -188,7 +191,7 @@ impl DatabaseClient { .resource_id .expect("service should always return a '_rid' for a database"); - self.pipeline + self.connection .read_throughput_offer(options.method_options.context, &resource_id) .await } @@ -212,7 +215,7 @@ impl DatabaseClient { .resource_id .expect("service should always return a '_rid' for a database"); - self.pipeline + self.connection .replace_throughput_offer(options.method_options.context, &resource_id, throughput) .await } diff --git a/sdk/cosmos/azure_data_cosmos/src/connection/authorization_policy.rs b/sdk/cosmos/azure_data_cosmos/src/connection/authorization_policy.rs index 19cf51c2a0..7444323e52 100644 --- a/sdk/cosmos/azure_data_cosmos/src/connection/authorization_policy.rs +++ b/sdk/cosmos/azure_data_cosmos/src/connection/authorization_policy.rs @@ -21,7 +21,7 @@ use azure_core::{ use std::sync::Arc; use tracing::trace; -use crate::{pipeline::signature_target::SignatureTarget, resource_context::ResourceLink}; +use crate::{connection::signature_target::SignatureTarget, resource_context::ResourceLink}; use crate::utils::url_encode; @@ -153,7 +153,7 @@ mod tests { use url::Url; use crate::{ - pipeline::{ + connection::{ authorization_policy::{generate_authorization, scope_from_url, Credential}, signature_target::SignatureTarget, }, diff --git a/sdk/cosmos/azure_data_cosmos/src/connection/mod.rs b/sdk/cosmos/azure_data_cosmos/src/connection/mod.rs index fc3f2a17d6..7f194a812e 100644 --- a/sdk/cosmos/azure_data_cosmos/src/connection/mod.rs +++ b/sdk/cosmos/azure_data_cosmos/src/connection/mod.rs @@ -18,16 +18,21 @@ use serde::de::DeserializeOwned; use url::Url; use crate::{ + cache::ContainerMetadataCache, constants, models::ThroughputProperties, resource_context::{ResourceLink, ResourceType}, - FeedPage, FeedPager, Query, + FeedPage, FeedPager, Query, ResourceId, }; -/// Newtype that wraps an Azure Core pipeline to provide a Cosmos-specific pipeline which configures our authorization policy and enforces that a [`ResourceType`] is set on the context. -#[derive(Debug, Clone)] +/// Represents a connection to a specific Cosmos account. +/// +/// The [`CosmosConnection`] holds all the shared state for a connection to a Cosmos DB account. +/// A connection is cheap to clone, and all clones share the same underlying HTTP pipeline and metadata cache. +#[derive(Clone)] pub struct CosmosConnection { pub endpoint: Url, + pub cache: ContainerMetadataCache, pipeline: azure_core::http::Pipeline, } @@ -39,6 +44,7 @@ impl CosmosConnection { ) -> Self { CosmosConnection { endpoint, + cache: ContainerMetadataCache::new(), pipeline: azure_core::http::Pipeline::new( option_env!("CARGO_PKG_NAME"), option_env!("CARGO_PKG_VERSION"), @@ -124,7 +130,7 @@ impl CosmosConnection { pub async fn read_throughput_offer( &self, context: Context<'_>, - resource_id: &str, + resource_id: &ResourceId, ) -> azure_core::Result>> { // We only have to into_owned here in order to call send_query_request below, // since it returns `Pager` which must own it's data. @@ -164,7 +170,7 @@ impl CosmosConnection { pub async fn replace_throughput_offer( &self, context: Context<'_>, - resource_id: &str, + resource_id: &ResourceId, throughput: ThroughputProperties, ) -> azure_core::Result> { let response = self diff --git a/sdk/cosmos/azure_data_cosmos/src/connection/signature_target.rs b/sdk/cosmos/azure_data_cosmos/src/connection/signature_target.rs index 6db3089903..a3ad1558da 100644 --- a/sdk/cosmos/azure_data_cosmos/src/connection/signature_target.rs +++ b/sdk/cosmos/azure_data_cosmos/src/connection/signature_target.rs @@ -70,7 +70,7 @@ mod tests { use azure_core::{http::Method, time}; use crate::{ - pipeline::signature_target::SignatureTarget, + connection::signature_target::SignatureTarget, resource_context::{ResourceLink, ResourceType}, }; diff --git a/sdk/cosmos/azure_data_cosmos/src/lib.rs b/sdk/cosmos/azure_data_cosmos/src/lib.rs index b0ceef2f0e..f9dca088f3 100644 --- a/sdk/cosmos/azure_data_cosmos/src/lib.rs +++ b/sdk/cosmos/azure_data_cosmos/src/lib.rs @@ -10,25 +10,22 @@ #![cfg_attr(docsrs, feature(doc_auto_cfg))] #![cfg_attr(docsrs, feature(doc_cfg_hide))] +mod cache; pub mod clients; +mod connection; mod connection_string; pub mod constants; mod feed; +mod location_cache; +pub mod models; mod options; mod partition_key; -pub(crate) mod pipeline; pub mod query; -pub(crate) mod resource_context; -pub(crate) mod routing; -pub(crate) mod utils; - -pub mod models; - -mod cache; -mod location_cache; +mod resource_context; mod types; +mod utils; -pub(crate) use types::*; +pub use types::ResourceId; #[doc(inline)] pub use clients::CosmosClient; diff --git a/sdk/cosmos/azure_data_cosmos/src/models/mod.rs b/sdk/cosmos/azure_data_cosmos/src/models/mod.rs index 92421a39d5..dd0c6222c1 100644 --- a/sdk/cosmos/azure_data_cosmos/src/models/mod.rs +++ b/sdk/cosmos/azure_data_cosmos/src/models/mod.rs @@ -18,7 +18,7 @@ pub use partition_key_definition::*; pub use patch_operations::*; pub use throughput_properties::*; -use crate::types::{EffectivePartitionKey, PartitionKeyRangeId}; +use crate::ResourceId; fn deserialize_cosmos_timestamp<'de, D>(deserializer: D) -> Result, D::Error> where @@ -84,7 +84,7 @@ pub struct SystemProperties { // Some APIs do expect the "_rid" to be provided (Replace Offer, for example), so we do want to serialize it if it's provided. #[serde(skip_serializing_if = "Option::is_none")] #[serde(rename = "_rid")] - pub resource_id: Option, + pub resource_id: Option, /// A [`OffsetDateTime`] representing the last modified time of the resource. #[serde(default)] diff --git a/sdk/cosmos/azure_data_cosmos/src/query/executor.rs b/sdk/cosmos/azure_data_cosmos/src/query/executor.rs index f0e93b326d..aebe448737 100644 --- a/sdk/cosmos/azure_data_cosmos/src/query/executor.rs +++ b/sdk/cosmos/azure_data_cosmos/src/query/executor.rs @@ -1,16 +1,18 @@ +use std::sync::Arc; + use azure_core::http::{headers::Headers, Context, Method, RawResponse, Request}; use serde::de::DeserializeOwned; use crate::{ + connection::CosmosConnection, constants, - pipeline::{self, CosmosPipeline}, query::{OwnedQueryPipeline, QueryEngineRef, QueryResult}, resource_context::{ResourceLink, ResourceType}, FeedPage, FeedPager, Query, QueryOptions, }; pub struct QueryExecutor { - http_pipeline: CosmosPipeline, + connection: CosmosConnection, container_link: ResourceLink, items_link: ResourceLink, context: Context<'static>, @@ -29,7 +31,7 @@ pub struct QueryExecutor { impl QueryExecutor { pub fn new( - http_pipeline: CosmosPipeline, + connection: CosmosConnection, container_link: ResourceLink, query: Query, options: QueryOptions<'_>, @@ -38,7 +40,7 @@ impl QueryExecutor { let items_link = container_link.feed(ResourceType::Items); let context = options.method_options.context.into_owned(); Ok(Self { - http_pipeline, + connection, container_link, items_link, context, @@ -77,7 +79,7 @@ impl QueryExecutor { None => { // Initialize the pipeline. let query_plan = get_query_plan( - &self.http_pipeline, + &self.connection, &self.items_link, self.context.to_borrowed(), &self.query, @@ -86,7 +88,7 @@ impl QueryExecutor { .await? .into_body(); let pkranges = get_pkranges( - &self.http_pipeline, + &self.connection, &self.container_link, self.context.to_borrowed(), ) @@ -97,8 +99,8 @@ impl QueryExecutor { self.query_engine .create_pipeline(&self.query.text, &query_plan, &pkranges)?; self.query.text = pipeline.query().into(); - self.base_request = Some(crate::pipeline::create_base_query_request( - self.http_pipeline.url(&self.items_link), + self.base_request = Some(crate::connection::create_base_query_request( + self.connection.url(&self.items_link), &self.query, )?); self.pipeline = Some(pipeline); @@ -139,7 +141,7 @@ impl QueryExecutor { } let resp = self - .http_pipeline + .connection .send_raw( self.context.to_borrowed(), &mut query_request, @@ -172,14 +174,14 @@ impl QueryExecutor { // This isn't an inherent method on QueryExecutor because that would force the whole executor to be Sync, which would force the pipeline to be Sync. #[tracing::instrument(skip_all)] async fn get_query_plan( - http_pipeline: &CosmosPipeline, + http_pipeline: &CosmosConnection, items_link: &ResourceLink, context: Context<'_>, query: &Query, supported_features: &str, ) -> azure_core::Result { let url = http_pipeline.url(items_link); - let mut request = pipeline::create_base_query_request(url, query)?; + let mut request = crate::connection::create_base_query_request(url, query)?; request.insert_header(constants::QUERY_ENABLE_CROSS_PARTITION, "True"); request.insert_header(constants::IS_QUERY_PLAN_REQUEST, "True"); request.insert_header( @@ -195,7 +197,7 @@ async fn get_query_plan( // This isn't an inherent method on QueryExecutor because that would force the whole executor to be Sync, which would force the pipeline to be Sync. #[tracing::instrument(skip_all)] async fn get_pkranges( - http_pipeline: &CosmosPipeline, + http_pipeline: &CosmosConnection, container_link: &ResourceLink, context: Context<'_>, ) -> azure_core::Result { diff --git a/sdk/cosmos/azure_data_cosmos/src/resource_context.rs b/sdk/cosmos/azure_data_cosmos/src/resource_context.rs index 69c61fe403..9e84871a99 100644 --- a/sdk/cosmos/azure_data_cosmos/src/resource_context.rs +++ b/sdk/cosmos/azure_data_cosmos/src/resource_context.rs @@ -5,7 +5,7 @@ use url::Url; use crate::utils::url_encode; -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] #[allow(dead_code)] // For the variants. Can be removed when we have them all implemented. pub enum ResourceType { Databases, @@ -41,7 +41,7 @@ impl ResourceType { /// /// This value is URL encoded, and can be [`Url::join`]ed to the endpoint root to produce the full absolute URL for a Cosmos DB resource. /// It's also intended for use by the signature algorithm used when authenticating with a primary key. -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct ResourceLink { parent: Option, item_id: Option, diff --git a/sdk/cosmos/azure_data_cosmos/src/routing.rs b/sdk/cosmos/azure_data_cosmos/src/routing.rs deleted file mode 100644 index 2b2a4e9623..0000000000 --- a/sdk/cosmos/azure_data_cosmos/src/routing.rs +++ /dev/null @@ -1,339 +0,0 @@ -use std::ops::Range; - -use crate::types::{EffectivePartitionKey, PartitionKeyRangeId}; -use serde::{Deserialize, Serialize}; - -/// Represents a partition key range within a Cosmos DB collection. -#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)] -pub(crate) struct PartitionKeyRange { - // The ID of the partition key range. - pub id: PartitionKeyRangeId, - - /// The minimum inclusive value of the partition key range. - pub min_inclusive: EffectivePartitionKey, - - /// The maximum exclusive value of the partition key range. - pub max_exclusive: EffectivePartitionKey, - - /// The parents of the partition key range, if they still exist. - /// - /// During a split or merge, this will contain the ID of the partition key range(s) that were split or merged to create this range. - #[serde(default)] - pub parents: Vec, -} - -impl PartitionKeyRange { - /// Checks if the given effective partition key falls within this partition key range. - pub fn contains(&self, epk: &EffectivePartitionKey) -> bool { - self.min_inclusive <= *epk && *epk < self.max_exclusive - } - - /// Compares the given effective partition key to this partition key range, returning an [`Ordering`](std::cmp::Ordering) that can be used for searching a sorted list of ranges for the range containing the key. - pub fn compare_to(&self, epk: &EffectivePartitionKey) -> std::cmp::Ordering { - if self.contains(epk) { - std::cmp::Ordering::Equal - } else if self.min_inclusive > *epk { - std::cmp::Ordering::Greater - } else { - std::cmp::Ordering::Less - } - } - - pub fn overlaps(&self, range: &Range) -> bool { - // Empty ranges don't overlap with anything - if range.start >= range.end { - return false; - } - !(self.max_exclusive <= range.start || self.min_inclusive >= range.end) - } -} - -#[derive(Debug)] -pub struct ContainerRoutingMap { - /// The list of partition key ranges for the container, sorted by their minimum inclusive value. - pk_ranges: Vec, -} - -impl ContainerRoutingMap { - /// Creates a new `ContainerRoutingMap` from the provided partition key ranges. - pub fn new(mut pk_ranges: Vec) -> Self { - normalize_ranges(&mut pk_ranges); - Self { pk_ranges } - } - - /// Replaces the partition key ranges with a new set, normalizing them in the process. - pub fn with_ranges(mut self, mut pk_ranges: Vec) -> Self { - normalize_ranges(&mut pk_ranges); - self.pk_ranges = pk_ranges; - self - } - - /// Gets the list of partition key ranges for the container, sorted by their minimum inclusive value. - pub fn ranges(&self) -> &[PartitionKeyRange] { - &self.pk_ranges - } - - /// Finds the partition key range that contains the given effective partition key, if any. - pub fn range_containing(&self, epk: &EffectivePartitionKey) -> Option<&PartitionKeyRange> { - // It's critical that pk_ranges is sorted by min_inclusive for this to work correctly, so assert that in debug builds - debug_assert!(self.pk_ranges.is_sorted_by_key(|pkr| &pkr.min_inclusive)); - self.pk_ranges - .binary_search_by(|pkr| pkr.compare_to(epk)) - .ok() - .map(|idx| &self.pk_ranges[idx]) - } - - /// Finds the partition key range with the given ID, if any. - pub fn range(&self, id: &PartitionKeyRangeId) -> Option<&PartitionKeyRange> { - // We don't know that IDs are sorted, so just do a linear search - self.pk_ranges.iter().find(|pkr| &pkr.id == id) - } - - pub fn overlapping_ranges( - &self, - ranges: &[Range], - ) -> Vec<&PartitionKeyRange> { - // TODO: Could be optimized further later - self.pk_ranges - .iter() - .filter(|pkr| ranges.iter().any(|r| pkr.overlaps(r))) - .collect() - } -} - -/// Discards any [`PartitionKeyRange`] that is the parent of another range, leaving only the leaf ranges. -/// Also sorts the ranges by their minimum inclusive value. -fn normalize_ranges(pk_ranges: &mut Vec) { - let parent_ids: std::collections::HashSet<_> = pk_ranges - .iter() - .flat_map(|pkr| pkr.parents.iter().cloned()) - .collect(); - pk_ranges.retain(|pkr| !parent_ids.contains(&pkr.id)); - pk_ranges.sort_by(|a, b| a.min_inclusive.cmp(&b.min_inclusive)); -} - -#[cfg(test)] -mod tests { - use super::*; - - fn create_range(id: &str, min: &str, max: &str, parents: Vec<&str>) -> PartitionKeyRange { - PartitionKeyRange { - id: id.into(), - min_inclusive: min.into(), - max_exclusive: max.into(), - parents: parents.into_iter().map(|p| p.into()).collect(), - } - } - - #[test] - fn partition_key_range_contains() { - let range = create_range("1", "00", "33", vec![]); - - assert!(range.contains(&"00".into())); - assert!(range.contains(&"15".into())); - assert!(range.contains(&"32".into())); - assert!(!range.contains(&"33".into())); - assert!(!range.contains(&"34".into())); - assert!(!range.contains(&"FF".into())); - } - - #[test] - fn partition_key_range_overlaps() { - let range = create_range("1", "20", "60", vec![]); - - assert!(range.overlaps(&("30".into().."40".into()))); - assert!(range.overlaps(&("10".into().."30".into()))); - assert!(range.overlaps(&("50".into().."70".into()))); - assert!(range.overlaps(&("10".into().."70".into()))); - assert!(range.overlaps(&("20".into().."60".into()))); - assert!(!range.overlaps(&("00".into().."20".into()))); - assert!(!range.overlaps(&("60".into().."80".into()))); - assert!(!range.overlaps(&("30".into().."30".into()))); - } - - #[test] - fn container_routing_map_construction() { - let ranges = vec![ - create_range("2", "33", "66", vec![]), - create_range("1", "00", "33", vec![]), - create_range("3", "66", "FF", vec![]), - ]; - - let map = ContainerRoutingMap::new(ranges); - assert_eq!(map.ranges().len(), 3); - assert_eq!(map.ranges()[0].id.value(), "1"); - assert_eq!(map.ranges()[1].id.value(), "2"); - assert_eq!(map.ranges()[2].id.value(), "3"); - - let new_ranges = vec![ - create_range("2", "33", "66", vec![]), - create_range("1", "00", "33", vec![]), - ]; - let map = map.with_ranges(new_ranges); - assert_eq!(map.ranges().len(), 2); - assert_eq!(map.ranges()[0].id.value(), "1"); - assert_eq!(map.ranges()[1].id.value(), "2"); - } - - #[test] - fn container_routing_map_range_containing() { - let ranges = vec![ - create_range("1", "00", "33", vec![]), - create_range("2", "33", "66", vec![]), - create_range("3", "66", "FF", vec![]), - ]; - - let map = ContainerRoutingMap::new(ranges); - - assert_eq!(map.range_containing(&"00".into()).unwrap().id.value(), "1"); - assert_eq!(map.range_containing(&"32".into()).unwrap().id.value(), "1"); - assert_eq!(map.range_containing(&"33".into()).unwrap().id.value(), "2"); - assert_eq!(map.range_containing(&"50".into()).unwrap().id.value(), "2"); - assert_eq!(map.range_containing(&"66".into()).unwrap().id.value(), "3"); - assert_eq!(map.range_containing(&"AA".into()).unwrap().id.value(), "3"); - - let single_range_map = - ContainerRoutingMap::new(vec![create_range("1", "33", "66", vec![])]); - assert!(single_range_map.range_containing(&"00".into()).is_none()); - assert!(single_range_map.range_containing(&"99".into()).is_none()); - } - - #[test] - fn container_routing_map_range_by_id() { - let ranges = vec![ - create_range("1", "00", "33", vec![]), - create_range("2", "33", "66", vec![]), - create_range("3", "66", "FF", vec![]), - ]; - - let map = ContainerRoutingMap::new(ranges); - - assert_eq!(map.range(&"1".into()).unwrap().min_inclusive.value(), "00"); - assert_eq!(map.range(&"2".into()).unwrap().min_inclusive.value(), "33"); - assert_eq!(map.range(&"3".into()).unwrap().min_inclusive.value(), "66"); - assert!(map.range(&"99".into()).is_none()); - } - - #[test] - fn container_routing_map_overlapping_ranges() { - let ranges = vec![ - create_range("1", "00", "33", vec![]), - create_range("2", "33", "66", vec![]), - create_range("3", "66", "FF", vec![]), - ]; - - let map = ContainerRoutingMap::new(ranges); - - let overlapping = map.overlapping_ranges(&["00".into().."20".into()]); - assert_eq!(overlapping.len(), 1); - assert_eq!(overlapping[0].id.value(), "1"); - - let overlapping = map.overlapping_ranges(&["20".into().."50".into()]); - assert_eq!(overlapping.len(), 2); - assert_eq!(overlapping[0].id.value(), "1"); - assert_eq!(overlapping[1].id.value(), "2"); - - let overlapping = map.overlapping_ranges(&["00".into().."FF".into()]); - assert_eq!(overlapping.len(), 3); - - let overlapping = map.overlapping_ranges(&[]); - assert_eq!(overlapping.len(), 0); - - let overlapping = - map.overlapping_ranges(&["10".into().."20".into(), "70".into().."80".into()]); - assert_eq!(overlapping.len(), 2); - assert_eq!(overlapping[0].id.value(), "1"); - assert_eq!(overlapping[1].id.value(), "3"); - } - - #[test] - fn normalize_ranges_removes_parents() { - let mut ranges = vec![ - create_range("parent1", "00", "50", vec![]), - create_range("parent2", "50", "FF", vec![]), - create_range("child1", "00", "25", vec!["parent1"]), - create_range("child2", "25", "50", vec!["parent1"]), - create_range("child3", "50", "FF", vec!["parent2"]), - ]; - - normalize_ranges(&mut ranges); - - assert_eq!(ranges.len(), 3); - assert_eq!(ranges[0].id.value(), "child1"); - assert_eq!(ranges[1].id.value(), "child2"); - assert_eq!(ranges[2].id.value(), "child3"); - assert!(ranges[0].min_inclusive <= ranges[1].min_inclusive); - assert!(ranges[1].min_inclusive <= ranges[2].min_inclusive); - } - - #[test] - fn normalize_ranges_sorts_by_min_inclusive() { - let mut ranges = vec![ - create_range("3", "66", "99", vec![]), - create_range("1", "00", "33", vec![]), - create_range("2", "33", "66", vec![]), - ]; - - normalize_ranges(&mut ranges); - assert_eq!(ranges[0].id.value(), "1"); - assert_eq!(ranges[1].id.value(), "2"); - assert_eq!(ranges[2].id.value(), "3"); - } - - #[test] - fn normalize_ranges_handles_empty_list() { - let mut ranges: Vec = vec![]; - normalize_ranges(&mut ranges); - assert_eq!(ranges.len(), 0); - } - - #[test] - fn normalize_ranges_handles_no_parents() { - let mut ranges = vec![ - create_range("2", "33", "66", vec![]), - create_range("1", "00", "33", vec![]), - ]; - - normalize_ranges(&mut ranges); - assert_eq!(ranges.len(), 2); - assert_eq!(ranges[0].id.value(), "1"); - assert_eq!(ranges[1].id.value(), "2"); - } - - #[test] - fn container_routing_map_full_test() { - let ranges = vec![ - create_range("parent1", "00", "80", vec![]), - create_range("parent2", "80", "FF", vec![]), - create_range("child1", "00", "40", vec!["parent1"]), - create_range("child2", "40", "80", vec!["parent1"]), - create_range("child3", "80", "C0", vec!["parent2"]), - create_range("child4", "C0", "FF", vec!["parent2"]), - ]; - - let map = ContainerRoutingMap::new(ranges); - - assert_eq!(map.ranges().len(), 4); - assert_eq!( - map.range_containing(&"10".into()).unwrap().id.value(), - "child1" - ); - assert_eq!( - map.range_containing(&"50".into()).unwrap().id.value(), - "child2" - ); - assert_eq!( - map.range_containing(&"90".into()).unwrap().id.value(), - "child3" - ); - assert_eq!( - map.range_containing(&"D0".into()).unwrap().id.value(), - "child4" - ); - - let overlapping = map.overlapping_ranges(&["30".into().."70".into()]); - assert_eq!(overlapping.len(), 2); - assert_eq!(overlapping[0].id.value(), "child1"); - assert_eq!(overlapping[1].id.value(), "child2"); - } -} diff --git a/sdk/cosmos/azure_data_cosmos/src/types.rs b/sdk/cosmos/azure_data_cosmos/src/types.rs index 08db8ba505..4e3d8d9539 100644 --- a/sdk/cosmos/azure_data_cosmos/src/types.rs +++ b/sdk/cosmos/azure_data_cosmos/src/types.rs @@ -37,5 +37,3 @@ macro_rules! string_newtype { } string_newtype!(ResourceId); -string_newtype!(PartitionKeyRangeId); -string_newtype!(EffectivePartitionKey); diff --git a/sdk/cosmos/azure_data_cosmos/tests/cosmos_containers.rs b/sdk/cosmos/azure_data_cosmos/tests/cosmos_containers.rs index c655b7da4c..d4fdf8e655 100644 --- a/sdk/cosmos/azure_data_cosmos/tests/cosmos_containers.rs +++ b/sdk/cosmos/azure_data_cosmos/tests/cosmos_containers.rs @@ -2,8 +2,9 @@ mod framework; -use std::error::Error; +use std::{error::Error, sync::Arc}; +use azure_core::http::Method; use azure_core_test::{recorded, TestContext}; use azure_data_cosmos::{ models::{ @@ -14,7 +15,7 @@ use azure_data_cosmos::{ }; use futures::TryStreamExt; -use framework::{test_data, TestAccount}; +use framework::{test_data, LocalRecorder, TestAccount, TestAccountOptions}; #[recorded::test] pub async fn container_crud(context: TestContext) -> Result<(), Box> { @@ -237,3 +238,71 @@ pub async fn container_crud_hierarchical_pk(context: TestContext) -> Result<(), Ok(()) } + +#[recorded::test] +pub async fn container_read_throughput_twice(context: TestContext) -> Result<(), Box> { + use azure_core::http::StatusCode; + + let recorder = Arc::new(LocalRecorder::new()); + let account = TestAccount::from_env( + context, + Some(TestAccountOptions { + recorder: Some(recorder.clone()), + ..Default::default() + }), + ) + .await?; + + let cosmos_client = account.connect_with_key(None)?; + let db_client = test_data::create_database(&account, &cosmos_client).await?; + + // Create the container with manual throughput + let properties = ContainerProperties { + id: "ThroughputTestContainer".into(), + partition_key: "/id".into(), + ..Default::default() + }; + let throughput = ThroughputProperties::manual(600); + + db_client + .create_container( + properties.clone(), + Some(CreateContainerOptions { + throughput: Some(throughput), + ..Default::default() + }), + ) + .await? + .into_body()?; + let container_client = db_client.container_client(&properties.id); + + let first_throughput = container_client + .read_throughput(None) + .await? + .expect("throughput should be present") + .into_body()?; + assert_eq!(Some(600), first_throughput.throughput()); + + let second_throughput = container_client + .read_throughput(None) + .await? + .expect("throughput should be present") + .into_body()?; + assert_eq!(Some(600), second_throughput.throughput()); + + // Check the recorder to ensure only one request was made to read the container metadata + let txs = recorder.to_transactions().await; + assert_eq!( + 1, + txs.iter() + .filter(|t| t.request.method() == Method::Get + && t.request + .url() + .path() + .ends_with("/colls/ThroughputTestContainer")) + .count() + ); + + account.cleanup().await?; + Ok(()) +} diff --git a/sdk/cosmos/azure_data_cosmos/tests/framework/local_recorder.rs b/sdk/cosmos/azure_data_cosmos/tests/framework/local_recorder.rs new file mode 100644 index 0000000000..595e70970d --- /dev/null +++ b/sdk/cosmos/azure_data_cosmos/tests/framework/local_recorder.rs @@ -0,0 +1,70 @@ +use std::sync::Arc; + +use azure_core::http::{ + policies::{Policy, PolicyResult}, + BufResponse, Context, Method, RawResponse, Request, StatusCode, +}; + +#[derive(Debug, Clone)] +pub struct Transaction { + pub request: Request, + pub response: Option, +} + +/// A policy that can be used to capture a simple local recording of requests for validation purposes +pub struct LocalRecorder { + transactions: tokio::sync::RwLock>, +} + +impl LocalRecorder { + pub fn new() -> Self { + Self { + transactions: tokio::sync::RwLock::new(Vec::new()), + } + } + + pub async fn to_transaction_summary(&self) -> Vec<(Method, String, Option)> { + self.transactions + .read() + .await + .iter() + .map(|t| { + ( + t.request.method(), + t.request.url().to_string(), + t.response.as_ref().map(|r| r.status()), + ) + }) + .collect() + } + + pub async fn to_transactions(&self) -> Vec { + self.transactions.write().await.clone() + } +} + +impl std::fmt::Debug for LocalRecorder { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("LocalRecorder").finish() + } +} + +#[async_trait::async_trait] +impl Policy for LocalRecorder { + async fn send( + &self, + ctx: &Context, + request: &mut Request, + next: &[Arc], + ) -> PolicyResult { + let response = next[0].send(ctx, request, &next[1..]).await?; + let (status, headers, body) = response.deconstruct(); + let body = body.collect().await?; + let raw_response = RawResponse::from_bytes(status, headers.clone(), body.clone()); + self.transactions.write().await.push(Transaction { + request: request.clone(), + response: Some(raw_response.clone()), + }); + Ok(BufResponse::from_bytes(status, headers, body)) + } +} diff --git a/sdk/cosmos/azure_data_cosmos/tests/framework/mod.rs b/sdk/cosmos/azure_data_cosmos/tests/framework/mod.rs index 8b204aafb3..9402e24a70 100644 --- a/sdk/cosmos/azure_data_cosmos/tests/framework/mod.rs +++ b/sdk/cosmos/azure_data_cosmos/tests/framework/mod.rs @@ -8,13 +8,15 @@ //! //! The framework allows tests to easily run against real Cosmos DB instances, the local emulator, or a mock server using test-proxy. +mod local_recorder; mod test_account; pub mod test_data; #[cfg(feature = "preview_query_engine")] pub mod query_engine; -pub use test_account::TestAccount; +pub use local_recorder::{LocalRecorder, Transaction}; +pub use test_account::{TestAccount, TestAccountOptions}; use serde::{Deserialize, Serialize}; diff --git a/sdk/cosmos/azure_data_cosmos/tests/framework/test_account.rs b/sdk/cosmos/azure_data_cosmos/tests/framework/test_account.rs index 6cd8eeb215..55d52386e5 100644 --- a/sdk/cosmos/azure_data_cosmos/tests/framework/test_account.rs +++ b/sdk/cosmos/azure_data_cosmos/tests/framework/test_account.rs @@ -9,6 +9,8 @@ use azure_core_test::TestContext; use azure_data_cosmos::{ConnectionString, CosmosClientOptions, Query}; use reqwest::ClientBuilder; +use crate::framework::LocalRecorder; + /// Represents a Cosmos DB account for testing purposes. /// /// A [`TestAccount`] serves two main purposes: @@ -25,6 +27,7 @@ pub struct TestAccount { #[derive(Default)] pub struct TestAccountOptions { pub allow_invalid_certificates: Option, + pub recorder: Option>, } const CONNECTION_STRING_ENV_VAR: &str = "AZURE_COSMOS_CONNECTION_STRING"; @@ -113,6 +116,13 @@ impl TestAccount { .recording() .instrument(&mut options.client_options); + if let Some(recorder) = &self.options.recorder { + options + .client_options + .per_try_policies + .push(recorder.clone()); + } + Ok(azure_data_cosmos::CosmosClient::with_key( &self.endpoint, self.key.clone(), From 3e12377a284e1e30cc70d6e933e937211c67794e Mon Sep 17 00:00:00 2001 From: Ashley Stanton-Nurse Date: Fri, 3 Oct 2025 21:34:12 +0000 Subject: [PATCH 5/8] update test recordings --- sdk/cosmos/assets.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/cosmos/assets.json b/sdk/cosmos/assets.json index 0e743ffd6e..c6fbb68160 100644 --- a/sdk/cosmos/assets.json +++ b/sdk/cosmos/assets.json @@ -1,6 +1,6 @@ { "AssetsRepo": "Azure/azure-sdk-assets", "AssetsRepoPrefixPath": "rust", - "Tag": "rust/azure_data_cosmos_a39b424a5b", + "Tag": "rust/azure_data_cosmos_69ad1e4995", "TagPrefix": "rust/azure_data_cosmos" } \ No newline at end of file From c8f88c9c32b9cf189d802aba77d50b073d8d9da0 Mon Sep 17 00:00:00 2001 From: Ashley Stanton-Nurse Date: Fri, 3 Oct 2025 21:43:16 +0000 Subject: [PATCH 6/8] final review tidy up --- Cargo.lock | 1 - sdk/cosmos/azure_data_cosmos/Cargo.toml | 1 - sdk/cosmos/azure_data_cosmos/src/cache.rs | 10 +++------- .../src/clients/container_client.rs | 9 +++++---- .../src/clients/cosmos_client.rs | 2 +- .../azure_data_cosmos/src/connection/mod.rs | 12 ++++++++++-- sdk/cosmos/azure_data_cosmos/src/types.rs | 12 ++++++++++-- .../tests/cosmos_containers.rs | 3 --- .../tests/framework/local_recorder.rs | 18 ++---------------- 9 files changed, 31 insertions(+), 37 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a4b5736e15..e711648a09 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -322,7 +322,6 @@ dependencies = [ "clap", "futures", "moka", - "pin-project", "reqwest", "serde", "serde_json", diff --git a/sdk/cosmos/azure_data_cosmos/Cargo.toml b/sdk/cosmos/azure_data_cosmos/Cargo.toml index 12d137e488..bb3266792a 100644 --- a/sdk/cosmos/azure_data_cosmos/Cargo.toml +++ b/sdk/cosmos/azure_data_cosmos/Cargo.toml @@ -17,7 +17,6 @@ categories = ["api-bindings"] async-trait.workspace = true azure_core.workspace = true futures.workspace = true -pin-project.workspace = true serde_json.workspace = true serde.workspace = true tracing.workspace = true diff --git a/sdk/cosmos/azure_data_cosmos/src/cache.rs b/sdk/cosmos/azure_data_cosmos/src/cache.rs index afd72e86ba..18e1a21d23 100644 --- a/sdk/cosmos/azure_data_cosmos/src/cache.rs +++ b/sdk/cosmos/azure_data_cosmos/src/cache.rs @@ -95,7 +95,7 @@ impl ContainerMetadata { /// The cache can be cloned cheaply, and all clones share the same underlying cache data. #[derive(Clone)] pub struct ContainerMetadataCache { - /// Caches stable container metadata, mapping from container link and RID to metadata. + /// Caches stable container metadata, mapping from container link to metadata. container_properties_cache: Cache>, } @@ -133,15 +133,11 @@ impl ContainerMetadataCache { ) -> Result, CacheError> { // TODO: Background refresh. We can do background refresh by storing an expiry time in the cache entry. // Then, if the entry is stale, we can return the stale entry and spawn a background task to refresh it. - // There's a little trickiness here in that + // There's a little trickiness here in that we can't directly spawn a task because that depends on a specific Async Runtime (tokio, smol, etc). + // The core SDK has an AsyncRuntime abstraction that we can use to spawn the task. Ok(self .container_properties_cache .try_get_with_by_ref(key, async { init.await.map(Arc::new) }) .await?) } - - /// Clears the cached container metadata for the specified key, so that the next request will fetch fresh data. - pub async fn clear_container_metadata(&self, key: &ResourceLink) { - self.container_properties_cache.invalidate(key).await; - } } diff --git a/sdk/cosmos/azure_data_cosmos/src/clients/container_client.rs b/sdk/cosmos/azure_data_cosmos/src/clients/container_client.rs index e6f249bcc2..6e3eabf2ef 100644 --- a/sdk/cosmos/azure_data_cosmos/src/clients/container_client.rs +++ b/sdk/cosmos/azure_data_cosmos/src/clients/container_client.rs @@ -26,7 +26,6 @@ use serde::{de::DeserializeOwned, Serialize}; /// You can get a `Container` by calling [`DatabaseClient::container_client()`](crate::clients::DatabaseClient::container_client()). #[derive(Clone)] pub struct ContainerClient { - id: String, link: ResourceLink, items_link: ResourceLink, connection: CosmosConnection, @@ -44,7 +43,6 @@ impl ContainerClient { let items_link = link.feed(ResourceType::Items); Self { - id: container_id.to_string(), link, items_link, connection, @@ -77,7 +75,10 @@ impl ContainerClient { // TODO: Replace with `response.body().json()` when that becomes borrowing. let properties = serde_json::from_slice::(response.body())?; let metadata = ContainerMetadata::from_properties(&properties, self.link.clone())?; - self.connection.cache.set_container_metadata(metadata).await; + self.connection + .cache() + .set_container_metadata(metadata) + .await; Ok(response.into()) } @@ -705,7 +706,7 @@ impl ContainerClient { async fn metadata(&self) -> azure_core::Result> { Ok(self .connection - .cache + .cache() .get_container_metadata(&self.link, async { let properties = self.read_properties(None).await?.into_body()?; ContainerMetadata::from_properties(&properties, self.link.clone()) diff --git a/sdk/cosmos/azure_data_cosmos/src/clients/cosmos_client.rs b/sdk/cosmos/azure_data_cosmos/src/clients/cosmos_client.rs index b02600d61a..418b0e82b2 100644 --- a/sdk/cosmos/azure_data_cosmos/src/clients/cosmos_client.rs +++ b/sdk/cosmos/azure_data_cosmos/src/clients/cosmos_client.rs @@ -139,7 +139,7 @@ impl CosmosClient { /// Gets the endpoint of the database account this client is connected to. pub fn endpoint(&self) -> &Url { - &self.connection.endpoint + self.connection.endpoint() } /// Executes a query against databases in the account. diff --git a/sdk/cosmos/azure_data_cosmos/src/connection/mod.rs b/sdk/cosmos/azure_data_cosmos/src/connection/mod.rs index 7f194a812e..1debfe2faa 100644 --- a/sdk/cosmos/azure_data_cosmos/src/connection/mod.rs +++ b/sdk/cosmos/azure_data_cosmos/src/connection/mod.rs @@ -31,8 +31,8 @@ use crate::{ /// A connection is cheap to clone, and all clones share the same underlying HTTP pipeline and metadata cache. #[derive(Clone)] pub struct CosmosConnection { - pub endpoint: Url, - pub cache: ContainerMetadataCache, + endpoint: Url, + cache: ContainerMetadataCache, pipeline: azure_core::http::Pipeline, } @@ -56,6 +56,14 @@ impl CosmosConnection { } } + pub fn endpoint(&self) -> &Url { + &self.endpoint + } + + pub fn cache(&self) -> &ContainerMetadataCache { + &self.cache + } + /// Creates a [`Url`] out of the provided [`ResourceLink`] /// /// This is a little backwards, ideally we'd accept [`ResourceLink`] in the [`CosmosPipeline::send`] method, diff --git a/sdk/cosmos/azure_data_cosmos/src/types.rs b/sdk/cosmos/azure_data_cosmos/src/types.rs index 4e3d8d9539..4ea1e871f4 100644 --- a/sdk/cosmos/azure_data_cosmos/src/types.rs +++ b/sdk/cosmos/azure_data_cosmos/src/types.rs @@ -1,16 +1,19 @@ //! Internal module to define several newtypes used in the SDK. macro_rules! string_newtype { - ($name:ident) => { + ($(#[$attr:meta])* $name:ident) => { + $(#[$attr])* #[derive(serde::Deserialize, serde::Serialize, Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] #[serde(transparent)] pub struct $name(String); impl $name { + #[doc = concat!("Creates a new `", stringify!($name), "` from a `String`.")] pub fn new(value: String) -> Self { Self(value) } + #[doc = concat!("Returns a reference to the inner `str` of the `", stringify!($name), "`.")] pub fn value(&self) -> &str { &self.0 } @@ -36,4 +39,9 @@ macro_rules! string_newtype { }; } -string_newtype!(ResourceId); +string_newtype!( + /// Represents a Resource ID, which is a unique identifier for a resource within a Cosmos DB account. + /// + /// In most cases, you don't need to use this type directly, as the SDK will handle resource IDs for you. + ResourceId +); diff --git a/sdk/cosmos/azure_data_cosmos/tests/cosmos_containers.rs b/sdk/cosmos/azure_data_cosmos/tests/cosmos_containers.rs index d4fdf8e655..91dc14e0e0 100644 --- a/sdk/cosmos/azure_data_cosmos/tests/cosmos_containers.rs +++ b/sdk/cosmos/azure_data_cosmos/tests/cosmos_containers.rs @@ -241,8 +241,6 @@ pub async fn container_crud_hierarchical_pk(context: TestContext) -> Result<(), #[recorded::test] pub async fn container_read_throughput_twice(context: TestContext) -> Result<(), Box> { - use azure_core::http::StatusCode; - let recorder = Arc::new(LocalRecorder::new()); let account = TestAccount::from_env( context, @@ -256,7 +254,6 @@ pub async fn container_read_throughput_twice(context: TestContext) -> Result<(), let cosmos_client = account.connect_with_key(None)?; let db_client = test_data::create_database(&account, &cosmos_client).await?; - // Create the container with manual throughput let properties = ContainerProperties { id: "ThroughputTestContainer".into(), partition_key: "/id".into(), diff --git a/sdk/cosmos/azure_data_cosmos/tests/framework/local_recorder.rs b/sdk/cosmos/azure_data_cosmos/tests/framework/local_recorder.rs index 595e70970d..63ac8559b1 100644 --- a/sdk/cosmos/azure_data_cosmos/tests/framework/local_recorder.rs +++ b/sdk/cosmos/azure_data_cosmos/tests/framework/local_recorder.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use azure_core::http::{ policies::{Policy, PolicyResult}, - BufResponse, Context, Method, RawResponse, Request, StatusCode, + BufResponse, Context, RawResponse, Request, }; #[derive(Debug, Clone)] @@ -23,21 +23,7 @@ impl LocalRecorder { } } - pub async fn to_transaction_summary(&self) -> Vec<(Method, String, Option)> { - self.transactions - .read() - .await - .iter() - .map(|t| { - ( - t.request.method(), - t.request.url().to_string(), - t.response.as_ref().map(|r| r.status()), - ) - }) - .collect() - } - + /// Returns a copy of all recorded transactions pub async fn to_transactions(&self) -> Vec { self.transactions.write().await.clone() } From d1a0cfdc9c7451b0f792c6739fef99a98b773b68 Mon Sep 17 00:00:00 2001 From: Ashley Stanton-Nurse Date: Fri, 3 Oct 2025 21:44:54 +0000 Subject: [PATCH 7/8] removed unused metadata from cache for now --- sdk/cosmos/azure_data_cosmos/src/cache.rs | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/sdk/cosmos/azure_data_cosmos/src/cache.rs b/sdk/cosmos/azure_data_cosmos/src/cache.rs index 18e1a21d23..bb9bac768b 100644 --- a/sdk/cosmos/azure_data_cosmos/src/cache.rs +++ b/sdk/cosmos/azure_data_cosmos/src/cache.rs @@ -48,9 +48,7 @@ impl std::error::Error for CacheError { /// A subset of container properties that are stable and suitable for caching. pub(crate) struct ContainerMetadata { - pub self_link: String, pub resource_id: ResourceId, - pub partition_key: PartitionKeyDefinition, pub container_link: ResourceLink, } @@ -60,17 +58,6 @@ impl ContainerMetadata { properties: &ContainerProperties, container_link: ResourceLink, ) -> azure_core::Result { - let self_link = properties - .system_properties - .self_link - .as_ref() - .ok_or_else(|| { - azure_core::Error::new( - azure_core::error::ErrorKind::Other, - "container properties is missing expected value 'self_link'", - ) - })? - .clone(); let resource_id = properties .system_properties .resource_id @@ -82,9 +69,7 @@ impl ContainerMetadata { ) })?; Ok(Self { - self_link, resource_id, - partition_key: properties.partition_key.clone(), container_link, }) } From 1eafb4361914e5d110dcc18eb1ecd5b4723e9899 Mon Sep 17 00:00:00 2001 From: Ashley Stanton-Nurse Date: Fri, 3 Oct 2025 22:58:43 +0000 Subject: [PATCH 8/8] clippy lints --- sdk/cosmos/azure_data_cosmos/src/cache.rs | 6 +----- sdk/cosmos/azure_data_cosmos/src/clients/database_client.rs | 2 -- 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/sdk/cosmos/azure_data_cosmos/src/cache.rs b/sdk/cosmos/azure_data_cosmos/src/cache.rs index bb9bac768b..d1c0106021 100644 --- a/sdk/cosmos/azure_data_cosmos/src/cache.rs +++ b/sdk/cosmos/azure_data_cosmos/src/cache.rs @@ -2,11 +2,7 @@ use std::sync::Arc; use moka::future::Cache; -use crate::{ - models::{ContainerProperties, PartitionKeyDefinition}, - resource_context::ResourceLink, - ResourceId, -}; +use crate::{models::ContainerProperties, resource_context::ResourceLink, ResourceId}; #[derive(Debug)] pub enum CacheError { diff --git a/sdk/cosmos/azure_data_cosmos/src/clients/database_client.rs b/sdk/cosmos/azure_data_cosmos/src/clients/database_client.rs index 38aa22e958..8a7fd04bad 100644 --- a/sdk/cosmos/azure_data_cosmos/src/clients/database_client.rs +++ b/sdk/cosmos/azure_data_cosmos/src/clients/database_client.rs @@ -1,8 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -use std::sync::Arc; - use crate::{ clients::ContainerClient, connection::CosmosConnection,