Skip to content

✨ Add support for SQLAlchemy polymorphic models #1226

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 26 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
6d93a46
support sqlalchemy polymorphic
Nov 26, 2024
589237b
improve docs
Nov 26, 2024
4071b0f
fix polymorphic_on check
Nov 26, 2024
48f2a88
fix polymorphic_on check
Nov 26, 2024
e6ad74d
fix lint
Nov 26, 2024
277953a
fix pydantic v1 support
Nov 26, 2024
4aade03
fix type hint for <3.10
Nov 26, 2024
a3044bb
add needs_pydanticv2 mark to test
Nov 26, 2024
015601c
improve code structure
Dec 3, 2024
66c1d93
lint
Dec 3, 2024
0efd1bf
remove effort of pydantic v1
Dec 3, 2024
7173b44
Merge branch 'main' into sqlalchemy_polymorphic_support
PaleNeutron Dec 10, 2024
84d739e
Update sqlmodel/_compat.py
PaleNeutron Dec 12, 2024
dbd0101
fix default value is InstrumentedAttribute in inherit
Feb 5, 2025
88670a5
Merge branch 'sqlalchemy_polymorphic_support' of https://github.com/P…
Feb 5, 2025
b1ed8c3
fix inherit order
Feb 5, 2025
5d1bf5c
support python < 3.9
Feb 5, 2025
ccbb92a
🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
pre-commit-ci[bot] Feb 5, 2025
d0d0288
skip polymorphic in pydantic v1
Feb 5, 2025
525373d
Merge branch 'sqlalchemy_polymorphic_support' of https://github.com/P…
Feb 5, 2025
ab965f0
Merge branch 'main' into sqlalchemy_polymorphic_support
PaleNeutron Feb 5, 2025
68963e7
Merge branch 'main' into sqlalchemy_polymorphic_support
svlandeg Feb 20, 2025
95c6a1e
Merge branch 'main' into sqlalchemy_polymorphic_support
svlandeg Feb 24, 2025
c1dff79
disable pydantic warning during polymorphic
Mar 12, 2025
7d333fe
fix relationship problem of parent class during polymorphic inherit
May 15, 2025
32ebd6f
avoid add ClassVar multiple times
May 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions sqlmodel/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from pydantic import VERSION as P_VERSION
from pydantic import BaseModel
from pydantic.fields import FieldInfo
from sqlalchemy import inspect
from sqlalchemy.orm import Mapper
from typing_extensions import Annotated, get_args, get_origin

# Reassign variable to make it reexported for mypy
Expand Down Expand Up @@ -64,6 +66,35 @@ def _is_union_type(t: Any) -> bool:
finish_init: ContextVar[bool] = ContextVar("finish_init", default=True)


def set_polymorphic_default_value(
self_instance: _TSQLModel,
values: Dict[str, Any],
) -> bool:
"""By default, when init a model, pydantic will set the polymorphic_on
value to field default value. But when inherit a model, the polymorphic_on
should be set to polymorphic_identity value by default."""
cls = type(self_instance)
mapper = inspect(cls)
ret = False
if isinstance(mapper, Mapper):
polymorphic_on = mapper.polymorphic_on
if polymorphic_on is not None:
polymorphic_property = mapper.get_property_by_column(polymorphic_on)
field_info = get_model_fields(cls).get(polymorphic_property.key)
if field_info:
v = values.get(polymorphic_property.key)
# if model is inherited or polymorphic_on is not explicitly set
# set the polymorphic_on by default
if mapper.inherits or v is None:
setattr(
self_instance,
polymorphic_property.key,
mapper.polymorphic_identity,
)
ret = True
return ret


