diff --git a/sqlmypy.py b/sqlmypy.py index 9b66b7f..55a01b4 100644 --- a/sqlmypy.py +++ b/sqlmypy.py @@ -3,10 +3,10 @@ Plugin, FunctionContext, ClassDefContext, DynamicClassDefContext, SemanticAnalyzerPluginInterface ) -from mypy.plugins.common import add_method +from mypy.plugins.common import add_method, _get_argument from mypy.nodes import ( NameExpr, Expression, StrExpr, TypeInfo, ClassDef, Block, SymbolTable, SymbolTableNode, GDEF, - Argument, Var, ARG_STAR2, MDEF, TupleExpr, RefExpr + Argument, Var, ARG_STAR2, MDEF, TupleExpr, RefExpr, AssignmentStmt, CallExpr, MemberExpr ) from mypy.types import ( UnionType, NoneTyp, Instance, Type, AnyType, TypeOfAny, UninhabitedType, CallableType @@ -25,6 +25,7 @@ COLUMN_ELEMENT_NAME = 'sqlalchemy.sql.elements.ColumnElement' # type: Final GROUPING_NAME = 'sqlalchemy.sql.elements.Grouping' # type: Final RELATIONSHIP_NAME = 'sqlalchemy.orm.relationships.RelationshipProperty' # type: Final +FOREIGN_KEY_NAME = 'sqlalchemy.sql.schema.ForeignKey' # type: Final def is_declarative(info: TypeInfo) -> bool: @@ -110,6 +111,53 @@ def add_model_init_hook(ctx: ClassDefContext) -> None: add_method(ctx, '__init__', [kw_arg], NoneTyp()) ctx.cls.info.metadata.setdefault('sqlalchemy', {})['generated_init'] = True + for stmt in ctx.cls.defs.body: + if not (isinstance(stmt, AssignmentStmt) and len(stmt.lvalues) == 1 and isinstance(stmt.lvalues[0], NameExpr)): + continue + + # We currently only handle setting __tablename__ as a class attribute, and not through a property. + if stmt.lvalues[0].name == "__tablename__" and isinstance(stmt.rvalue, StrExpr): + ctx.cls.info.metadata.setdefault('sqlalchemy', {})['table_name'] = stmt.rvalue.value + + if (isinstance(stmt.rvalue, CallExpr) and isinstance(stmt.rvalue.callee, NameExpr) + and stmt.rvalue.callee.fullname == COLUMN_NAME): + # Save columns. The name of a column on the db side can be different from the one inside the SA model. + sa_column_name = stmt.lvalues[0].name + + db_column_name = None # type: Optional[str] + if 'name' in stmt.rvalue.arg_names: + name_str_expr = stmt.rvalue.args[stmt.rvalue.arg_names.index('name')] + assert isinstance(name_str_expr, StrExpr) + db_column_name = name_str_expr.value + else: + if len(stmt.rvalue.args) >= 1 and isinstance(stmt.rvalue.args[0], StrExpr): + db_column_name = stmt.rvalue.args[0].value + + ctx.cls.info.metadata.setdefault('sqlalchemy', {}).setdefault('columns', []).append( + {"sa_name": sa_column_name, "db_name": db_column_name or sa_column_name} + ) + + # Save foreign keys. + for arg in stmt.rvalue.args: + if (isinstance(arg, CallExpr) and isinstance(arg.callee, NameExpr) + and arg.callee.fullname == FOREIGN_KEY_NAME and len(arg.args) >= 1): + fk = arg.args[0] + if isinstance(fk, StrExpr): + *r, parent_table_name, parent_db_col_name = fk.value.split(".") + assert len(r) <= 1 + ctx.cls.info.metadata.setdefault('sqlalchemy', {}).setdefault('foreign_keys', + {})[sa_column_name] = { + "db_name": parent_db_col_name, + "table_name": parent_table_name, + "schema": r[0] if r else None + } + elif isinstance(fk, MemberExpr) and isinstance(fk.expr, NameExpr): + ctx.cls.info.metadata.setdefault('sqlalchemy', {}).setdefault('foreign_keys', + {})[sa_column_name] = { + "sa_name": fk.name, + "model_fullname": fk.expr.fullname + } + # Also add a selection of auto-generated attributes. sym = ctx.api.lookup_fully_qualified_or_none('sqlalchemy.sql.schema.Table') if sym: @@ -317,6 +365,55 @@ def grouping_hook(ctx: FunctionContext) -> Type: return ctx.default_return_type +class IncompleteModelMetadata(Exception): + pass + + +def has_foreign_keys(local_model: TypeInfo, remote_model: TypeInfo) -> bool: + """Tells if `local_model` has a fk to `remote_model`. + Will raise an `IncompleteModelMetadata` if some mandatory metadata is missing. + """ + local_metadata = local_model.metadata.get("sqlalchemy", {}) + remote_metadata = remote_model.metadata.get("sqlalchemy", {}) + + for fk in local_metadata.get("foreign_keys", {}).values(): + if 'model_fullname' in fk and remote_model.fullname() == fk['model_fullname']: + return True + if 'table_name' in fk: + if 'table_name' not in remote_metadata: + raise IncompleteModelMetadata + # TODO: handle different schemas. + # It's not straightforward because schema can be specified in `__table_args__` or in metadata for example + if remote_metadata['table_name'] == fk['table_name']: + return True + + return False + + +def is_relationship_iterable(ctx: FunctionContext, local_model: TypeInfo, remote_model: TypeInfo) -> bool: + """Tries to guess if the relationship is onetoone/onetomany/manytoone. + + Currently we handle the most current case, where a model relates to the other one through a relationship. + We also handle cases where secondaryjoin argument is provided. + We don't handle advanced usecases (foreign keys on both sides, primaryjoin, etc.). + """ + secondaryjoin = get_argument_by_name(ctx, 'secondaryjoin') + + if secondaryjoin is not None: + return True + + try: + can_be_many_to_one = has_foreign_keys(local_model, remote_model) + can_be_one_to_many = has_foreign_keys(remote_model, local_model) + + if not can_be_many_to_one and can_be_one_to_many: + return True + except IncompleteModelMetadata: + pass + + return False # Assume relationship is not iterable, if we weren't able to guess better. + + def relationship_hook(ctx: FunctionContext) -> Type: """Support basic use cases for relationships. @@ -369,10 +466,18 @@ class User(Base): # Something complex, stay silent for now. new_arg = AnyType(TypeOfAny.special_form) + # use private api + current_model = ctx.api.scope.active_class() # type: ignore # type: TypeInfo + assert current_model is not None + + # TODO: handle backref relationships + # We figured out, the model type. Now check if we need to wrap it in Iterable if uselist_arg: if parse_bool(uselist_arg): new_arg = ctx.api.named_generic_type('typing.Iterable', [new_arg]) + elif isinstance(new_arg, Instance) and is_relationship_iterable(ctx, current_model, new_arg.type): + new_arg = ctx.api.named_generic_type('typing.Iterable', [new_arg]) else: if has_annotation: # If there is an annotation we use it as a source of truth. @@ -387,10 +492,10 @@ class User(Base): # We really need to add this to TypeChecker API def parse_bool(expr: Expression) -> Optional[bool]: if isinstance(expr, NameExpr): - if expr.fullname == 'builtins.True': - return True - if expr.fullname == 'builtins.False': - return False + if expr.fullname == 'builtins.True': + return True + if expr.fullname == 'builtins.False': + return False return None diff --git a/test/test-data/sqlalchemy-plugin-features.test b/test/test-data/sqlalchemy-plugin-features.test index bcbbb3e..5f0b0f6 100644 --- a/test/test-data/sqlalchemy-plugin-features.test +++ b/test/test-data/sqlalchemy-plugin-features.test @@ -280,3 +280,93 @@ class M2(M1): Base = declarative_base(cls=(M1, M2)) # E: Not able to calculate MRO for declarative base reveal_type(Base) # E: Revealed type is 'Any' [out] + +[case testRelationshipIsGuessed] +from sqlalchemy import Column, Integer, String, ForeignKey +from sqlalchemy.orm import relationship +from sqlalchemy.ext.declarative import declarative_base + +Base = declarative_base() + +class Parent(Base): + __tablename__ = 'parents' + id = Column(Integer, primary_key=True) + name = Column(String) + + children = relationship("Child") + +class Child(Base): + __tablename__ = 'children' + id = Column(Integer, primary_key=True) + name = Column(String) + parent_id = Column(Integer, ForeignKey(Parent.id)) + + parent = relationship(Parent) + +child: Child +parent: Parent + +reveal_type(child.parent) # E: Revealed type is 'main.Parent*' +reveal_type(parent.children) # E: Revealed type is 'typing.Iterable*[main.Child]' + +[out] + +[case testRelationshipIsGuessed2] +from sqlalchemy import Column, Integer, String, ForeignKey +from sqlalchemy.orm import relationship +from sqlalchemy.ext.declarative import declarative_base + +Base = declarative_base() + +class Parent(Base): + __tablename__ = 'parents' + id = Column(Integer, primary_key=True) + name = Column(String) + + children = relationship("Child") + +class Child(Base): + __tablename__ = 'children' + id = Column(Integer, primary_key=True) + name = Column(String) + parent_id = Column(Integer, ForeignKey("parents.id")) + + parent = relationship(Parent) + +child: Child +parent: Parent + +reveal_type(child.parent) # E: Revealed type is 'main.Parent*' +reveal_type(parent.children) # E: Revealed type is 'typing.Iterable*[main.Child]' + +[out] + +[case testRelationshipIsGuessed3] +from sqlalchemy import Column, Integer, String, ForeignKey +from sqlalchemy.orm import relationship +from sqlalchemy.ext.declarative import declarative_base + +Base = declarative_base() + +class Parent(Base): + __tablename__ = 'parents' + id = Column(Integer, primary_key=True) + name = Column(String) + + children = relationship("Child") + +class Child(Base): + __tablename__ = 'children' + id = Column(Integer, primary_key=True) + name = Column(String) + parent_id = Column(Integer, ForeignKey("other_parents.id")) + + parent = relationship(Parent) + +child: Child +parent: Parent + +reveal_type(child.parent) # E: Revealed type is 'main.Parent*' +reveal_type(parent.children) # E: Revealed type is 'main.Child*' + +[out] \ No newline at end of file