Skip to content

Commit 25c3ef3

Browse files
author
Mehdi
committed
mypy fixes
1 parent ee7baf3 commit 25c3ef3

File tree

1 file changed

+11
-8
lines changed

1 file changed

+11
-8
lines changed

sqlmypy.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -66,17 +66,17 @@ def get_function_hook(self, fullname: str) -> Optional[Callable[[FunctionContext
6666
return model_hook
6767
return None
6868

69-
def get_dynamic_class_hook(self, fullname: str):
69+
def get_dynamic_class_hook(self, fullname: str) -> CB[DynamicClassDefContext]:
7070
if fullname == 'sqlalchemy.ext.declarative.api.declarative_base':
7171
return decl_info_hook
7272
return None
7373

74-
def get_class_decorator_hook(self, fullname: str):
74+
def get_class_decorator_hook(self, fullname: str) -> CB[ClassDefContext]:
7575
if fullname == 'sqlalchemy.ext.declarative.api.as_declarative':
7676
return decl_deco_hook
7777
return None
7878

79-
def get_base_class_hook(self, fullname: str):
79+
def get_base_class_hook(self, fullname: str) -> CB[ClassDefContext]:
8080
sym = self.lookup_fully_qualified(fullname)
8181
if sym and isinstance(sym.node, TypeInfo):
8282
if is_declarative(sym.node):
@@ -119,7 +119,8 @@ def add_model_init_hook(ctx: ClassDefContext) -> None:
119119
if stmt.lvalues[0].name == "__tablename__" and isinstance(stmt.rvalue, StrExpr):
120120
ctx.cls.info.metadata.setdefault('sqlalchemy', {})['table_name'] = stmt.rvalue.value
121121

122-
if isinstance(stmt.rvalue, CallExpr) and stmt.rvalue.callee.fullname == COLUMN_NAME:
122+
if (isinstance(stmt.rvalue, CallExpr) and isinstance(stmt.rvalue.callee, NameExpr)
123+
and stmt.rvalue.callee.fullname == COLUMN_NAME):
123124
# Save columns. The name of a column on the db side can be different from the one inside the SA model.
124125
sa_column_name = stmt.lvalues[0].name
125126

@@ -138,7 +139,8 @@ def add_model_init_hook(ctx: ClassDefContext) -> None:
138139

139140
# Save foreign keys.
140141
for arg in stmt.rvalue.args:
141-
if isinstance(arg, CallExpr) and arg.callee.fullname == FOREIGN_KEY_NAME and len(arg.args) >= 1:
142+
if (isinstance(arg, CallExpr) and isinstance(arg.callee, NameExpr)
143+
and arg.callee.fullname == FOREIGN_KEY_NAME and len(arg.args) >= 1):
142144
fk = arg.args[0]
143145
if isinstance(fk, StrExpr):
144146
*r, parent_table_name, parent_db_col_name = fk.value.split(".")
@@ -149,7 +151,7 @@ def add_model_init_hook(ctx: ClassDefContext) -> None:
149151
"table_name": parent_table_name,
150152
"schema": r[0] if r else None
151153
}
152-
elif isinstance(fk, MemberExpr):
154+
elif isinstance(fk, MemberExpr) and isinstance(fk.expr, NameExpr):
153155
ctx.cls.info.metadata.setdefault('sqlalchemy', {}).setdefault('foreign_keys',
154156
{})[sa_column_name] = {
155157
"sa_name": fk.name,
@@ -463,7 +465,8 @@ class User(Base):
463465
# Something complex, stay silent for now.
464466
new_arg = AnyType(TypeOfAny.special_form)
465467

466-
current_model = ctx.api.scope.active_class()
468+
# use private api
469+
current_model = ctx.api.scope.active_class() # type: ignore # type: TypeInfo
467470
assert current_model is not None
468471

469472
# TODO: handle backref relationships
@@ -472,7 +475,7 @@ class User(Base):
472475
if uselist_arg:
473476
if parse_bool(uselist_arg):
474477
new_arg = ctx.api.named_generic_type('typing.Iterable', [new_arg])
475-
elif not isinstance(new_arg, AnyType) and is_relationship_iterable(ctx, current_model, new_arg.type):
478+
elif isinstance(new_arg, Instance) and is_relationship_iterable(ctx, current_model, new_arg.type):
476479
new_arg = ctx.api.named_generic_type('typing.Iterable', [new_arg])
477480
else:
478481
if has_annotation:

0 commit comments

Comments
 (0)