Skip to content

Commit 403c364

Browse files
fix relational restrictions applied before group by
1 parent 2523cc6 commit 403c364

File tree

4 files changed

+81
-45
lines changed

4 files changed

+81
-45
lines changed

datajoint/expression.py

Lines changed: 64 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,9 @@ def from_clause(self):
9393
using="" if not a else " USING (%s)" % ",".join('`%s`' % _ for _ in a))
9494
return clause
9595

96-
@property
9796
def where_clause(self):
98-
return '' if not self.restriction else ' WHERE(%s)' % ')AND('.join(str(s) for s in self.restriction)
97+
return '' if not self.restriction else ' WHERE(%s)' % ')AND('.join(
98+
str(s) for s in self.restriction)
9999

100100
def make_sql(self, fields=None):
101101
"""
@@ -106,7 +106,7 @@ def make_sql(self, fields=None):
106106
return 'SELECT {distinct}{fields} FROM {from_}{where}'.format(
107107
distinct="DISTINCT " if distinct else "",
108108
fields=self.heading.as_sql(fields or self.heading.names),
109-
from_=self.from_clause(), where=self.where_clause)
109+
from_=self.from_clause(), where=self.where_clause())
110110

111111
# --------- query operators -----------
112112
def make_subquery(self):
@@ -180,7 +180,7 @@ def restrict(self, restriction):
180180
result = self.make_subquery()
181181
else:
182182
result = copy.copy(self)
183-
result._restriction = AndList(self.restriction) # make a copy to protect the original
183+
result._restriction = AndList(self.restriction) # copy to preserve the original
184184
result.restriction.append(new_condition)
185185
result.restriction_attributes.update(attributes)
186186
return result
@@ -421,25 +421,27 @@ def tail(self, limit=25, **fetch_kwargs):
421421

422422
def __len__(self):
423423
""" :return: number of elements in the result set """
424-
what = '*' if set(self.heading.names) != set(self.primary_key) else 'DISTINCT %s' % ','.join(
425-
(self.heading[k].attribute_expression or '`%s`' % k for k in self.primary_key))
426424
return self.connection.query(
427-
'SELECT count({what}) FROM {from_}{where}'.format(
428-
what=what,
425+
'SELECT count(DISTINCT {fields}) FROM {from_}{where}'.format(
426+
fields=self.heading.as_sql(self.primary_key, include_aliases=False),
429427
from_=self.from_clause(),
430-
where=self.where_clause)).fetchone()[0]
428+
where=self.where_clause())).fetchone()[0]
431429

432430
def __bool__(self):
433431
"""
434-
:return: True if the result is not empty. Equivalent to len(rel)>0 but may be more efficient.
432+
:return: True if the result is not empty. Equivalent to len(self) > 0 but often faster.
435433
"""
436-
return len(self) > 0
434+
return bool(self.connection.query(
435+
'SELECT EXISTS(SELECT 1 FROM {from_}{where})'.format(
436+
from_=self.from_clause(),
437+
where=self.where_clause())).fetchone()[0])
437438

438439
def __contains__(self, item):
439440
"""
440441
returns True if item is found in the .
441442
:param item: any restriction
442-
(item in query_expression) is equivalent to bool(query_expression & item) but may be executed more efficiently.
443+
(item in query_expression) is equivalent to bool(query_expression & item) but may be
444+
executed more efficiently.
443445
"""
444446
return bool(self & item) # May be optimized e.g. using an EXISTS query
445447

@@ -453,7 +455,8 @@ def __next__(self):
453455
key = self._iter_keys.pop(0)
454456
except AttributeError:
455457
# self._iter_keys is missing because __iter__ has not been called.
456-
raise TypeError("'QueryExpression' object is not an iterator. Use iter(obj) to create an iterator.")
458+
raise TypeError("A QueryExpression object is not an iterator. "
459+
"Use iter(obj) to create an iterator.")
457460
except IndexError:
458461
raise StopIteration
459462
else:
@@ -463,7 +466,8 @@ def __next__(self):
463466
try:
464467
return (self & key).fetch1()
465468
except DataJointError:
466-
# The data may have been deleted since the moment the keys were fetched -- move on to next entry.
469+
# The data may have been deleted since the moment the keys were fetched
470+
# -- move on to next entry.
467471
return next(self)
468472

469473
def cursor(self, offset=0, limit=None, order_by=None, as_dict=False):
@@ -495,8 +499,10 @@ def _repr_html_(self):
495499

496500
class Aggregation(QueryExpression):
497501
"""
498-
Aggregation(rel, comp1='expr1', ..., compn='exprn') yields an entity set with the primary key specified by rel.heading.
499-
The computed arguments comp1, ..., compn use aggregation operators on the attributes of rel.
502+
Aggregation.create(arg, group, comp1='calc1', ..., compn='calcn') yields an entity set
503+
with primary key from arg.
504+
The computed arguments comp1, ..., compn use aggregation calculations on the attributes of
505+
group or simple projections and calculations on the attributes of arg.
500506
Aggregation is used QueryExpression.aggr and U.aggr.
501507
Aggregation is a private class in DataJoint, not exposed to users.
502508
"""
@@ -517,31 +523,34 @@ def create(cls, arg, group, keep_all_rows=False):
517523
result._support = join.support
518524
result._join_attributes = join._join_attributes
519525
result._left = join._left
520-
result.initial_restriction = join.restriction # WHERE clause applied before GROUP BY
526+
result._left_restrict = join.restriction # WHERE clause applied before GROUP BY
521527
result._grouping_attributes = result.primary_key
522528
return result
523529

530+
def where_clause(self):
531+
return '' if not self._left_restrict else ' WHERE (%s)' % ')AND('.join(
532+
str(s) for s in self._left_restrict)
533+
524534
def make_sql(self, fields=None):
525-
where = '' if not self._left_restrict else ' WHERE (%s)' % ')AND('.join(self._left_restrict)
526535
fields = self.heading.as_sql(fields or self.heading.names)
527536
assert self._grouping_attributes or not self.restriction
528537
distinct = set(self.heading.names) == set(self.primary_key)
529538
return 'SELECT {distinct}{fields} FROM {from_}{where}{group_by}'.format(
530539
distinct="DISTINCT " if distinct else "",
531540
fields=fields,
532541
from_=self.from_clause(),
533-
where=where,
542+
where=self.where_clause(),
534543
group_by="" if not self.primary_key else (
535544
" GROUP BY `%s`" % '`,`'.join(self._grouping_attributes) +
536545
("" if not self.restriction else ' HAVING (%s)' % ')AND('.join(self.restriction))))
537546

538547
def __len__(self):
539-
what = '*' if set(self.heading.names) != set(self.primary_key) else 'DISTINCT `%s`' % '`,`'.join(self.primary_key)
540548
return self.connection.query(
541-
'SELECT count({what}) FROM ({subquery}) as `_r{alias:x}`'.format(
542-
what=what,
543-
subquery=self.make_sql(),
544-
alias=next(self.__subquery_alias_count))).fetchone()[0]
549+
'SELECT count(1) FROM ({sql}) `$sub`'.format(sql=self.make_sql())).fetchone()[0]
550+
551+
def __bool__(self):
552+
return bool(self.connection.query(
553+
'SELECT EXISTS({sql})'.format(sql=self.make_sql())))
545554

546555

547556
class Union(QueryExpression):
@@ -553,36 +562,52 @@ def create(cls, arg1, arg2):
553562
if inspect.isclass(arg2) and issubclass(arg2, QueryExpression):
554563
arg2 = arg2() # instantiate if a class
555564
if not isinstance(arg2, QueryExpression):
556-
raise DataJointError('A QueryExpression can only be unioned with another QueryExpression')
565+
raise DataJointError(
566+
"A QueryExpression can only be unioned with another QueryExpression")
557567
if arg1.connection != arg2.connection:
558-
raise DataJointError("Cannot operate on QueryExpressions originating from different connections.")
568+
raise DataJointError(
569+
"Cannot operate on QueryExpressions originating from different connections.")
559570
if set(arg1.primary_key) != set(arg2.primary_key):
560571
raise DataJointError("The operands of a union must share the same primary key.")
561572
if set(arg1.heading.secondary_attributes) & set(arg2.heading.secondary_attributes):
562-
raise DataJointError("The operands of a union must not share any secondary attributes.")
573+
raise DataJointError(
574+
"The operands of a union must not share any secondary attributes.")
563575
result = cls()
564576
result._connection = arg1.connection
565577
result._heading = arg1.heading.join(arg2.heading)
566578
result._support = [arg1, arg2]
567579
return result
568580

569-
def make_sql(self, select_fields=None):
581+
def make_sql(self):
570582
arg1, arg2 = self._support
571-
if not arg1.heading.secondary_attributes and not arg2.heading.secondary_attributes: # use UNION DISTINCT
572-
fields = select_fields or arg1.primary_key
573-
return "({sql1}) UNION ({sql2})".format(sql1=arg1.make_sql(fields), sql2=arg2.make_sql(fields))
574-
fields = select_fields or self.heading.names
583+
if not arg1.heading.secondary_attributes and not arg2.heading.secondary_attributes:
584+
# no secondary attributes: use UNION DISTINCT
585+
fields = arg1.primary_key
586+
return "({sql1}) UNION ({sql2})".format(
587+
sql1=arg1.make_sql(fields),
588+
sql2=arg2.make_sql(fields))
589+
# with secondary attributes, use union of left join with antijoin
590+
fields = self.heading.names
575591
sql1 = arg1.join(arg2, left=True).make_sql(fields)
576-
sql2 = (arg2 - arg1).proj(..., **{k: 'NULL' for k in arg1.heading.secondary_attributes}).make_sql(fields)
592+
sql2 = (arg2 - arg1).proj(
593+
..., **{k: 'NULL' for k in arg1.heading.secondary_attributes}).make_sql(fields)
577594
return "({sql1}) UNION ({sql2})".format(sql1=sql1, sql2=sql2)
578595

579596
def from_clause(self):
580-
"""In Union, the select clause can be used as the WHERE clause and make_sql() does not call from_clause"""
581-
return self.make_sql()
597+
""" The union does not use a FROM clause """
598+
assert False
599+
600+
def where_clause(self):
601+
""" The union does not use a WHERE clause """
602+
assert False
582603

583604
def __len__(self):
584-
return self.connection.query(
585-
'SELECT count(*) FROM ({sql}) `$sub`'.format(sql=self.make_sql())).fetchone()[0]
605+
return self.connection.query('SELECT count(1) FROM ({sql}) as `$sub`'.format(
606+
sql=self.make_sql())).fetchone()[0]
607+
608+
def __bool__(self):
609+
return bool(self.connection.query(
610+
'SELECT EXISTS({sql})'.format(sql=self.make_sql())))
586611

587612

588613
class U:
@@ -688,7 +713,8 @@ def aggr(self, group, **named_attributes):
688713
:return: The derived query expression
689714
"""
690715
if named_attributes.get('keep_all_rows', False):
691-
raise DataJointError('Cannot set keep_all_rows=True when aggregating on a universal set.')
716+
raise DataJointError(
717+
'Cannot set keep_all_rows=True when aggregating on a universal set.')
692718
return Aggregation.create(self, group=group, keep_all_rows=False).proj(**named_attributes)
693719

