diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 7c916f79af..271bf5ddeb 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -1,6 +1,9 @@ from __future__ import annotations import ipaddress +import sys +import types +import typing import uuid import weakref from datetime import date, datetime, time, timedelta @@ -27,6 +30,7 @@ overload, ) +import typing_extensions from pydantic import BaseModel, EmailStr from pydantic.fields import FieldInfo as PydanticFieldInfo from sqlalchemy import ( @@ -519,7 +523,7 @@ def __new__( if k in relationships: relationship_annotations[k] = v else: - pydantic_annotations[k] = v + pydantic_annotations[k] = resolve_type_alias(v) dict_used = { **dict_for_pydantic, "__weakref__": None, @@ -763,6 +767,54 @@ def get_column_from_field(field: Any) -> Column: # type: ignore return Column(sa_type, *args, **kwargs) # type: ignore +def _is_typing_type_instance(annotation: Any, type_name: str) -> bool: + check_type = [] + if hasattr(typing, type_name): + check_type.append(getattr(typing, type_name)) + if hasattr(typing_extensions, type_name): + check_type.append(getattr(typing_extensions, type_name)) + + return bool(check_type) and isinstance(annotation, tuple(check_type)) + + +def _is_new_type_instance(annotation: Any) -> bool: + if sys.version_info >= (3, 10): + return _is_typing_type_instance(annotation, "NewType") + else: + return hasattr(annotation, "__supertype__") + + +def _is_type_var_instance(annotation: Any) -> bool: + return _is_typing_type_instance(annotation, "TypeVar") + + +def _is_type_alias_type_instance(annotation: Any) -> bool: + if sys.version_info[:2] == (3, 10): + if type(annotation) is types.GenericAlias: + # In Python 3.10, GenericAlias instances are of type TypeAliasType + return False + + return _is_typing_type_instance(annotation, "TypeAliasType") + + +def resolve_type_alias(annotation: Any) -> Any: + if _is_type_var_instance(annotation): + resolution = annotation.__bound__ + if not annotation: + raise ValueError( + "TypeVars without a bound type cannot be converted to SQLAlchemy types" + ) + # annotations.__constraints__ could be used and defined Union[*constraints], but ORM does not support it + elif _is_new_type_instance(annotation): + resolution = annotation.__supertype__ + elif _is_type_alias_type_instance(annotation): + resolution = annotation.__value__ + else: + resolution = annotation + + return resolution + + class_registry = weakref.WeakValueDictionary() # type: ignore default_registry = registry() diff --git a/tests/conftest.py b/tests/conftest.py index 98a4d2b7e6..204c28b9ea 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -93,3 +93,6 @@ def print_mock_fixture() -> Generator[PrintMock, None, None]: needs_py310 = pytest.mark.skipif( sys.version_info < (3, 10), reason="requires python3.10+" ) +needs_py312 = pytest.mark.skipif( + sys.version_info < (3, 12), reason="requires python3.12+" +) diff --git a/tests/test_field_sa_type.py b/tests/test_field_sa_type.py new file mode 100644 index 0000000000..8f3bc46b85 --- /dev/null +++ b/tests/test_field_sa_type.py @@ -0,0 +1,178 @@ +import typing as t +from textwrap import dedent + +import pytest +import typing_extensions as te +from sqlmodel import Field, SQLModel + +from tests.conftest import needs_py312 + + +def test_sa_type_typing_1() -> None: + Type1_t = str + + class Hero1(SQLModel, table=True): + pk: int = Field(primary_key=True) + weapon: Type1_t = "sword" + + +if hasattr(t, "Annotated"): + + def test_sa_type_typing_2() -> None: + Type2_t = t.Annotated[str, "Just a comment"] + + class Hero(SQLModel, table=True): + pk: int = Field(primary_key=True) + weapon: Type2_t = "sword" + + +if hasattr(t, "TypeAlias"): + Type3_t: t.TypeAlias = str + + def test_sa_type_typing_3() -> None: + class Hero(SQLModel, table=True): + pk: int = Field(primary_key=True) + weapon: Type3_t = "sword" + + if hasattr(t, "Annotated"): + Type4_t: t.TypeAlias = t.Annotated[str, "Just a comment"] + + def test_sa_type_typing_4() -> None: + class Hero(SQLModel, table=True): + pk: int = Field(primary_key=True) + weapon: Type4_t = "sword" + + +@needs_py312 +def test_sa_type_typing_5() -> None: + test_code = dedent(""" + type Type5_t = str + + class Hero(SQLModel, table=True): + pk: int = Field(primary_key=True) + weapon: Type5_t = "sword" + """) + exec(test_code, globals()) + + +@needs_py312 +def test_sa_type_typing_6() -> None: + test_code = dedent(""" + type Type6_t = t.Annotated[str, "Just a comment"] + + class Hero(SQLModel, table=True): + pk: int = Field(primary_key=True) + weapon: Type6_t = "sword" + """) + exec(test_code, globals()) + + +def test_sa_type_typing_7() -> None: + Type7_t = t.NewType("Type7_t", str) + + class Hero(SQLModel, table=True): + pk: int = Field(primary_key=True) + weapon: Type7_t = "sword" + + +def test_sa_type_typing_8() -> None: + Type8_t = t.TypeVar("Type8_t", bound=str) + + class Hero(SQLModel, table=True): + pk: int = Field(primary_key=True) + weapon: Type8_t = "sword" + + +def test_sa_type_typing_9() -> None: + Type9_t = t.TypeVar("Type9_t", str, bytes) + + with pytest.raises(ValueError): + + class Hero(SQLModel, table=True): + pk: int = Field(primary_key=True) + weapon: Type9_t = "sword" + + +def test_sa_type_typing_extensions_1() -> None: + Type1_te = str + + class Hero1(SQLModel, table=True): + pk: int = Field(primary_key=True) + weapon: Type1_te = "sword" + + +if hasattr(te, "Annotated"): + + def test_sa_type_typing_extensions_2() -> None: + Type2_te = te.Annotated[str, "Just a comment"] + + class Hero(SQLModel, table=True): + pk: int = Field(primary_key=True) + weapon: Type2_te = "sword" + + +if hasattr(te, "TypeAlias"): + Type3_te: te.TypeAlias = str + + def test_sa_type_typing_extensions_3() -> None: + class Hero(SQLModel, table=True): + pk: int = Field(primary_key=True) + weapon: Type3_te = "sword" + + if hasattr(te, "Annotated"): + Type4_te: te.TypeAlias = te.Annotated[str, "Just a comment"] + + def test_sa_type_typing_extensions_4() -> None: + class Hero(SQLModel, table=True): + pk: int = Field(primary_key=True) + weapon: Type4_te = "sword" + + +@needs_py312 +def test_sa_type_typing_extensions_5() -> None: + test_code = dedent(""" + type Type5_te = str + + class Hero(SQLModel, table=True): + pk: int = Field(primary_key=True) + weapon: Type5_te = "sword" + """) + exec(test_code, globals()) + + +@needs_py312 +def test_sa_type_typing_extensions_6() -> None: + test_code = dedent(""" + type Type6_te = te.Annotated[str, "Just a comment"] + + class Hero(SQLModel, table=True): + pk: int = Field(primary_key=True) + weapon: Type6_te = "sword" + """) + exec(test_code, globals()) + + +def test_sa_type_typing_extensions_7() -> None: + Type7_te = te.NewType("Type7_te", str) + + class Hero(SQLModel, table=True): + pk: int = Field(primary_key=True) + weapon: Type7_te = "sword" + + +def test_sa_type_typing_extensions_8() -> None: + Type8_te = te.TypeVar("Type8_te", bound=str) + + class Hero(SQLModel, table=True): + pk: int = Field(primary_key=True) + weapon: Type8_te = "sword" + + +def test_sa_type_typing_extensions_9() -> None: + Type9_te = te.TypeVar("Type9_te", str, bytes) + + with pytest.raises(ValueError): + + class Hero(SQLModel, table=True): + pk: int = Field(primary_key=True) + weapon: Type9_te = "sword"