@contextmanager
def partial_init() -> Generator[None, None, None]:
token = finish_init.set(False)
Expand Down Expand Up @@ -290,6 +321,8 @@ def sqlmodel_table_construct(
if value is not Undefined:
setattr(self_instance, key, value)
# End SQLModel override
# Override polymorphic_on default value
set_polymorphic_default_value(self_instance, values)
return self_instance

def sqlmodel_validate(
Expand Down
71 changes: 61 additions & 10 deletions sqlmodel/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import ipaddress
import uuid
import warnings
import weakref
from datetime import date, datetime, time, timedelta
from decimal import Decimal
Expand Down Expand Up @@ -41,9 +42,10 @@
)
from sqlalchemy import Enum as sa_Enum
from sqlalchemy.orm import (
InstrumentedAttribute,
Mapped,
MappedColumn,
RelationshipProperty,
declared_attr,
registry,
relationship,
)
Expand Down Expand Up @@ -538,7 +540,42 @@ def __new__(
config_kwargs = {
key: kwargs[key] for key in kwargs.keys() & allowed_config_kwargs
}
new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs)
is_polymorphic = False
if IS_PYDANTIC_V2:
base_fields = {}
base_annotations = {}
for base in bases[::-1]:
if issubclass(base, BaseModel):
base_fields.update(get_model_fields(base))
base_annotations.update(base.__annotations__)
if hasattr(base, "__sqlmodel_relationships__"):
for k in base.__sqlmodel_relationships__:
# create a dummy attribute to avoid inherit
# pydantic will treat it as class variables, and will not become fields on model instances
anno = base_annotations.get(k, Any)
if get_origin(anno) is not ClassVar:
dummy_anno = ClassVar[anno]
dict_used["__annotations__"][k] = dummy_anno

if hasattr(base, "__tablename__"):
is_polymorphic = True
# use base_fields overwriting the ones from the class for inherit
# if base is a sqlalchemy model, it's attributes will be an InstrumentedAttribute
# thus pydantic will use the value of the attribute as the default value
base_annotations.update(dict_used["__annotations__"])
dict_used["__annotations__"] = base_annotations
base_fields.update(dict_used)
dict_used = base_fields
# if is_polymorphic, disable pydantic `shadows an attribute` warning
if is_polymorphic:
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="Field name .+ shadows an attribute in parent.+",
)
new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs)
else:
new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs)
new_cls.__annotations__ = {
**relationship_annotations,
**pydantic_annotations,
Expand All @@ -558,9 +595,22 @@ def get_config(name: str) -> Any:

config_table = get_config("table")
if config_table is True:
# sqlalchemy mark a class as table by check if it has __tablename__ attribute
# or if __tablename__ is in __annotations__. Only set __tablename__ if it's
# a table model
if new_cls.__name__ != "SQLModel" and not hasattr(new_cls, "__tablename__"):
setattr(new_cls, "__tablename__", new_cls.__name__.lower()) # noqa: B010
# If it was passed by kwargs, ensure it's also set in config
set_config_value(model=new_cls, parameter="table", value=config_table)
for k, v in get_model_fields(new_cls).items():
original_v = getattr(new_cls, k, None)
if (
isinstance(original_v, InstrumentedAttribute)
and k not in class_dict
):
# The attribute was already set by SQLAlchemy, don't override it
# Needed for polymorphic models, see #36
continue
col = get_column_from_field(v)
setattr(new_cls, k, col)
# Set a config flag to tell FastAPI that this should be read with a field
Expand Down Expand Up @@ -594,7 +644,13 @@ def __init__(
# trying to create a new SQLAlchemy, for a new table, with the same name, that
# triggers an error
base_is_table = any(is_table_model_class(base) for base in bases)
if is_table_model_class(cls) and not base_is_table:
polymorphic_identity = dict_.get("__mapper_args__", {}).get(
"polymorphic_identity"
)
has_polymorphic = polymorphic_identity is not None

# allow polymorphic models inherit from table models
if is_table_model_class(cls) and (not base_is_table or has_polymorphic):
for rel_name, rel_info in cls.__sqlmodel_relationships__.items():
if rel_info.sa_relationship:
# There's a SQLAlchemy relationship declared, that takes precedence
Expand Down Expand Up @@ -702,13 +758,13 @@ def get_sqlalchemy_type(field: Any) -> Any:
raise ValueError(f"{type_} has no matching SQLAlchemy type")


def get_column_from_field(field: Any) -> Column: # type: ignore
def get_column_from_field(field: Any) -> Union[Column, MappedColumn]: # type: ignore
if IS_PYDANTIC_V2:
field_info = field
else:
field_info = field.field_info
sa_column = getattr(field_info, "sa_column", Undefined)
if isinstance(sa_column, Column):
if isinstance(sa_column, Column) or isinstance(sa_column, MappedColumn):
return sa_column
sa_type = get_sqlalchemy_type(field)
primary_key = getattr(field_info, "primary_key", Undefined)
Expand Down Expand Up @@ -772,7 +828,6 @@ def get_column_from_field(field: Any) -> Column: # type: ignore
class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry):
# SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values
__slots__ = ("__weakref__",)
__tablename__: ClassVar[Union[str, Callable[..., str]]]
__sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty[Any]]]
__name__: ClassVar[str]
metadata: ClassVar[MetaData]
Expand Down Expand Up @@ -836,10 +891,6 @@ def __repr_args__(self) -> Sequence[Tuple[Optional[str], Any]]:
if not (isinstance(k, str) and k.startswith("_sa_"))
]

@declared_attr # type: ignore
def __tablename__(cls) -> str:
return cls.__name__.lower()

