Skip to content

Commit 09a99f2

Browse files
shapovalovfacebook-github-bot
authored andcommitted
Support limiting num sequences per category.
Summary: Adds stratified sampling of sequences within categories applied after category / sequence filters but before the num sequence limit. It respects the insertion order into the sequence_annots table, i.e. takes top N sequences within each category. Reviewed By: bottler Differential Revision: D46724002 fbshipit-source-id: 597cb2a795c3f3bc07f838fc51b4e95a4f981ad3
1 parent 5ffeb4d commit 09a99f2

File tree

2 files changed

+57
-6
lines changed

2 files changed

+57
-6
lines changed

pytorch3d/implicitron/dataset/sql_dataset.py

+33-6
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,9 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
8989
pick_categories: Restrict the dataset to the given list of categories.
9090
pick_sequences: A Sequence of sequence names to restrict the dataset to.
9191
exclude_sequences: A Sequence of the names of the sequences to exclude.
92+
limit_sequences_per_category_to: Limit the dataset to the first up to N
93+
sequences within each category (applies after all other sequence filters
94+
but before `limit_sequences_to`).
9295
limit_sequences_to: Limit the dataset to the first `limit_sequences_to`
9396
sequences (after other sequence filters have been applied but before
9497
frame-based filters).
@@ -115,6 +118,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
115118

116119
pick_sequences: Tuple[str, ...] = ()
117120
exclude_sequences: Tuple[str, ...] = ()
121+
limit_sequences_per_category_to: int = 0
118122
limit_sequences_to: int = 0
119123
limit_to: int = 0
120124
n_frames_per_sequence: int = -1
@@ -373,27 +377,46 @@ def is_filtered(self) -> bool:
373377
self.remove_empty_masks
374378
or self.limit_to > 0
375379
or self.limit_sequences_to > 0
380+
or self.limit_sequences_per_category_to > 0
376381
or len(self.pick_sequences) > 0
377382
or len(self.exclude_sequences) > 0
378383
or len(self.pick_categories) > 0
379384
or self.n_frames_per_sequence > 0
380385
)
381386

382387
def _get_filtered_sequences_if_any(self) -> Optional[pd.Series]:
383-
# maximum possible query: WHERE category IN 'self.pick_categories'
388+
# maximum possible filter (if limit_sequences_per_category_to == 0):
389+
# WHERE category IN 'self.pick_categories'
384390
# AND sequence_name IN 'self.pick_sequences'
385391
# AND sequence_name NOT IN 'self.exclude_sequences'
386392
# LIMIT 'self.limit_sequence_to'
387393

388-
stmt = sa.select(SqlSequenceAnnotation.sequence_name)
389-
390394
where_conditions = [
391395
*self._get_category_filters(),
392396
*self._get_pick_filters(),
393397
*self._get_exclude_filters(),
394398
]
395-
if where_conditions:
396-
stmt = stmt.where(*where_conditions)
399+
400+
def add_where(stmt):
401+
return stmt.where(*where_conditions) if where_conditions else stmt
402+
403+
if self.limit_sequences_per_category_to <= 0:
404+
stmt = add_where(sa.select(SqlSequenceAnnotation.sequence_name))
405+
else:
406+
subquery = sa.select(
407+
SqlSequenceAnnotation.sequence_name,
408+
sa.func.row_number()
409+
.over(
410+
order_by=sa.text("ROWID"), # NOTE: ROWID is SQLite-specific
411+
partition_by=SqlSequenceAnnotation.category,
412+
)
413+
.label("row_number"),
414+
)
415+
416+
subquery = add_where(subquery).subquery()
417+
stmt = sa.select(subquery.c.sequence_name).where(
418+
subquery.c.row_number <= self.limit_sequences_per_category_to
419+
)
397420

398421
if self.limit_sequences_to > 0:
399422
logger.info(
@@ -402,7 +425,11 @@ def _get_filtered_sequences_if_any(self) -> Optional[pd.Series]:
402425
# NOTE: ROWID is SQLite-specific
403426
stmt = stmt.order_by(sa.text("ROWID")).limit(self.limit_sequences_to)
404427

405-
if not where_conditions and self.limit_sequences_to <= 0:
428+
if (
429+
not where_conditions
430+
and self.limit_sequences_to <= 0
431+
and self.limit_sequences_per_category_to <= 0
432+
):
406433
# we will not need to filter by sequences
407434
return None
408435

tests/implicitron/test_sql_dataset.py

+24
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,30 @@ def test_limit_frames_per_sequence(self, num_frames=2):
222222
)
223223
self.assertEqual(len(dataset), 100)
224224

225+
def test_limit_sequence_per_category(self, num_sequences=2):
226+
dataset = SqlIndexDataset(
227+
sqlite_metadata_file=METADATA_FILE,
228+
remove_empty_masks=False,
229+
limit_sequences_per_category_to=num_sequences,
230+
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
231+
)
232+
233+
self.assertEqual(len(dataset), num_sequences * 10 * 2)
234+
seq_names = list(dataset.sequence_names())
235+
self.assertEqual(len(seq_names), num_sequences * 2)
236+
# check that we respect the row order
237+
for seq_name in seq_names:
238+
self.assertLess(int(seq_name[-1]), num_sequences)
239+
240+
# test when the limit is not binding
241+
dataset = SqlIndexDataset(
242+
sqlite_metadata_file=METADATA_FILE,
243+
remove_empty_masks=False,
244+
limit_sequences_per_category_to=13,
245+
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
246+
)
247+
self.assertEqual(len(dataset), 100)
248+
225249
def test_filter_medley(self):
226250
dataset = SqlIndexDataset(
227251
sqlite_metadata_file=METADATA_FILE,

0 commit comments

Comments
 (0)