694720
aggregate = aggr # alias for aggr

datajoint/heading.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -148,13 +148,14 @@ def as_dtype(self):
148148
names=self.names,
149149
formats=[v.dtype for v in self.attributes.values()]))
150150

151-
def as_sql(self, fields):
151+
def as_sql(self, fields, include_aliases=True):
152152
"""
153153
represent heading as the SQL SELECT clause.
154154
"""
155-
return ','.join('`%s`' % name if self.attributes[name].attribute_expression is None
156-
else '%s as `%s`' % (self.attributes[name].attribute_expression, name)
157-
for name in fields)
155+
return ','.join(
156+
'`%s`' % name if self.attributes[name].attribute_expression is None
157+
else self.attributes[name].attribute_expression + (' as `%s`' % name if include_aliases else '')
158+
for name in fields)
158159

159160
def __iter__(self):
160161
return iter(self.attributes)

datajoint/table.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ def delete_quick(self, get_count=False):
320320
Deletes the table without cascading and without user prompt.
321321
If this table has populated dependent tables, this will fail.
322322
"""
323-
query = 'DELETE FROM ' + self.full_table_name + self.where_clause
323+
query = 'DELETE FROM ' + self.full_table_name + self.where_clause()
324324
self.connection.query(query)
325325
count = self.connection.query("SELECT ROW_COUNT()").fetchone()[0] if get_count else None
326326
self._log(query[:255])
@@ -564,7 +564,7 @@ def _update(self, attrname, value=None):
564564
full_table_name=self.from_clause(),
565565
attrname=attrname,
566566
placeholder=placeholder,
567-
where_clause=self.where_clause)
567+
where_clause=self.where_clause())
568568
self.connection.query(command, args=(value, ) if value is not None else ())
569569

570570
# --- private helper functions ----

tests/test_university.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,16 @@ def test_aggr():
8383
assert_true(len(avg_grade_per_course) == 45)
8484

8585
# GPA
86-
student_gpa = Student.aggr(Course * Grade * LetterGrade, gpa='round(sum(points*credits)/sum(credits), 2)')
86+
student_gpa = Student.aggr(
87+
Course * Grade * LetterGrade,
88+
gpa='round(sum(points*credits)/sum(credits), 2)')
8789
gpa = student_gpa.fetch('gpa')
8890
assert_true(len(gpa) == 261)
8991
assert_true(2 < gpa.mean() < 3)
92+
93+
# Sections in biology department with zero students in them
94+
section = (Section & {"dept": "BIOL"}).aggr(
95+
Enroll, n='count(student_id)', keep_all_rows=True) & 'n=0'
96+
assert_true(len(set(section.fetch('dept'))) == 1)
97+
assert_true(len(section) == 17)
98+
assert_true(bool(section))

0 commit comments

Comments
 (0)