@classmethod
def model_validate(
cls: Type[_TSQLModel],
Expand Down
177 changes: 177 additions & 0 deletions tests/test_polymorphic_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
from typing import Optional

from sqlalchemy import ForeignKey
from sqlalchemy.orm import mapped_column
from sqlmodel import Field, Relationship, Session, SQLModel, create_engine, select

from tests.conftest import needs_pydanticv2


@needs_pydanticv2
def test_polymorphic_joined_table(clear_sqlmodel) -> None:
class Hero(SQLModel, table=True):
__tablename__ = "hero"
id: Optional[int] = Field(default=None, primary_key=True)
hero_type: str = Field(default="hero")

__mapper_args__ = {
"polymorphic_on": "hero_type",
"polymorphic_identity": "normal_hero",
}

class DarkHero(Hero):
__tablename__ = "dark_hero"
id: Optional[int] = Field(
default=None,
sa_column=mapped_column(ForeignKey("hero.id"), primary_key=True),
)
dark_power: str = Field(
default="dark",
sa_column=mapped_column(
nullable=False, use_existing_column=True, default="dark"
),
)

__mapper_args__ = {
"polymorphic_identity": "dark",
}

engine = create_engine("sqlite:///:memory:", echo=True)
SQLModel.metadata.create_all(engine)
with Session(engine) as db:
hero = Hero()
db.add(hero)
dark_hero = DarkHero()
db.add(dark_hero)
db.commit()
statement = select(DarkHero)
result = db.exec(statement).all()
assert len(result) == 1
assert isinstance(result[0].dark_power, str)


@needs_pydanticv2
def test_polymorphic_joined_table_with_sqlmodel_field(clear_sqlmodel) -> None:
class Hero(SQLModel, table=True):
__tablename__ = "hero"
id: Optional[int] = Field(default=None, primary_key=True)
hero_type: str = Field(default="hero")

__mapper_args__ = {
"polymorphic_on": "hero_type",
"polymorphic_identity": "normal_hero",
}

class DarkHero(Hero):
__tablename__ = "dark_hero"
id: Optional[int] = Field(
default=None,
primary_key=True,
foreign_key="hero.id",
)
dark_power: str = Field(
default="dark",
sa_column=mapped_column(
nullable=False, use_existing_column=True, default="dark"
),
)

__mapper_args__ = {
"polymorphic_identity": "dark",
}

engine = create_engine("sqlite:///:memory:", echo=True)
SQLModel.metadata.create_all(engine)
with Session(engine) as db:
hero = Hero()
db.add(hero)
dark_hero = DarkHero()
db.add(dark_hero)
db.commit()
statement = select(DarkHero)
result = db.exec(statement).all()
assert len(result) == 1
assert isinstance(result[0].dark_power, str)


@needs_pydanticv2
def test_polymorphic_single_table(clear_sqlmodel) -> None:
class Hero(SQLModel, table=True):
__tablename__ = "hero"
id: Optional[int] = Field(default=None, primary_key=True)
hero_type: str = Field(default="hero")

__mapper_args__ = {
"polymorphic_on": "hero_type",
"polymorphic_identity": "normal_hero",
}

class DarkHero(Hero):
dark_power: str = Field(
default="dark",
sa_column=mapped_column(
nullable=False, use_existing_column=True, default="dark"
),
)

__mapper_args__ = {
"polymorphic_identity": "dark",
}

engine = create_engine("sqlite:///:memory:", echo=True)
SQLModel.metadata.create_all(engine)
with Session(engine) as db:
hero = Hero()
db.add(hero)
dark_hero = DarkHero(dark_power="pokey")
db.add(dark_hero)
db.commit()
statement = select(DarkHero)
result = db.exec(statement).all()
assert len(result) == 1
assert isinstance(result[0].dark_power, str)


@needs_pydanticv2
def test_polymorphic_relationship(clear_sqlmodel) -> None:
class Tool(SQLModel, table=True):
__tablename__ = "tool_table"

id: int = Field(primary_key=True)

name: str

class Person(SQLModel, table=True):
__tablename__ = "person_table"

id: int = Field(primary_key=True)

discriminator: str
name: str

tool_id: int = Field(foreign_key="tool_table.id")
tool: Tool = Relationship()

__mapper_args__ = {
"polymorphic_on": "discriminator",
"polymorphic_identity": "simple_person",
}

class Worker(Person):
__mapper_args__ = {
"polymorphic_identity": "worker",
}

engine = create_engine("sqlite:///:memory:", echo=True)
SQLModel.metadata.create_all(engine)
with Session(engine) as db:
tool = Tool(id=1, name="Hammer")
db.add(tool)
worker = Worker(id=2, name="Bob", tool_id=1)
db.add(worker)
db.commit()

statement = select(Worker).where(Worker.tool_id == 1)
result = db.exec(statement).all()
assert len(result) == 1
assert isinstance(result[0].tool, Tool)