@@ -166,20 +166,36 @@ def close_partition(self, partition: Partition) -> None:
166166
167167 def _check_and_update_parent_state (self ) -> None :
168168 """
169- If all slices for the earliest partitions are closed, pop them from the left
170- of _partition_parent_state_map and update _parent_state to the most recent popped.
169+ Pop the leftmost partition state from _partition_parent_state_map only if
170+ *all partitions* up to (and including) that partition key in _semaphore_per_partition
171+ are fully finished (i.e. in _finished_partitions and semaphore._value == 0).
171172 """
172173 last_closed_state = None
173- # We iterate in creation order (left to right) in the OrderedDict
174- for p_key in list (self ._partition_parent_state_map .keys ()):
175- # If this partition is not fully closed, stop
176- if p_key not in self ._finished_partitions or self ._semaphore_per_partition [p_key ]._value != 0 :
174+
175+ while self ._partition_parent_state_map :
176+ # Look at the earliest partition key in creation order
177+ earliest_key = next (iter (self ._partition_parent_state_map ))
178+
179+ # Verify ALL partitions from the left up to earliest_key are finished
180+ all_left_finished = True
181+ for p_key , sem in self ._semaphore_per_partition .items ():
182+ # If any earlier partition is still not finished, we must stop
183+ if p_key not in self ._finished_partitions or sem ._value != 0 :
184+ all_left_finished = False
185+ break
186+ # Once we've reached earliest_key in the semaphore order, we can stop checking
187+ if p_key == earliest_key :
188+ break
189+
190+ # If the partitions up to earliest_key are not all finished, break the while-loop
191+ if not all_left_finished :
177192 break
178- # Otherwise, we pop from the left
193+
194+ # Otherwise, pop the leftmost entry from parent-state map
179195 _ , closed_parent_state = self ._partition_parent_state_map .popitem (last = False )
180196 last_closed_state = closed_parent_state
181197
182- # If we popped at least one partition, update the parent_state to that partition's parent state
198+ # Update _parent_state if we actually popped at least one partition
183199 if last_closed_state is not None :
184200 self ._parent_state = last_closed_state
185201
@@ -228,11 +244,13 @@ def stream_slices(self) -> Iterable[StreamSlice]:
228244 slices = self ._partition_router .stream_slices ()
229245 self ._timer .start ()
230246 for partition , last , parent_state in iterate_with_last_flag_and_state (
231- slices , self ._partition_router .get_stream_state
247+ slices , self ._partition_router .get_stream_state
232248 ):
233249 yield from self ._generate_slices_from_partition (partition , parent_state )
234250
235- def _generate_slices_from_partition (self , partition : StreamSlice , parent_state : Mapping [str , Any ]) -> Iterable [StreamSlice ]:
251+ def _generate_slices_from_partition (
252+ self , partition : StreamSlice , parent_state : Mapping [str , Any ]
253+ ) -> Iterable [StreamSlice ]:
236254 # Ensure the maximum number of partitions is not exceeded
237255 self ._ensure_partition_limit ()
238256
@@ -247,12 +265,17 @@ def _generate_slices_from_partition(self, partition: StreamSlice, parent_state:
247265 with self ._lock :
248266 self ._number_of_partitions += 1
249267 self ._cursor_per_partition [partition_key ] = cursor
250- self ._semaphore_per_partition [partition_key ] = (
251- threading .Semaphore (0 )
252- )
268+ self ._semaphore_per_partition [partition_key ] = threading .Semaphore (0 )
253269
254270 with self ._lock :
255- self ._partition_parent_state_map [partition_key ] = deepcopy (parent_state )
271+ if (
272+ len (self ._partition_parent_state_map ) == 0
273+ or self ._partition_parent_state_map [
274+ next (reversed (self ._partition_parent_state_map ))
275+ ]
276+ != parent_state
277+ ):
278+ self ._partition_parent_state_map [partition_key ] = deepcopy (parent_state )
256279
257280 for cursor_slice , is_last_slice , _ in iterate_with_last_flag_and_state (
258281 cursor .stream_slices (),
@@ -287,7 +310,6 @@ def _ensure_partition_limit(self) -> None:
287310 self ._use_global_cursor = True
288311
289312 with self ._lock :
290- self ._number_of_partitions += 1
291313 while len (self ._cursor_per_partition ) > self .DEFAULT_MAX_PARTITIONS_NUMBER - 1 :
292314 # Try removing finished partitions first
293315 for partition_key in list (self ._cursor_per_partition .keys ()):
@@ -372,9 +394,6 @@ def _set_initial_state(self, stream_state: StreamState) -> None:
372394 self ._cursor_per_partition [self ._to_partition_key (state ["partition" ])] = (
373395 self ._create_cursor (state ["cursor" ])
374396 )
375- self ._semaphore_per_partition [self ._to_partition_key (state ["partition" ])] = (
376- threading .Semaphore (0 )
377- )
378397
379398 # set default state for missing partitions if it is per partition with fallback to global
380399 if self ._GLOBAL_STATE_KEY in stream_state :
0 commit comments