Skip to content

Commit 91e8ca9

Browse files
authored
Add a fake __eq__ method to relationship() to avoid false positives with --strict-equality (dropbox#92)
1 parent f6c1105 commit 91e8ca9

File tree

3 files changed

+32
-0
lines changed

3 files changed

+32
-0
lines changed

sqlalchemy-stubs/orm/relationships.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ class RelationshipProperty(StrategizedProperty, Generic[_T_co]):
8787
def __ne__(self, other): ...
8888
@property
8989
def property(self): ...
90+
# This doesn't exist at runtime, and Comparator is used instead, but it is hard to explain to mypy.
91+
def __eq__(self, other: Any) -> Any: ...
9092
def merge(self, session, source_state, source_dict, dest_state, dest_dict,
9193
load, _recursive, _resolve_conflict_map): ...
9294
def cascade_iterator(self, *args, **kwargs): ...

test/test-data/sqlalchemy-basics.test

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,3 +120,26 @@ user = User()
120120
reveal_type(user.id) # N: Revealed type is 'builtins.int*'
121121
reveal_type(User.name) # N: Revealed type is 'sqlalchemy.sql.schema.Column[builtins.unicode*]'
122122
[out]
123+
124+
[case testRelationshipStrictEquality]
125+
# flags: --strict-equality
126+
from sqlalchemy.ext.declarative import declarative_base
127+
from sqlalchemy import Column, Integer, String
128+
from sqlalchemy.orm import relationship
129+
from sqlalchemy.orm import Session
130+
131+
Base = declarative_base()
132+
session = Session()
133+
134+
class User(Base):
135+
__tablename__ = 'users'
136+
id = Column(Integer(), primary_key=True)
137+
other = relationship('Other')
138+
139+
class Other(Base):
140+
__tablename__ = 'other'
141+
id = Column(Integer(), primary_key=True)
142+
143+
other: Other
144+
session.query(User).filter(User.other == other)
145+
[out]

test/testsql.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
import os.path
55
import sys
6+
import re
67

78
import pytest # type: ignore # no pytest in typeshed
89

@@ -54,6 +55,12 @@ def run_case(self, testcase: DataDrivenTestCase) -> None:
5455
version = sys.version_info[:2]
5556
mypy_cmdline.append('--python-version={}'.format('.'.join(map(str, version))))
5657

58+
program_text = '\n'.join(testcase.input)
59+
flags = re.search('# flags: (.*)$', program_text, flags=re.MULTILINE)
60+
if flags:
61+
flag_list = flags.group(1).split()
62+
mypy_cmdline.extend(flag_list)
63+
5764
# Write the program to a file.
5865
program_path = os.path.join(test_temp_dir, 'main.py')
5966
mypy_cmdline.append(program_path)

0 commit comments

Comments
 (0)