@@ -93,9 +93,9 @@ def from_clause(self):
93
93
using = "" if not a else " USING (%s)" % "," .join ('`%s`' % _ for _ in a ))
94
94
return clause
95
95
96
- @property
97
96
def where_clause (self ):
98
- return '' if not self .restriction else ' WHERE(%s)' % ')AND(' .join (str (s ) for s in self .restriction )
97
+ return '' if not self .restriction else ' WHERE(%s)' % ')AND(' .join (
98
+ str (s ) for s in self .restriction )
99
99
100
100
def make_sql (self , fields = None ):
101
101
"""
@@ -106,7 +106,7 @@ def make_sql(self, fields=None):
106
106
return 'SELECT {distinct}{fields} FROM {from_}{where}' .format (
107
107
distinct = "DISTINCT " if distinct else "" ,
108
108
fields = self .heading .as_sql (fields or self .heading .names ),
109
- from_ = self .from_clause (), where = self .where_clause )
109
+ from_ = self .from_clause (), where = self .where_clause () )
110
110
111
111
# --------- query operators -----------
112
112
def make_subquery (self ):
@@ -180,7 +180,7 @@ def restrict(self, restriction):
180
180
result = self .make_subquery ()
181
181
else :
182
182
result = copy .copy (self )
183
- result ._restriction = AndList (self .restriction ) # make a copy to protect the original
183
+ result ._restriction = AndList (self .restriction ) # copy to preserve the original
184
184
result .restriction .append (new_condition )
185
185
result .restriction_attributes .update (attributes )
186
186
return result
@@ -421,25 +421,27 @@ def tail(self, limit=25, **fetch_kwargs):
421
421
422
422
def __len__ (self ):
423
423
""" :return: number of elements in the result set """
424
- what = '*' if set (self .heading .names ) != set (self .primary_key ) else 'DISTINCT %s' % ',' .join (
425
- (self .heading [k ].attribute_expression or '`%s`' % k for k in self .primary_key ))
426
424
return self .connection .query (
427
- 'SELECT count({what }) FROM {from_}{where}' .format (
428
- what = what ,
425
+ 'SELECT count(DISTINCT {fields }) FROM {from_}{where}' .format (
426
+ fields = self . heading . as_sql ( self . primary_key , include_aliases = False ) ,
429
427
from_ = self .from_clause (),
430
- where = self .where_clause )).fetchone ()[0 ]
428
+ where = self .where_clause () )).fetchone ()[0 ]
431
429
432
430
def __bool__ (self ):
433
431
"""
434
- :return: True if the result is not empty. Equivalent to len(rel)> 0 but may be more efficient .
432
+ :return: True if the result is not empty. Equivalent to len(self) > 0 but often faster .
435
433
"""
436
- return len (self ) > 0
434
+ return bool (self .connection .query (
435
+ 'SELECT EXISTS(SELECT 1 FROM {from_}{where})' .format (
436
+ from_ = self .from_clause (),
437
+ where = self .where_clause ())).fetchone ()[0 ])
437
438
438
439
def __contains__ (self , item ):
439
440
"""
440
441
returns True if item is found in the .
441
442
:param item: any restriction
442
- (item in query_expression) is equivalent to bool(query_expression & item) but may be executed more efficiently.
443
+ (item in query_expression) is equivalent to bool(query_expression & item) but may be
444
+ executed more efficiently.
443
445
"""
444
446
return bool (self & item ) # May be optimized e.g. using an EXISTS query
445
447
@@ -453,7 +455,8 @@ def __next__(self):
453
455
key = self ._iter_keys .pop (0 )
454
456
except AttributeError :
455
457
# self._iter_keys is missing because __iter__ has not been called.
456
- raise TypeError ("'QueryExpression' object is not an iterator. Use iter(obj) to create an iterator." )
458
+ raise TypeError ("A QueryExpression object is not an iterator. "
459
+ "Use iter(obj) to create an iterator." )
457
460
except IndexError :
458
461
raise StopIteration
459
462
else :
@@ -463,7 +466,8 @@ def __next__(self):
463
466
try :
464
467
return (self & key ).fetch1 ()
465
468
except DataJointError :
466
- # The data may have been deleted since the moment the keys were fetched -- move on to next entry.
469
+ # The data may have been deleted since the moment the keys were fetched
470
+ # -- move on to next entry.
467
471
return next (self )
468
472
469
473
def cursor (self , offset = 0 , limit = None , order_by = None , as_dict = False ):
@@ -495,8 +499,10 @@ def _repr_html_(self):
495
499
496
500
class Aggregation (QueryExpression ):
497
501
"""
498
- Aggregation(rel, comp1='expr1', ..., compn='exprn') yields an entity set with the primary key specified by rel.heading.
499
- The computed arguments comp1, ..., compn use aggregation operators on the attributes of rel.
502
+ Aggregation.create(arg, group, comp1='calc1', ..., compn='calcn') yields an entity set
503
+ with primary key from arg.
504
+ The computed arguments comp1, ..., compn use aggregation calculations on the attributes of
505
+ group or simple projections and calculations on the attributes of arg.
500
506
Aggregation is used QueryExpression.aggr and U.aggr.
501
507
Aggregation is a private class in DataJoint, not exposed to users.
502
508
"""
@@ -517,31 +523,34 @@ def create(cls, arg, group, keep_all_rows=False):
517
523
result ._support = join .support
518
524
result ._join_attributes = join ._join_attributes
519
525
result ._left = join ._left
520
- result .initial_restriction = join .restriction # WHERE clause applied before GROUP BY
526
+ result ._left_restrict = join .restriction # WHERE clause applied before GROUP BY
521
527
result ._grouping_attributes = result .primary_key
522
528
return result
523
529
530
+ def where_clause (self ):
531
+ return '' if not self ._left_restrict else ' WHERE (%s)' % ')AND(' .join (
532
+ str (s ) for s in self ._left_restrict )
533
+
524
534
def make_sql (self , fields = None ):
525
- where = '' if not self ._left_restrict else ' WHERE (%s)' % ')AND(' .join (self ._left_restrict )
526
535
fields = self .heading .as_sql (fields or self .heading .names )
527
536
assert self ._grouping_attributes or not self .restriction
528
537
distinct = set (self .heading .names ) == set (self .primary_key )
529
538
return 'SELECT {distinct}{fields} FROM {from_}{where}{group_by}' .format (
530
539
distinct = "DISTINCT " if distinct else "" ,
531
540
fields = fields ,
532
541
from_ = self .from_clause (),
533
- where = where ,
542
+ where = self . where_clause () ,
534
543
group_by = "" if not self .primary_key else (
535
544
" GROUP BY `%s`" % '`,`' .join (self ._grouping_attributes ) +
536
545
("" if not self .restriction else ' HAVING (%s)' % ')AND(' .join (self .restriction ))))
537
546
538
547
def __len__ (self ):
539
- what = '*' if set (self .heading .names ) != set (self .primary_key ) else 'DISTINCT `%s`' % '`,`' .join (self .primary_key )
540
548
return self .connection .query (
541
- 'SELECT count({what}) FROM ({subquery}) as `_r{alias:x}`' .format (
542
- what = what ,
543
- subquery = self .make_sql (),
544
- alias = next (self .__subquery_alias_count ))).fetchone ()[0 ]
549
+ 'SELECT count(1) FROM ({sql}) `$sub`' .format (sql = self .make_sql ())).fetchone ()[0 ]
550
+
551
+ def __bool__ (self ):
552
+ return bool (self .connection .query (
553
+ 'SELECT EXISTS({sql})' .format (sql = self .make_sql ())))
545
554
546
555
547
556
class Union (QueryExpression ):
@@ -553,36 +562,52 @@ def create(cls, arg1, arg2):
553
562
if inspect .isclass (arg2 ) and issubclass (arg2 , QueryExpression ):
554
563
arg2 = arg2 () # instantiate if a class
555
564
if not isinstance (arg2 , QueryExpression ):
556
- raise DataJointError ('A QueryExpression can only be unioned with another QueryExpression' )
565
+ raise DataJointError (
566
+ "A QueryExpression can only be unioned with another QueryExpression" )
557
567
if arg1 .connection != arg2 .connection :
558
- raise DataJointError ("Cannot operate on QueryExpressions originating from different connections." )
568
+ raise DataJointError (
569
+ "Cannot operate on QueryExpressions originating from different connections." )
559
570
if set (arg1 .primary_key ) != set (arg2 .primary_key ):
560
571
raise DataJointError ("The operands of a union must share the same primary key." )
561
572
if set (arg1 .heading .secondary_attributes ) & set (arg2 .heading .secondary_attributes ):
562
- raise DataJointError ("The operands of a union must not share any secondary attributes." )
573
+ raise DataJointError (
574
+ "The operands of a union must not share any secondary attributes." )
563
575
result = cls ()
564
576
result ._connection = arg1 .connection
565
577
result ._heading = arg1 .heading .join (arg2 .heading )
566
578
result ._support = [arg1 , arg2 ]
567
579
return result
568
580
569
- def make_sql (self , select_fields = None ):
581
+ def make_sql (self ):
570
582
arg1 , arg2 = self ._support
571
- if not arg1 .heading .secondary_attributes and not arg2 .heading .secondary_attributes : # use UNION DISTINCT
572
- fields = select_fields or arg1 .primary_key
573
- return "({sql1}) UNION ({sql2})" .format (sql1 = arg1 .make_sql (fields ), sql2 = arg2 .make_sql (fields ))
574
- fields = select_fields or self .heading .names
583
+ if not arg1 .heading .secondary_attributes and not arg2 .heading .secondary_attributes :
584
+ # no secondary attributes: use UNION DISTINCT
585
+ fields = arg1 .primary_key
586
+ return "({sql1}) UNION ({sql2})" .format (
587
+ sql1 = arg1 .make_sql (fields ),
588
+ sql2 = arg2 .make_sql (fields ))
589
+ # with secondary attributes, use union of left join with antijoin
590
+ fields = self .heading .names
575
591
sql1 = arg1 .join (arg2 , left = True ).make_sql (fields )
576
- sql2 = (arg2 - arg1 ).proj (..., ** {k : 'NULL' for k in arg1 .heading .secondary_attributes }).make_sql (fields )
592
+ sql2 = (arg2 - arg1 ).proj (
593
+ ..., ** {k : 'NULL' for k in arg1 .heading .secondary_attributes }).make_sql (fields )
577
594
return "({sql1}) UNION ({sql2})" .format (sql1 = sql1 , sql2 = sql2 )
578
595
579
596
def from_clause (self ):
580
- """In Union, the select clause can be used as the WHERE clause and make_sql() does not call from_clause"""
581
- return self .make_sql ()
597
+ """ The union does not use a FROM clause """
598
+ assert False
599
+
600
+ def where_clause (self ):
601
+ """ The union does not use a WHERE clause """
602
+ assert False
582
603
583
604
def __len__ (self ):
584
- return self .connection .query (
585
- 'SELECT count(*) FROM ({sql}) `$sub`' .format (sql = self .make_sql ())).fetchone ()[0 ]
605
+ return self .connection .query ('SELECT count(1) FROM ({sql}) as `$sub`' .format (
606
+ sql = self .make_sql ())).fetchone ()[0 ]
607
+
608
+ def __bool__ (self ):
609
+ return bool (self .connection .query (
610
+ 'SELECT EXISTS({sql})' .format (sql = self .make_sql ())))
586
611
587
612
588
613
class U :
@@ -688,7 +713,8 @@ def aggr(self, group, **named_attributes):
688
713
:return: The derived query expression
689
714
"""
690
715
if named_attributes .get ('keep_all_rows' , False ):
691
- raise DataJointError ('Cannot set keep_all_rows=True when aggregating on a universal set.' )
716
+ raise DataJointError (
717
+ 'Cannot set keep_all_rows=True when aggregating on a universal set.' )
692
718
return Aggregation .create (self , group = group , keep_all_rows = False ).proj (** named_attributes )
693
719
694
720
aggregate = aggr # alias for aggr
0 commit comments