diff --git a/src/metadata.rs b/src/metadata.rs index 79cec19e5..a5014d4dc 100644 --- a/src/metadata.rs +++ b/src/metadata.rs @@ -74,7 +74,7 @@ pub enum SchemaVersion { V1, } -#[derive(Debug, Default)] +#[derive(Debug, Default, Clone)] pub struct LogStreamMetadata { pub schema_version: SchemaVersion, pub schema: HashMap>, diff --git a/src/parseable/mod.rs b/src/parseable/mod.rs index f8ffa1ca9..f4b220db3 100644 --- a/src/parseable/mod.rs +++ b/src/parseable/mod.rs @@ -171,7 +171,7 @@ impl Parseable { } // Gets write privileges only for creating the stream when it doesn't already exist. - self.streams.create( + self.streams.get_or_create( self.options.clone(), stream_name.to_owned(), LogStreamMetadata::default(), @@ -342,7 +342,7 @@ impl Parseable { schema_version, log_source, ); - self.streams.create( + self.streams.get_or_create( self.options.clone(), stream_name.to_string(), metadata, @@ -652,7 +652,7 @@ impl Parseable { SchemaVersion::V1, // New stream log_source, ); - self.streams.create( + self.streams.get_or_create( self.options.clone(), stream_name.to_string(), metadata, diff --git a/src/parseable/streams.rs b/src/parseable/streams.rs index 784658c1f..b8f441912 100644 --- a/src/parseable/streams.rs +++ b/src/parseable/streams.rs @@ -737,17 +737,22 @@ pub struct Streams(RwLock>); // 4. When first event is sent to stream (update the schema) // 5. When set alert API is called (update the alert) impl Streams { - pub fn create( + /// Checks after getting an exclusive lock whether the stream already exists, else creates it. + /// NOTE: This is done to ensure we don't have contention among threads. + pub fn get_or_create( &self, options: Arc, stream_name: String, metadata: LogStreamMetadata, ingestor_id: Option, ) -> StreamRef { + let mut guard = self.write().expect(LOCK_EXPECT); + if let Some(stream) = guard.get(&stream_name) { + return stream.clone(); + } + let stream = Stream::new(options, &stream_name, metadata, ingestor_id); - self.write() - .expect(LOCK_EXPECT) - .insert(stream_name, stream.clone()); + guard.insert(stream_name, stream.clone()); stream } @@ -812,7 +817,7 @@ impl Streams { #[cfg(test)] mod tests { - use std::time::Duration; + use std::{sync::Barrier, thread::spawn, time::Duration}; use arrow_array::{Int32Array, StringArray, TimestampMillisecondArray}; use arrow_schema::{DataType, Field, TimeUnit}; @@ -1187,4 +1192,113 @@ mod tests { assert_eq!(staging.parquet_files().len(), 2); assert_eq!(staging.arrow_files().len(), 1); } + + #[test] + fn get_or_create_returns_existing_stream() { + let streams = Streams::default(); + let options = Arc::new(Options::default()); + let stream_name = "test_stream"; + let metadata = LogStreamMetadata::default(); + let ingestor_id = Some("test_ingestor".to_owned()); + + // Create the stream first + let stream1 = streams.get_or_create( + options.clone(), + stream_name.to_owned(), + metadata.clone(), + ingestor_id.clone(), + ); + + // Call get_or_create again with the same stream_name + let stream2 = streams.get_or_create( + options.clone(), + stream_name.to_owned(), + metadata.clone(), + ingestor_id.clone(), + ); + + // Assert that both references point to the same stream + assert!(Arc::ptr_eq(&stream1, &stream2)); + + // Verify the map contains only one entry + let guard = streams.read().expect("Failed to acquire read lock"); + assert_eq!(guard.len(), 1); + } + + #[test] + fn create_and_return_new_stream_when_name_does_not_exist() { + let streams = Streams::default(); + let options = Arc::new(Options::default()); + let stream_name = "new_stream"; + let metadata = LogStreamMetadata::default(); + let ingestor_id = Some("new_ingestor".to_owned()); + + // Assert the stream doesn't exist already + let guard = streams.read().expect("Failed to acquire read lock"); + assert_eq!(guard.len(), 0); + assert!(!guard.contains_key(stream_name)); + drop(guard); + + // Call get_or_create with a new stream_name + let stream = streams.get_or_create( + options.clone(), + stream_name.to_owned(), + metadata.clone(), + ingestor_id.clone(), + ); + + // verify created stream has the same ingestor_id + assert_eq!(stream.ingestor_id, ingestor_id); + + // Assert that the stream is created + let guard = streams.read().expect("Failed to acquire read lock"); + assert_eq!(guard.len(), 1); + assert!(guard.contains_key(stream_name)); + } + + #[test] + fn get_or_create_stream_concurrently() { + let streams = Arc::new(Streams::default()); + let options = Arc::new(Options::default()); + let stream_name = String::from("concurrent_stream"); + let metadata = LogStreamMetadata::default(); + let ingestor_id = Some(String::from("concurrent_ingestor")); + + // Barrier to synchronize threads + let barrier = Arc::new(Barrier::new(2)); + + // Clones for the first thread + let streams1 = Arc::clone(&streams); + let options1 = Arc::clone(&options); + let barrier1 = Arc::clone(&barrier); + let stream_name1 = stream_name.clone(); + let metadata1 = metadata.clone(); + let ingestor_id1 = ingestor_id.clone(); + + // First thread + let handle1 = spawn(move || { + barrier1.wait(); + streams1.get_or_create(options1, stream_name1, metadata1, ingestor_id1) + }); + + // Cloned for the second thread + let streams2 = Arc::clone(&streams); + + // Second thread + let handle2 = spawn(move || { + barrier.wait(); + streams2.get_or_create(options, stream_name, metadata, ingestor_id) + }); + + // Wait for both threads to complete and get their results + let stream1 = handle1.join().expect("Thread 1 panicked"); + let stream2 = handle2.join().expect("Thread 2 panicked"); + + // Assert that both references point to the same stream + assert!(Arc::ptr_eq(&stream1, &stream2)); + + // Verify the map contains only one entry + let guard = streams.read().expect("Failed to acquire read lock"); + assert_eq!(guard.len(), 1); + } }