diff --git a/sqlmypy.py b/sqlmypy.py index 9b66b7f..4482389 100644 --- a/sqlmypy.py +++ b/sqlmypy.py @@ -207,12 +207,13 @@ def model_hook(ctx: FunctionContext) -> Type: # Collect column names and types defined in the model # TODO: cache this? expected_types = {} # type: Dict[str, Type] - for name, sym in model.names.items(): - if isinstance(sym.node, Var) and isinstance(sym.node.type, Instance): - tp = sym.node.type - if tp.type.fullname() in (COLUMN_NAME, RELATIONSHIP_NAME): - assert len(tp.args) == 1 - expected_types[name] = tp.args[0] + for cls in model.mro[::-1]: + for name, sym in cls.names.items(): + if isinstance(sym.node, Var) and isinstance(sym.node.type, Instance): + tp = sym.node.type + if tp.type.fullname() in (COLUMN_NAME, RELATIONSHIP_NAME): + assert len(tp.args) == 1 + expected_types[name] = tp.args[0] assert len(ctx.arg_names) == 1 # only **kwargs in generated __init__ assert len(ctx.arg_types) == 1 diff --git a/test/test-data/sqlalchemy-plugin-features.test b/test/test-data/sqlalchemy-plugin-features.test index bcbbb3e..3eaab8d 100644 --- a/test/test-data/sqlalchemy-plugin-features.test +++ b/test/test-data/sqlalchemy-plugin-features.test @@ -50,6 +50,47 @@ class Base: ... [out] +[case testModelInitMixin] + +from sqlalchemy import Column, Integer, String, DateTime +from sqlalchemy.ext.declarative import declarative_base + +Base = declarative_base() + +class HasId: + id = Column(Integer, primary_key=True) + +class User(Base, HasId): + __tablename__ = 'users' + + name = Column(String, nullable=False) + +user = User(id=123, name="John Doe") +reveal_type(user.id) # N: Revealed type is 'builtins.int*' +[out] + +[case testModelInitProperMro] + +from sqlalchemy import Column, Integer, String, DateTime +from sqlalchemy.ext.declarative import declarative_base + +Base = declarative_base() + +class Defaults: + id = Column(Integer, primary_key=True) + +class User(Base, Defaults): + __tablename__ = 'users' + + # By default mypy will complain about Column[str] not being compatible with Column[int]. + # Adding "ignore" should allow us to override column type. + id = Column(String, primary_key=True) # type: ignore + name = Column(String, nullable=False) + +User(id='stringish-id') +User(id=123) # E: Incompatible type for "id" of "User" (got "int", expected "str") +[out] + [case testModelInitRelationship] from typing import TYPE_CHECKING, List