@@ -89,6 +89,9 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
89
89
pick_categories: Restrict the dataset to the given list of categories.
90
90
pick_sequences: A Sequence of sequence names to restrict the dataset to.
91
91
exclude_sequences: A Sequence of the names of the sequences to exclude.
92
+ limit_sequences_per_category_to: Limit the dataset to the first up to N
93
+ sequences within each category (applies after all other sequence filters
94
+ but before `limit_sequences_to`).
92
95
limit_sequences_to: Limit the dataset to the first `limit_sequences_to`
93
96
sequences (after other sequence filters have been applied but before
94
97
frame-based filters).
@@ -115,6 +118,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
115
118
116
119
pick_sequences : Tuple [str , ...] = ()
117
120
exclude_sequences : Tuple [str , ...] = ()
121
+ limit_sequences_per_category_to : int = 0
118
122
limit_sequences_to : int = 0
119
123
limit_to : int = 0
120
124
n_frames_per_sequence : int = - 1
@@ -373,27 +377,46 @@ def is_filtered(self) -> bool:
373
377
self .remove_empty_masks
374
378
or self .limit_to > 0
375
379
or self .limit_sequences_to > 0
380
+ or self .limit_sequences_per_category_to > 0
376
381
or len (self .pick_sequences ) > 0
377
382
or len (self .exclude_sequences ) > 0
378
383
or len (self .pick_categories ) > 0
379
384
or self .n_frames_per_sequence > 0
380
385
)
381
386
382
387
def _get_filtered_sequences_if_any (self ) -> Optional [pd .Series ]:
383
- # maximum possible query: WHERE category IN 'self.pick_categories'
388
+ # maximum possible filter (if limit_sequences_per_category_to == 0):
389
+ # WHERE category IN 'self.pick_categories'
384
390
# AND sequence_name IN 'self.pick_sequences'
385
391
# AND sequence_name NOT IN 'self.exclude_sequences'
386
392
# LIMIT 'self.limit_sequence_to'
387
393
388
- stmt = sa .select (SqlSequenceAnnotation .sequence_name )
389
-
390
394
where_conditions = [
391
395
* self ._get_category_filters (),
392
396
* self ._get_pick_filters (),
393
397
* self ._get_exclude_filters (),
394
398
]
395
- if where_conditions :
396
- stmt = stmt .where (* where_conditions )
399
+
400
+ def add_where (stmt ):
401
+ return stmt .where (* where_conditions ) if where_conditions else stmt
402
+
403
+ if self .limit_sequences_per_category_to <= 0 :
404
+ stmt = add_where (sa .select (SqlSequenceAnnotation .sequence_name ))
405
+ else :
406
+ subquery = sa .select (
407
+ SqlSequenceAnnotation .sequence_name ,
408
+ sa .func .row_number ()
409
+ .over (
410
+ order_by = sa .text ("ROWID" ), # NOTE: ROWID is SQLite-specific
411
+ partition_by = SqlSequenceAnnotation .category ,
412
+ )
413
+ .label ("row_number" ),
414
+ )
415
+
416
+ subquery = add_where (subquery ).subquery ()
417
+ stmt = sa .select (subquery .c .sequence_name ).where (
418
+ subquery .c .row_number <= self .limit_sequences_per_category_to
419
+ )
397
420
398
421
if self .limit_sequences_to > 0 :
399
422
logger .info (
@@ -402,7 +425,11 @@ def _get_filtered_sequences_if_any(self) -> Optional[pd.Series]:
402
425
# NOTE: ROWID is SQLite-specific
403
426
stmt = stmt .order_by (sa .text ("ROWID" )).limit (self .limit_sequences_to )
404
427
405
- if not where_conditions and self .limit_sequences_to <= 0 :
428
+ if (
429
+ not where_conditions
430
+ and self .limit_sequences_to <= 0
431
+ and self .limit_sequences_per_category_to <= 0
432
+ ):
406
433
# we will not need to filter by sequences
407
434
return None
408
435
0 commit comments