@@ -66,17 +66,17 @@ def get_function_hook(self, fullname: str) -> Optional[Callable[[FunctionContext
66
66
return model_hook
67
67
return None
68
68
69
- def get_dynamic_class_hook (self , fullname : str ):
69
+ def get_dynamic_class_hook (self , fullname : str ) -> CB [ DynamicClassDefContext ] :
70
70
if fullname == 'sqlalchemy.ext.declarative.api.declarative_base' :
71
71
return decl_info_hook
72
72
return None
73
73
74
- def get_class_decorator_hook (self , fullname : str ):
74
+ def get_class_decorator_hook (self , fullname : str ) -> CB [ ClassDefContext ] :
75
75
if fullname == 'sqlalchemy.ext.declarative.api.as_declarative' :
76
76
return decl_deco_hook
77
77
return None
78
78
79
- def get_base_class_hook (self , fullname : str ):
79
+ def get_base_class_hook (self , fullname : str ) -> CB [ ClassDefContext ] :
80
80
sym = self .lookup_fully_qualified (fullname )
81
81
if sym and isinstance (sym .node , TypeInfo ):
82
82
if is_declarative (sym .node ):
@@ -119,7 +119,8 @@ def add_model_init_hook(ctx: ClassDefContext) -> None:
119
119
if stmt .lvalues [0 ].name == "__tablename__" and isinstance (stmt .rvalue , StrExpr ):
120
120
ctx .cls .info .metadata .setdefault ('sqlalchemy' , {})['table_name' ] = stmt .rvalue .value
121
121
122
- if isinstance (stmt .rvalue , CallExpr ) and stmt .rvalue .callee .fullname == COLUMN_NAME :
122
+ if (isinstance (stmt .rvalue , CallExpr ) and isinstance (stmt .rvalue .callee , NameExpr )
123
+ and stmt .rvalue .callee .fullname == COLUMN_NAME ):
123
124
# Save columns. The name of a column on the db side can be different from the one inside the SA model.
124
125
sa_column_name = stmt .lvalues [0 ].name
125
126
@@ -138,7 +139,8 @@ def add_model_init_hook(ctx: ClassDefContext) -> None:
138
139
139
140
# Save foreign keys.
140
141
for arg in stmt .rvalue .args :
141
- if isinstance (arg , CallExpr ) and arg .callee .fullname == FOREIGN_KEY_NAME and len (arg .args ) >= 1 :
142
+ if (isinstance (arg , CallExpr ) and isinstance (arg .callee , NameExpr )
143
+ and arg .callee .fullname == FOREIGN_KEY_NAME and len (arg .args ) >= 1 ):
142
144
fk = arg .args [0 ]
143
145
if isinstance (fk , StrExpr ):
144
146
* r , parent_table_name , parent_db_col_name = fk .value .split ("." )
@@ -149,7 +151,7 @@ def add_model_init_hook(ctx: ClassDefContext) -> None:
149
151
"table_name" : parent_table_name ,
150
152
"schema" : r [0 ] if r else None
151
153
}
152
- elif isinstance (fk , MemberExpr ):
154
+ elif isinstance (fk , MemberExpr ) and isinstance ( fk . expr , NameExpr ) :
153
155
ctx .cls .info .metadata .setdefault ('sqlalchemy' , {}).setdefault ('foreign_keys' ,
154
156
{})[sa_column_name ] = {
155
157
"sa_name" : fk .name ,
@@ -463,7 +465,8 @@ class User(Base):
463
465
# Something complex, stay silent for now.
464
466
new_arg = AnyType (TypeOfAny .special_form )
465
467
466
- current_model = ctx .api .scope .active_class ()
468
+ # use private api
469
+ current_model = ctx .api .scope .active_class () # type: ignore # type: TypeInfo
467
470
assert current_model is not None
468
471
469
472
# TODO: handle backref relationships
@@ -472,7 +475,7 @@ class User(Base):
472
475
if uselist_arg :
473
476
if parse_bool (uselist_arg ):
474
477
new_arg = ctx .api .named_generic_type ('typing.Iterable' , [new_arg ])
475
- elif not isinstance (new_arg , AnyType ) and is_relationship_iterable (ctx , current_model , new_arg .type ):
478
+ elif isinstance (new_arg , Instance ) and is_relationship_iterable (ctx , current_model , new_arg .type ):
476
479
new_arg = ctx .api .named_generic_type ('typing.Iterable' , [new_arg ])
477
480
else :
478
481
if has_annotation :
0 commit comments