From c95b5efcac419c89c7cfd55a935dfffd918e671c Mon Sep 17 00:00:00 2001 From: Robert Holt Date: Thu, 11 Mar 2021 09:31:05 -0500 Subject: [PATCH 1/3] Manually edit python example code The generated example code is edited just to demostrate potential improvements to the python codegen. The main changes are using sqlalchemy as the exection engine and using stdlib dataclasses instead of pydantic. This also eliminates the runtime package. --- examples/python/requirements.txt | 3 +- examples/python/src/authors/models.py | 6 +- examples/python/src/authors/query.py | 95 ++++++++++------------- examples/python/src/dbtest/migrations.py | 15 ++-- examples/python/src/tests/conftest.py | 32 +++++++- examples/python/src/tests/test_authors.py | 33 ++++---- 6 files changed, 96 insertions(+), 88 deletions(-) diff --git a/examples/python/requirements.txt b/examples/python/requirements.txt index c22530eaa9..ab7698b01a 100644 --- a/examples/python/requirements.txt +++ b/examples/python/requirements.txt @@ -2,5 +2,4 @@ pytest~=6.2.2 pytest-asyncio~=0.14.0 psycopg2-binary~=2.8.6 asyncpg~=0.21.0 -pydantic~=1.7.3 -sqlc-python-runtime~=1.0.0 +sqlalchemy==1.4.0b3 diff --git a/examples/python/src/authors/models.py b/examples/python/src/authors/models.py index b282d77dac..c1b6cfaf87 100644 --- a/examples/python/src/authors/models.py +++ b/examples/python/src/authors/models.py @@ -1,13 +1,13 @@ # Code generated by sqlc. DO NOT EDIT. from typing import Optional - -import pydantic +import dataclasses # Enums # Models -class Author(pydantic.BaseModel): +@dataclasses.dataclass() +class Author: id: int name: str bio: Optional[str] diff --git a/examples/python/src/authors/query.py b/examples/python/src/authors/query.py index 947e535309..0ec4839085 100644 --- a/examples/python/src/authors/query.py +++ b/examples/python/src/authors/query.py @@ -1,92 +1,77 @@ # Code generated by sqlc. DO NOT EDIT. -from typing import AsyncIterator, Awaitable, Iterator, Optional, overload +from typing import AsyncIterator, Iterator, Optional -import sqlc_runtime as sqlc +import sqlalchemy +import sqlalchemy.ext.asyncio from authors import models -CREATE_AUTHOR = """-- name: create_author :one +CREATE_AUTHOR = """-- name: create_author \\:one INSERT INTO authors ( name, bio ) VALUES ( - $1, $2 + :p1, :p2 ) RETURNING id, name, bio """ -DELETE_AUTHOR = """-- name: delete_author :exec +DELETE_AUTHOR = """-- name: delete_author \\:exec DELETE FROM authors -WHERE id = $1 +WHERE id = :p1 """ -GET_AUTHOR = """-- name: get_author :one +GET_AUTHOR = """-- name: get_author \\:one SELECT id, name, bio FROM authors -WHERE id = $1 LIMIT 1 +WHERE id = :p1 LIMIT 1 """ -LIST_AUTHORS = """-- name: list_authors :many +LIST_AUTHORS = """-- name: list_authors \\:many SELECT id, name, bio FROM authors ORDER BY name """ -@overload -def create_author(conn: sqlc.Connection, name: str, bio: Optional[str]) -> Optional[models.Author]: - pass +class Query: + def __init__(self, conn: sqlalchemy.engine.Connection): + self._conn = conn + def create_author(self, name: str, bio: Optional[str]) -> Optional[models.Author]: + result = self._conn.execute(sqlalchemy.text(CREATE_AUTHOR), {"p1": name, "p2": bio}) + return models.Author(**dict(result.first())) -@overload -def create_author(conn: sqlc.AsyncConnection, name: str, bio: Optional[str]) -> Awaitable[Optional[models.Author]]: - pass + def delete_author(self, id: int) -> None: + self._conn.execute(sqlalchemy.text(DELETE_AUTHOR), {"p1": id}) + def get_author(self, id: int) -> Optional[models.Author]: + result = self._conn.execute(sqlalchemy.text(GET_AUTHOR), {"p1": id}) + return models.Author(**dict(result.first())) -def create_author(conn: sqlc.GenericConnection, name: str, bio: Optional[str]) -> sqlc.ReturnType[Optional[models.Author]]: - return conn.execute_one_model(models.Author, CREATE_AUTHOR, name, bio) + def list_authors(self) -> Iterator[models.Author]: + result = self._conn.execute(sqlalchemy.text(LIST_AUTHORS)) + for row in result: + yield models.Author(**dict(row)) -@overload -def delete_author(conn: sqlc.Connection, id: int) -> None: - pass +class AsyncQuery: + def __init__(self, conn: sqlalchemy.ext.asyncio.AsyncConnection): + self._conn = conn + async def create_author(self, name: str, bio: Optional[str]) -> Optional[models.Author]: + result = await self._conn.execute(sqlalchemy.text(CREATE_AUTHOR), {"p1": name, "p2": bio}) + return models.Author(**dict(result.first())) -@overload -def delete_author(conn: sqlc.AsyncConnection, id: int) -> Awaitable[None]: - pass - - -def delete_author(conn: sqlc.GenericConnection, id: int) -> sqlc.ReturnType[None]: - return conn.execute_none(DELETE_AUTHOR, id) - - -@overload -def get_author(conn: sqlc.Connection, id: int) -> Optional[models.Author]: - pass - - -@overload -def get_author(conn: sqlc.AsyncConnection, id: int) -> Awaitable[Optional[models.Author]]: - pass - - -def get_author(conn: sqlc.GenericConnection, id: int) -> sqlc.ReturnType[Optional[models.Author]]: - return conn.execute_one_model(models.Author, GET_AUTHOR, id) - - -@overload -def list_authors(conn: sqlc.Connection) -> Iterator[models.Author]: - pass - - -@overload -def list_authors(conn: sqlc.AsyncConnection) -> AsyncIterator[models.Author]: - pass - - -def list_authors(conn: sqlc.GenericConnection) -> sqlc.IteratorReturn[models.Author]: - return conn.execute_many_model(models.Author, LIST_AUTHORS) + async def delete_author(self, id: int) -> None: + await self._conn.execute(sqlalchemy.text(DELETE_AUTHOR), {"p1": id}) + async def get_author(self, id: int) -> Optional[models.Author]: + result = await self._conn.execute(sqlalchemy.text(GET_AUTHOR), {"p1": id}) + return models.Author(**dict(result.first())) + async def list_authors(self) -> AsyncIterator[models.Author]: + result = await self._conn.stream(sqlalchemy.text(LIST_AUTHORS)) + async for row in result: + yield models.Author(**dict(row)) diff --git a/examples/python/src/dbtest/migrations.py b/examples/python/src/dbtest/migrations.py index c3c72b78a6..8882c63169 100644 --- a/examples/python/src/dbtest/migrations.py +++ b/examples/python/src/dbtest/migrations.py @@ -1,29 +1,26 @@ import os from typing import List -import asyncpg -import psycopg2.extensions +import sqlalchemy +import sqlalchemy.ext.asyncio -def apply_migrations(db: psycopg2.extensions.connection, paths: List[str]): +def apply_migrations(conn: sqlalchemy.engine.Connection, paths: List[str]): files = _find_sql_files(paths) for file in files: with open(file, "r") as fd: blob = fd.read() - cur = db.cursor() - cur.execute(blob) - cur.close() - db.commit() + conn.execute(blob) -async def apply_migrations_async(db: asyncpg.Connection, paths: List[str]): +async def apply_migrations_async(conn: sqlalchemy.ext.asyncio.AsyncConnection, paths: List[str]): files = _find_sql_files(paths) for file in files: with open(file, "r") as fd: blob = fd.read() - await db.execute(blob) + await conn.execute(sqlalchemy.text(blob)) def _find_sql_files(paths: List[str]) -> List[str]: diff --git a/examples/python/src/tests/conftest.py b/examples/python/src/tests/conftest.py index e3df5f77dc..6b926935af 100644 --- a/examples/python/src/tests/conftest.py +++ b/examples/python/src/tests/conftest.py @@ -6,6 +6,8 @@ import psycopg2 import psycopg2.extensions import pytest +import sqlalchemy +import sqlalchemy.ext.asyncio @pytest.fixture(scope="session") @@ -16,7 +18,34 @@ def postgres_uri() -> str: pg_password = os.environ.get("PG_PASSWORD", "mysecretpassword") pg_db = os.environ.get("PG_DATABASE", "dinotest") - return f"postgres://{pg_user}:{pg_password}@{pg_host}:{pg_port}/{pg_db}?sslmode=disable" + return f"postgresql://{pg_user}:{pg_password}@{pg_host}:{pg_port}/{pg_db}" + + +@pytest.fixture(scope="session") +def sqlalchemy_connection(postgres_uri) -> sqlalchemy.engine.Connection: + schema_name = f"sqltest_{random.randint(0, 1000)}" + engine = sqlalchemy.create_engine(postgres_uri) + with engine.connect() as conn: + conn.execute(f"CREATE SCHEMA {schema_name}") + conn.execute(f"SET search_path TO {schema_name}") + yield conn + conn.execute(f"DROP SCHEMA {schema_name} CASCADE") + conn.execute("SET search_path TO public") + + +@pytest.fixture(scope="session") +async def async_sqlalchemy_connection(postgres_uri) -> sqlalchemy.ext.asyncio.AsyncConnection: + postgres_uri = postgres_uri.replace("postgresql", "postgresql+asyncpg") + schema_name = f"sqltest_{random.randint(0, 1000)}" + engine = sqlalchemy.ext.asyncio.create_async_engine(postgres_uri) + async with engine.connect() as conn: + await conn.execute(sqlalchemy.text(f"CREATE SCHEMA {schema_name}")) + await conn.execute(sqlalchemy.text(f"SET search_path TO {schema_name}")) + await conn.commit() + yield conn + await conn.rollback() + await conn.execute(sqlalchemy.text(f"DROP SCHEMA {schema_name} CASCADE")) + await conn.execute(sqlalchemy.text("SET search_path TO public")) @pytest.fixture(scope="session") @@ -29,6 +58,7 @@ def postgres_connection(postgres_uri) -> psycopg2.extensions.connection: @pytest.fixture() def postgres_db(postgres_connection) -> psycopg2.extensions.connection: schema_name = f"sqltest_{random.randint(0, 1000)}" + # schema_name = "sqltest_1" cur = postgres_connection.cursor() cur.execute(f"CREATE SCHEMA {schema_name}") cur.execute(f"SET search_path TO {schema_name}") diff --git a/examples/python/src/tests/test_authors.py b/examples/python/src/tests/test_authors.py index bc01f3133b..8cef1a5651 100644 --- a/examples/python/src/tests/test_authors.py +++ b/examples/python/src/tests/test_authors.py @@ -1,59 +1,56 @@ import os -import asyncpg -import psycopg2.extensions import pytest -from sqlc_runtime.psycopg2 import build_psycopg2_connection -from sqlc_runtime.asyncpg import build_asyncpg_connection +import sqlalchemy.ext.asyncio from authors import query from dbtest.migrations import apply_migrations, apply_migrations_async -def test_authors(postgres_db: psycopg2.extensions.connection): - apply_migrations(postgres_db, [os.path.dirname(__file__) + "/../../../authors/postgresql/schema.sql"]) +def test_authors(sqlalchemy_connection: sqlalchemy.engine.Connection): + apply_migrations(sqlalchemy_connection, [os.path.dirname(__file__) + "/../../../authors/postgresql/schema.sql"]) - db = build_psycopg2_connection(postgres_db) + db = query.Query(sqlalchemy_connection) - authors = list(query.list_authors(db)) + authors = list(db.list_authors()) assert authors == [] author_name = "Brian Kernighan" author_bio = "Co-author of The C Programming Language and The Go Programming Language" - new_author = query.create_author(db, name=author_name, bio=author_bio) + new_author = db.create_author(name=author_name, bio=author_bio) assert new_author.id > 0 assert new_author.name == author_name assert new_author.bio == author_bio - db_author = query.get_author(db, new_author.id) + db_author = db.get_author(new_author.id) assert db_author == new_author - author_list = list(query.list_authors(db)) + author_list = list(db.list_authors()) assert len(author_list) == 1 assert author_list[0] == new_author @pytest.mark.asyncio -async def test_authors_async(async_postgres_db: asyncpg.Connection): - await apply_migrations_async(async_postgres_db, [os.path.dirname(__file__) + "/../../../authors/postgresql/schema.sql"]) +async def test_authors_async(async_sqlalchemy_connection: sqlalchemy.ext.asyncio.AsyncConnection): + await apply_migrations_async(async_sqlalchemy_connection, [os.path.dirname(__file__) + "/../../../authors/postgresql/schema.sql"]) - db = build_asyncpg_connection(async_postgres_db) + db = query.AsyncQuery(async_sqlalchemy_connection) - async for _ in query.list_authors(db): + async for _ in db.list_authors(): assert False, "No authors should exist" author_name = "Brian Kernighan" author_bio = "Co-author of The C Programming Language and The Go Programming Language" - new_author = await query.create_author(db, name=author_name, bio=author_bio) + new_author = await db.create_author(name=author_name, bio=author_bio) assert new_author.id > 0 assert new_author.name == author_name assert new_author.bio == author_bio - db_author = await query.get_author(db, new_author.id) + db_author = await db.get_author(new_author.id) assert db_author == new_author author_list = [] - async for author in query.list_authors(db): + async for author in db.list_authors(): author_list.append(author) assert len(author_list) == 1 assert author_list[0] == new_author From d9c77db3e1e3cf6c3478ab8be833f156be76eda2 Mon Sep 17 00:00:00 2001 From: Robert Holt Date: Thu, 11 Mar 2021 09:34:41 -0500 Subject: [PATCH 2/3] Remove leftover debug code in python tests --- examples/python/src/tests/conftest.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/python/src/tests/conftest.py b/examples/python/src/tests/conftest.py index 6b926935af..b02ae01ad1 100644 --- a/examples/python/src/tests/conftest.py +++ b/examples/python/src/tests/conftest.py @@ -58,7 +58,6 @@ def postgres_connection(postgres_uri) -> psycopg2.extensions.connection: @pytest.fixture() def postgres_db(postgres_connection) -> psycopg2.extensions.connection: schema_name = f"sqltest_{random.randint(0, 1000)}" - # schema_name = "sqltest_1" cur = postgres_connection.cursor() cur.execute(f"CREATE SCHEMA {schema_name}") cur.execute(f"SET search_path TO {schema_name}") From 5f73dbdaca13f5e5484d60d035e5ab00df741e21 Mon Sep 17 00:00:00 2001 From: Robert Holt Date: Mon, 15 Mar 2021 16:14:58 -0400 Subject: [PATCH 3/3] Python: implement new codegen Generated python code now only depends on sqlalchemy. Query functions are now inside classes as well. --- examples/python/requirements.txt | 2 +- examples/python/sqlc.json | 13 +- examples/python/src/authors/models.py | 4 +- examples/python/src/authors/query.py | 74 +++-- examples/python/src/booktest/models.py | 12 +- examples/python/src/booktest/query.py | 321 +++++++++------------ examples/python/src/dbtest/migrations.py | 9 +- examples/python/src/jets/models.py | 20 +- examples/python/src/jets/query-building.py | 68 ++--- examples/python/src/ondeck/city.py | 110 +++---- examples/python/src/ondeck/models.py | 12 +- examples/python/src/ondeck/venue.py | 226 ++++++--------- examples/python/src/tests/conftest.py | 79 ++--- examples/python/src/tests/test_authors.py | 28 +- examples/python/src/tests/test_booktest.py | 114 ++++---- examples/python/src/tests/test_ondeck.py | 30 +- internal/codegen/python/gen.go | 238 ++++++++------- internal/codegen/python/imports.go | 19 +- internal/config/config.go | 13 +- internal/config/v_two.go | 11 +- 20 files changed, 677 insertions(+), 726 deletions(-) diff --git a/examples/python/requirements.txt b/examples/python/requirements.txt index ab7698b01a..26f645a169 100644 --- a/examples/python/requirements.txt +++ b/examples/python/requirements.txt @@ -2,4 +2,4 @@ pytest~=6.2.2 pytest-asyncio~=0.14.0 psycopg2-binary~=2.8.6 asyncpg~=0.21.0 -sqlalchemy==1.4.0b3 +sqlalchemy==1.4.0 diff --git a/examples/python/sqlc.json b/examples/python/sqlc.json index ba987b5af1..583513a184 100644 --- a/examples/python/sqlc.json +++ b/examples/python/sqlc.json @@ -8,7 +8,9 @@ "gen": { "python": { "out": "src/authors", - "package": "authors" + "package": "authors", + "emit_sync_querier": true, + "emit_async_querier": true } } }, @@ -19,7 +21,8 @@ "gen": { "python": { "out": "src/booktest", - "package": "booktest" + "package": "booktest", + "emit_async_querier": true } } }, @@ -30,7 +33,8 @@ "gen": { "python": { "out": "src/jets", - "package": "jets" + "package": "jets", + "emit_async_querier": true } } }, @@ -41,7 +45,8 @@ "gen": { "python": { "out": "src/ondeck", - "package": "ondeck" + "package": "ondeck", + "emit_async_querier": true } } } diff --git a/examples/python/src/authors/models.py b/examples/python/src/authors/models.py index c1b6cfaf87..42d945abe1 100644 --- a/examples/python/src/authors/models.py +++ b/examples/python/src/authors/models.py @@ -1,11 +1,11 @@ # Code generated by sqlc. DO NOT EDIT. from typing import Optional + import dataclasses -# Enums -# Models + @dataclasses.dataclass() class Author: id: int diff --git a/examples/python/src/authors/query.py b/examples/python/src/authors/query.py index 0ec4839085..5571b49364 100644 --- a/examples/python/src/authors/query.py +++ b/examples/python/src/authors/query.py @@ -1,3 +1,4 @@ + # Code generated by sqlc. DO NOT EDIT. from typing import AsyncIterator, Iterator, Optional @@ -35,43 +36,76 @@ """ -class Query: +class Querier: def __init__(self, conn: sqlalchemy.engine.Connection): self._conn = conn - def create_author(self, name: str, bio: Optional[str]) -> Optional[models.Author]: - result = self._conn.execute(sqlalchemy.text(CREATE_AUTHOR), {"p1": name, "p2": bio}) - return models.Author(**dict(result.first())) - - def delete_author(self, id: int) -> None: + def create_author(self, *, name: str, bio: Optional[str]) -> Optional[models.Author]: + row = self._conn.execute(sqlalchemy.text(CREATE_AUTHOR), {"p1": name, "p2": bio}).first() + if row is None: + return None + return models.Author( + id=row[0], + name=row[1], + bio=row[2], + ) + + def delete_author(self, *, id: int) -> None: self._conn.execute(sqlalchemy.text(DELETE_AUTHOR), {"p1": id}) - def get_author(self, id: int) -> Optional[models.Author]: - result = self._conn.execute(sqlalchemy.text(GET_AUTHOR), {"p1": id}) - return models.Author(**dict(result.first())) + def get_author(self, *, id: int) -> Optional[models.Author]: + row = self._conn.execute(sqlalchemy.text(GET_AUTHOR), {"p1": id}).first() + if row is None: + return None + return models.Author( + id=row[0], + name=row[1], + bio=row[2], + ) def list_authors(self) -> Iterator[models.Author]: result = self._conn.execute(sqlalchemy.text(LIST_AUTHORS)) for row in result: - yield models.Author(**dict(row)) + yield models.Author( + id=row[0], + name=row[1], + bio=row[2], + ) -class AsyncQuery: +class AsyncQuerier: def __init__(self, conn: sqlalchemy.ext.asyncio.AsyncConnection): self._conn = conn - async def create_author(self, name: str, bio: Optional[str]) -> Optional[models.Author]: - result = await self._conn.execute(sqlalchemy.text(CREATE_AUTHOR), {"p1": name, "p2": bio}) - return models.Author(**dict(result.first())) - - async def delete_author(self, id: int) -> None: + async def create_author(self, *, name: str, bio: Optional[str]) -> Optional[models.Author]: + row = (await self._conn.execute(sqlalchemy.text(CREATE_AUTHOR), {"p1": name, "p2": bio})).first() + if row is None: + return None + return models.Author( + id=row[0], + name=row[1], + bio=row[2], + ) + + async def delete_author(self, *, id: int) -> None: await self._conn.execute(sqlalchemy.text(DELETE_AUTHOR), {"p1": id}) - async def get_author(self, id: int) -> Optional[models.Author]: - result = await self._conn.execute(sqlalchemy.text(GET_AUTHOR), {"p1": id}) - return models.Author(**dict(result.first())) + async def get_author(self, *, id: int) -> Optional[models.Author]: + row = (await self._conn.execute(sqlalchemy.text(GET_AUTHOR), {"p1": id})).first() + if row is None: + return None + return models.Author( + id=row[0], + name=row[1], + bio=row[2], + ) async def list_authors(self) -> AsyncIterator[models.Author]: result = await self._conn.stream(sqlalchemy.text(LIST_AUTHORS)) async for row in result: - yield models.Author(**dict(row)) + yield models.Author( + id=row[0], + name=row[1], + bio=row[2], + ) + diff --git a/examples/python/src/booktest/models.py b/examples/python/src/booktest/models.py index 9c7d1a1970..7367c14720 100644 --- a/examples/python/src/booktest/models.py +++ b/examples/python/src/booktest/models.py @@ -3,22 +3,24 @@ import datetime import enum -import pydantic +import dataclasses + -# Enums class BookType(str, enum.Enum): FICTION = "FICTION" NONFICTION = "NONFICTION" -# Models -class Author(pydantic.BaseModel): +@dataclasses.dataclass() +class Author: author_id: int name: str -class Book(pydantic.BaseModel): + +@dataclasses.dataclass() +class Book: book_id: int author_id: int isbn: str diff --git a/examples/python/src/booktest/query.py b/examples/python/src/booktest/query.py index 3081b2a598..6bc73be5fb 100644 --- a/examples/python/src/booktest/query.py +++ b/examples/python/src/booktest/query.py @@ -1,14 +1,16 @@ + # Code generated by sqlc. DO NOT EDIT. -from typing import AsyncIterator, Awaitable, Iterator, List, Optional, overload +from typing import AsyncIterator, List, Optional import datetime -import pydantic -import sqlc_runtime as sqlc +import dataclasses +import sqlalchemy +import sqlalchemy.ext.asyncio from booktest import models -BOOKS_BY_TAGS = """-- name: books_by_tags :many +BOOKS_BY_TAGS = """-- name: books_by_tags \\:many SELECT book_id, title, @@ -17,11 +19,12 @@ tags FROM books LEFT JOIN authors ON books.author_id = authors.author_id -WHERE tags && $1::varchar[] +WHERE tags && :p1\\:\\:varchar[] """ -class BooksByTagsRow(pydantic.BaseModel): +@dataclasses.dataclass() +class BooksByTagsRow: book_id: int title: str name: str @@ -29,30 +32,19 @@ class BooksByTagsRow(pydantic.BaseModel): tags: List[str] -BOOKS_BY_TITLE_YEAR = """-- name: books_by_title_year :many +BOOKS_BY_TITLE_YEAR = """-- name: books_by_title_year \\:many SELECT book_id, author_id, isbn, book_type, title, year, available, tags FROM books -WHERE title = $1 AND year = $2 +WHERE title = :p1 AND year = :p2 """ -class BooksByTitleYearRow(pydantic.BaseModel): - book_id: int - author_id: int - isbn: str - book_type: models.BookType - title: str - year: int - available: datetime.datetime - tags: List[str] - - -CREATE_AUTHOR = """-- name: create_author :one -INSERT INTO authors (name) VALUES ($1) +CREATE_AUTHOR = """-- name: create_author \\:one +INSERT INTO authors (name) VALUES (:p1) RETURNING author_id, name """ -CREATE_BOOK = """-- name: create_book :one +CREATE_BOOK = """-- name: create_book \\:one INSERT INTO books ( author_id, isbn, @@ -62,30 +54,20 @@ class BooksByTitleYearRow(pydantic.BaseModel): available, tags ) VALUES ( - $1, - $2, - $3, - $4, - $5, - $6, - $7 + :p1, + :p2, + :p3, + :p4, + :p5, + :p6, + :p7 ) RETURNING book_id, author_id, isbn, book_type, title, year, available, tags """ -class CreateBookParams(pydantic.BaseModel): - author_id: int - isbn: str - book_type: models.BookType - title: str - year: int - available: datetime.datetime - tags: List[str] - - -class CreateBookRow(pydantic.BaseModel): - book_id: int +@dataclasses.dataclass() +class CreateBookParams: author_id: int isbn: str book_type: models.BookType @@ -95,172 +77,135 @@ class CreateBookRow(pydantic.BaseModel): tags: List[str] -DELETE_BOOK = """-- name: delete_book :exec +DELETE_BOOK = """-- name: delete_book \\:exec DELETE FROM books -WHERE book_id = $1 +WHERE book_id = :p1 """ -GET_AUTHOR = """-- name: get_author :one +GET_AUTHOR = """-- name: get_author \\:one SELECT author_id, name FROM authors -WHERE author_id = $1 +WHERE author_id = :p1 """ -GET_BOOK = """-- name: get_book :one +GET_BOOK = """-- name: get_book \\:one SELECT book_id, author_id, isbn, book_type, title, year, available, tags FROM books -WHERE book_id = $1 +WHERE book_id = :p1 """ -class GetBookRow(pydantic.BaseModel): - book_id: int - author_id: int - isbn: str - book_type: models.BookType - title: str - year: int - available: datetime.datetime - tags: List[str] - - -UPDATE_BOOK = """-- name: update_book :exec +UPDATE_BOOK = """-- name: update_book \\:exec UPDATE books -SET title = $1, tags = $2 -WHERE book_id = $3 +SET title = :p1, tags = :p2 +WHERE book_id = :p3 """ -UPDATE_BOOK_ISBN = """-- name: update_book_isbn :exec +UPDATE_BOOK_ISBN = """-- name: update_book_isbn \\:exec UPDATE books -SET title = $1, tags = $2, isbn = $4 -WHERE book_id = $3 +SET title = :p1, tags = :p2, isbn = :p4 +WHERE book_id = :p3 """ -@overload -def books_by_tags(conn: sqlc.Connection, dollar_1: List[str]) -> Iterator[BooksByTagsRow]: - pass - - -@overload -def books_by_tags(conn: sqlc.AsyncConnection, dollar_1: List[str]) -> AsyncIterator[BooksByTagsRow]: - pass - - -def books_by_tags(conn: sqlc.GenericConnection, dollar_1: List[str]) -> sqlc.IteratorReturn[BooksByTagsRow]: - return conn.execute_many_model(BooksByTagsRow, BOOKS_BY_TAGS, dollar_1) - - -@overload -def books_by_title_year(conn: sqlc.Connection, title: str, year: int) -> Iterator[BooksByTitleYearRow]: - pass - - -@overload -def books_by_title_year(conn: sqlc.AsyncConnection, title: str, year: int) -> AsyncIterator[BooksByTitleYearRow]: - pass - - -def books_by_title_year(conn: sqlc.GenericConnection, title: str, year: int) -> sqlc.IteratorReturn[BooksByTitleYearRow]: - return conn.execute_many_model(BooksByTitleYearRow, BOOKS_BY_TITLE_YEAR, title, year) - - -@overload -def create_author(conn: sqlc.Connection, name: str) -> Optional[models.Author]: - pass - - -@overload -def create_author(conn: sqlc.AsyncConnection, name: str) -> Awaitable[Optional[models.Author]]: - pass - - -def create_author(conn: sqlc.GenericConnection, name: str) -> sqlc.ReturnType[Optional[models.Author]]: - return conn.execute_one_model(models.Author, CREATE_AUTHOR, name) - - -@overload -def create_book(conn: sqlc.Connection, arg: CreateBookParams) -> Optional[CreateBookRow]: - pass - - -@overload -def create_book(conn: sqlc.AsyncConnection, arg: CreateBookParams) -> Awaitable[Optional[CreateBookRow]]: - pass - - -def create_book(conn: sqlc.GenericConnection, arg: CreateBookParams) -> sqlc.ReturnType[Optional[CreateBookRow]]: - return conn.execute_one_model(CreateBookRow, CREATE_BOOK, arg.author_id, arg.isbn, arg.book_type, arg.title, arg.year, arg.available, arg.tags) - - -@overload -def delete_book(conn: sqlc.Connection, book_id: int) -> None: - pass - - -@overload -def delete_book(conn: sqlc.AsyncConnection, book_id: int) -> Awaitable[None]: - pass - - -def delete_book(conn: sqlc.GenericConnection, book_id: int) -> sqlc.ReturnType[None]: - return conn.execute_none(DELETE_BOOK, book_id) - - -@overload -def get_author(conn: sqlc.Connection, author_id: int) -> Optional[models.Author]: - pass - - -@overload -def get_author(conn: sqlc.AsyncConnection, author_id: int) -> Awaitable[Optional[models.Author]]: - pass - - -def get_author(conn: sqlc.GenericConnection, author_id: int) -> sqlc.ReturnType[Optional[models.Author]]: - return conn.execute_one_model(models.Author, GET_AUTHOR, author_id) - - -@overload -def get_book(conn: sqlc.Connection, book_id: int) -> Optional[GetBookRow]: - pass - - -@overload -def get_book(conn: sqlc.AsyncConnection, book_id: int) -> Awaitable[Optional[GetBookRow]]: - pass - - -def get_book(conn: sqlc.GenericConnection, book_id: int) -> sqlc.ReturnType[Optional[GetBookRow]]: - return conn.execute_one_model(GetBookRow, GET_BOOK, book_id) - - -@overload -def update_book(conn: sqlc.Connection, title: str, tags: List[str], book_id: int) -> None: - pass - - -@overload -def update_book(conn: sqlc.AsyncConnection, title: str, tags: List[str], book_id: int) -> Awaitable[None]: - pass - - -def update_book(conn: sqlc.GenericConnection, title: str, tags: List[str], book_id: int) -> sqlc.ReturnType[None]: - return conn.execute_none(UPDATE_BOOK, title, tags, book_id) - - -@overload -def update_book_isbn(conn: sqlc.Connection, title: str, tags: List[str], book_id: int, isbn: str) -> None: - pass - - -@overload -def update_book_isbn(conn: sqlc.AsyncConnection, title: str, tags: List[str], book_id: int, isbn: str) -> Awaitable[None]: - pass - - -def update_book_isbn(conn: sqlc.GenericConnection, title: str, tags: List[str], book_id: int, isbn: str) -> sqlc.ReturnType[None]: - return conn.execute_none(UPDATE_BOOK_ISBN, title, tags, book_id, isbn) +class AsyncQuerier: + def __init__(self, conn: sqlalchemy.ext.asyncio.AsyncConnection): + self._conn = conn + + async def books_by_tags(self, *, dollar_1: List[str]) -> AsyncIterator[BooksByTagsRow]: + result = await self._conn.stream(sqlalchemy.text(BOOKS_BY_TAGS), {"p1": dollar_1}) + async for row in result: + yield BooksByTagsRow( + book_id=row[0], + title=row[1], + name=row[2], + isbn=row[3], + tags=row[4], + ) + + async def books_by_title_year(self, *, title: str, year: int) -> AsyncIterator[models.Book]: + result = await self._conn.stream(sqlalchemy.text(BOOKS_BY_TITLE_YEAR), {"p1": title, "p2": year}) + async for row in result: + yield models.Book( + book_id=row[0], + author_id=row[1], + isbn=row[2], + book_type=row[3], + title=row[4], + year=row[5], + available=row[6], + tags=row[7], + ) + + async def create_author(self, *, name: str) -> Optional[models.Author]: + row = (await self._conn.execute(sqlalchemy.text(CREATE_AUTHOR), {"p1": name})).first() + if row is None: + return None + return models.Author( + author_id=row[0], + name=row[1], + ) + + async def create_book(self, arg: CreateBookParams) -> Optional[models.Book]: + row = (await self._conn.execute(sqlalchemy.text(CREATE_BOOK), { + "p1": arg.author_id, + "p2": arg.isbn, + "p3": arg.book_type, + "p4": arg.title, + "p5": arg.year, + "p6": arg.available, + "p7": arg.tags, + })).first() + if row is None: + return None + return models.Book( + book_id=row[0], + author_id=row[1], + isbn=row[2], + book_type=row[3], + title=row[4], + year=row[5], + available=row[6], + tags=row[7], + ) + + async def delete_book(self, *, book_id: int) -> None: + await self._conn.execute(sqlalchemy.text(DELETE_BOOK), {"p1": book_id}) + + async def get_author(self, *, author_id: int) -> Optional[models.Author]: + row = (await self._conn.execute(sqlalchemy.text(GET_AUTHOR), {"p1": author_id})).first() + if row is None: + return None + return models.Author( + author_id=row[0], + name=row[1], + ) + + async def get_book(self, *, book_id: int) -> Optional[models.Book]: + row = (await self._conn.execute(sqlalchemy.text(GET_BOOK), {"p1": book_id})).first() + if row is None: + return None + return models.Book( + book_id=row[0], + author_id=row[1], + isbn=row[2], + book_type=row[3], + title=row[4], + year=row[5], + available=row[6], + tags=row[7], + ) + + async def update_book(self, *, title: str, tags: List[str], book_id: int) -> None: + await self._conn.execute(sqlalchemy.text(UPDATE_BOOK), {"p1": title, "p2": tags, "p3": book_id}) + + async def update_book_isbn(self, *, title: str, tags: List[str], book_id: int, isbn: str) -> None: + await self._conn.execute(sqlalchemy.text(UPDATE_BOOK_ISBN), { + "p1": title, + "p2": tags, + "p3": book_id, + "p4": isbn, + }) diff --git a/examples/python/src/dbtest/migrations.py b/examples/python/src/dbtest/migrations.py index 8882c63169..6ace6bcd63 100644 --- a/examples/python/src/dbtest/migrations.py +++ b/examples/python/src/dbtest/migrations.py @@ -11,7 +11,10 @@ def apply_migrations(conn: sqlalchemy.engine.Connection, paths: List[str]): for file in files: with open(file, "r") as fd: blob = fd.read() - conn.execute(blob) + stmts = blob.split(";") + for stmt in stmts: + if stmt.strip(): + conn.execute(sqlalchemy.text(stmt)) async def apply_migrations_async(conn: sqlalchemy.ext.asyncio.AsyncConnection, paths: List[str]): @@ -20,7 +23,9 @@ async def apply_migrations_async(conn: sqlalchemy.ext.asyncio.AsyncConnection, p for file in files: with open(file, "r") as fd: blob = fd.read() - await conn.execute(sqlalchemy.text(blob)) + raw_conn = await conn.get_raw_connection() + # The asyncpg sqlalchemy adapter uses a prepared statement cache which can't handle the migration statements + await raw_conn._connection.execute(blob) def _find_sql_files(paths: List[str]) -> List[str]: diff --git a/examples/python/src/jets/models.py b/examples/python/src/jets/models.py index cb543dcb22..13ff1a122c 100644 --- a/examples/python/src/jets/models.py +++ b/examples/python/src/jets/models.py @@ -1,13 +1,13 @@ # Code generated by sqlc. DO NOT EDIT. -import pydantic +import dataclasses -# Enums -# Models -class Jet(pydantic.BaseModel): + +@dataclasses.dataclass() +class Jet: id: int pilot_id: int age: int @@ -15,17 +15,23 @@ class Jet(pydantic.BaseModel): color: str -class Language(pydantic.BaseModel): + +@dataclasses.dataclass() +class Language: id: int language: str -class Pilot(pydantic.BaseModel): + +@dataclasses.dataclass() +class Pilot: id: int name: str -class PilotLanguage(pydantic.BaseModel): + +@dataclasses.dataclass() +class PilotLanguage: pilot_id: int language_id: int diff --git a/examples/python/src/jets/query-building.py b/examples/python/src/jets/query-building.py index 21056fe10c..0725a80902 100644 --- a/examples/python/src/jets/query-building.py +++ b/examples/python/src/jets/query-building.py @@ -1,65 +1,47 @@ + # Code generated by sqlc. DO NOT EDIT. -from typing import AsyncIterator, Awaitable, Iterator, Optional, overload +from typing import AsyncIterator, Optional -import sqlc_runtime as sqlc +import sqlalchemy +import sqlalchemy.ext.asyncio from jets import models -COUNT_PILOTS = """-- name: count_pilots :one +COUNT_PILOTS = """-- name: count_pilots \\:one SELECT COUNT(*) FROM pilots """ -DELETE_PILOT = """-- name: delete_pilot :exec -DELETE FROM pilots WHERE id = $1 +DELETE_PILOT = """-- name: delete_pilot \\:exec +DELETE FROM pilots WHERE id = :p1 """ -LIST_PILOTS = """-- name: list_pilots :many +LIST_PILOTS = """-- name: list_pilots \\:many SELECT id, name FROM pilots LIMIT 5 """ -@overload -def count_pilots(conn: sqlc.Connection) -> Optional[int]: - pass - - -@overload -def count_pilots(conn: sqlc.AsyncConnection) -> Awaitable[Optional[int]]: - pass - - -def count_pilots(conn: sqlc.GenericConnection) -> sqlc.ReturnType[Optional[int]]: - return conn.execute_one(COUNT_PILOTS) - - -@overload -def delete_pilot(conn: sqlc.Connection, id: int) -> None: - pass - - -@overload -def delete_pilot(conn: sqlc.AsyncConnection, id: int) -> Awaitable[None]: - pass - - -def delete_pilot(conn: sqlc.GenericConnection, id: int) -> sqlc.ReturnType[None]: - return conn.execute_none(DELETE_PILOT, id) - - -@overload -def list_pilots(conn: sqlc.Connection) -> Iterator[models.Pilot]: - pass - -@overload -def list_pilots(conn: sqlc.AsyncConnection) -> AsyncIterator[models.Pilot]: - pass +class AsyncQuerier: + def __init__(self, conn: sqlalchemy.ext.asyncio.AsyncConnection): + self._conn = conn + async def count_pilots(self) -> Optional[int]: + row = (await self._conn.execute(sqlalchemy.text(COUNT_PILOTS))).first() + if row is None: + return None + return row[0] -def list_pilots(conn: sqlc.GenericConnection) -> sqlc.IteratorReturn[models.Pilot]: - return conn.execute_many_model(models.Pilot, LIST_PILOTS) + async def delete_pilot(self, *, id: int) -> None: + await self._conn.execute(sqlalchemy.text(DELETE_PILOT), {"p1": id}) + async def list_pilots(self) -> AsyncIterator[models.Pilot]: + result = await self._conn.stream(sqlalchemy.text(LIST_PILOTS)) + async for row in result: + yield models.Pilot( + id=row[0], + name=row[1], + ) diff --git a/examples/python/src/ondeck/city.py b/examples/python/src/ondeck/city.py index baf0051e9c..d7a2123b97 100644 --- a/examples/python/src/ondeck/city.py +++ b/examples/python/src/ondeck/city.py @@ -1,96 +1,76 @@ + # Code generated by sqlc. DO NOT EDIT. -from typing import AsyncIterator, Awaitable, Iterator, Optional, overload +from typing import AsyncIterator, Optional -import sqlc_runtime as sqlc +import sqlalchemy +import sqlalchemy.ext.asyncio from ondeck import models -CREATE_CITY = """-- name: create_city :one +CREATE_CITY = """-- name: create_city \\:one INSERT INTO city ( name, slug ) VALUES ( - $1, - $2 + :p1, + :p2 ) RETURNING slug, name """ -GET_CITY = """-- name: get_city :one +GET_CITY = """-- name: get_city \\:one SELECT slug, name FROM city -WHERE slug = $1 +WHERE slug = :p1 """ -LIST_CITIES = """-- name: list_cities :many +LIST_CITIES = """-- name: list_cities \\:many SELECT slug, name FROM city ORDER BY name """ -UPDATE_CITY_NAME = """-- name: update_city_name :exec +UPDATE_CITY_NAME = """-- name: update_city_name \\:exec UPDATE city -SET name = $2 -WHERE slug = $1 +SET name = :p2 +WHERE slug = :p1 """ -@overload -def create_city(conn: sqlc.Connection, name: str, slug: str) -> Optional[models.City]: - pass - - -@overload -def create_city(conn: sqlc.AsyncConnection, name: str, slug: str) -> Awaitable[Optional[models.City]]: - pass - - -def create_city(conn: sqlc.GenericConnection, name: str, slug: str) -> sqlc.ReturnType[Optional[models.City]]: - return conn.execute_one_model(models.City, CREATE_CITY, name, slug) - - -@overload -def get_city(conn: sqlc.Connection, slug: str) -> Optional[models.City]: - pass - - -@overload -def get_city(conn: sqlc.AsyncConnection, slug: str) -> Awaitable[Optional[models.City]]: - pass - - -def get_city(conn: sqlc.GenericConnection, slug: str) -> sqlc.ReturnType[Optional[models.City]]: - return conn.execute_one_model(models.City, GET_CITY, slug) - - -@overload -def list_cities(conn: sqlc.Connection) -> Iterator[models.City]: - pass - - -@overload -def list_cities(conn: sqlc.AsyncConnection) -> AsyncIterator[models.City]: - pass - - -def list_cities(conn: sqlc.GenericConnection) -> sqlc.IteratorReturn[models.City]: - return conn.execute_many_model(models.City, LIST_CITIES) - - -@overload -def update_city_name(conn: sqlc.Connection, slug: str, name: str) -> None: - pass - - -@overload -def update_city_name(conn: sqlc.AsyncConnection, slug: str, name: str) -> Awaitable[None]: - pass - - -def update_city_name(conn: sqlc.GenericConnection, slug: str, name: str) -> sqlc.ReturnType[None]: - return conn.execute_none(UPDATE_CITY_NAME, slug, name) +class AsyncQuerier: + def __init__(self, conn: sqlalchemy.ext.asyncio.AsyncConnection): + self._conn = conn + + async def create_city(self, *, name: str, slug: str) -> Optional[models.City]: + row = (await self._conn.execute(sqlalchemy.text(CREATE_CITY), {"p1": name, "p2": slug})).first() + if row is None: + return None + return models.City( + slug=row[0], + name=row[1], + ) + + async def get_city(self, *, slug: str) -> Optional[models.City]: + row = (await self._conn.execute(sqlalchemy.text(GET_CITY), {"p1": slug})).first() + if row is None: + return None + return models.City( + slug=row[0], + name=row[1], + ) + + async def list_cities(self) -> AsyncIterator[models.City]: + result = await self._conn.stream(sqlalchemy.text(LIST_CITIES)) + async for row in result: + yield models.City( + slug=row[0], + name=row[1], + ) + + async def update_city_name(self, *, slug: str, name: str) -> None: + await self._conn.execute(sqlalchemy.text(UPDATE_CITY_NAME), {"p1": slug, "p2": name}) diff --git a/examples/python/src/ondeck/models.py b/examples/python/src/ondeck/models.py index bc8244b2a3..6e2de33b5e 100644 --- a/examples/python/src/ondeck/models.py +++ b/examples/python/src/ondeck/models.py @@ -3,22 +3,24 @@ import datetime import enum -import pydantic +import dataclasses -# Enums# Venues can be either open or closed +# Venues can be either open or closed class Status(str, enum.Enum): OPEN = "op!en" CLOSED = "clo@sed" -# Models -class City(pydantic.BaseModel): +@dataclasses.dataclass() +class City: slug: str name: str + # Venues are places where muisc happens -class Venue(pydantic.BaseModel): +@dataclasses.dataclass() +class Venue: id: int status: Status statuses: Optional[List[Status]] diff --git a/examples/python/src/ondeck/venue.py b/examples/python/src/ondeck/venue.py index 289725e0d4..c0e31be474 100644 --- a/examples/python/src/ondeck/venue.py +++ b/examples/python/src/ondeck/venue.py @@ -1,14 +1,15 @@ + # Code generated by sqlc. DO NOT EDIT. -from typing import AsyncIterator, Awaitable, Iterator, List, Optional, overload -import datetime +from typing import AsyncIterator, List, Optional -import pydantic -import sqlc_runtime as sqlc +import dataclasses +import sqlalchemy +import sqlalchemy.ext.asyncio from ondeck import models -CREATE_VENUE = """-- name: create_venue :one +CREATE_VENUE = """-- name: create_venue \\:one INSERT INTO venue ( slug, name, @@ -19,19 +20,20 @@ statuses, tags ) VALUES ( - $1, - $2, - $3, + :p1, + :p2, + :p3, NOW(), - $4, - $5, - $6, - $7 + :p4, + :p5, + :p6, + :p7 ) RETURNING id """ -class CreateVenueParams(pydantic.BaseModel): +@dataclasses.dataclass() +class CreateVenueParams: slug: str name: str city: str @@ -41,62 +43,36 @@ class CreateVenueParams(pydantic.BaseModel): tags: Optional[List[str]] -DELETE_VENUE = """-- name: delete_venue :exec +DELETE_VENUE = """-- name: delete_venue \\:exec DELETE FROM venue -WHERE slug = $1 AND slug = $1 +WHERE slug = :p1 AND slug = :p1 """ -GET_VENUE = """-- name: get_venue :one +GET_VENUE = """-- name: get_venue \\:one SELECT id, status, statuses, slug, name, city, spotify_playlist, songkick_id, tags, created_at FROM venue -WHERE slug = $1 AND city = $2 +WHERE slug = :p1 AND city = :p2 """ -class GetVenueRow(pydantic.BaseModel): - id: int - status: models.Status - statuses: Optional[List[models.Status]] - slug: str - name: str - city: str - spotify_playlist: str - songkick_id: Optional[str] - tags: Optional[List[str]] - created_at: datetime.datetime - - -LIST_VENUES = """-- name: list_venues :many +LIST_VENUES = """-- name: list_venues \\:many SELECT id, status, statuses, slug, name, city, spotify_playlist, songkick_id, tags, created_at FROM venue -WHERE city = $1 +WHERE city = :p1 ORDER BY name """ -class ListVenuesRow(pydantic.BaseModel): - id: int - status: models.Status - statuses: Optional[List[models.Status]] - slug: str - name: str - city: str - spotify_playlist: str - songkick_id: Optional[str] - tags: Optional[List[str]] - created_at: datetime.datetime - - -UPDATE_VENUE_NAME = """-- name: update_venue_name :one +UPDATE_VENUE_NAME = """-- name: update_venue_name \\:one UPDATE venue -SET name = $2 -WHERE slug = $1 +SET name = :p2 +WHERE slug = :p1 RETURNING id """ -VENUE_COUNT_BY_CITY = """-- name: venue_count_by_city :many +VENUE_COUNT_BY_CITY = """-- name: venue_count_by_city \\:many SELECT city, count(*) @@ -106,92 +82,78 @@ class ListVenuesRow(pydantic.BaseModel): """ -class VenueCountByCityRow(pydantic.BaseModel): +@dataclasses.dataclass() +class VenueCountByCityRow: city: str count: int -@overload -def create_venue(conn: sqlc.Connection, arg: CreateVenueParams) -> Optional[int]: - pass - - -@overload -def create_venue(conn: sqlc.AsyncConnection, arg: CreateVenueParams) -> Awaitable[Optional[int]]: - pass - - -def create_venue(conn: sqlc.GenericConnection, arg: CreateVenueParams) -> sqlc.ReturnType[Optional[int]]: - return conn.execute_one(CREATE_VENUE, arg.slug, arg.name, arg.city, arg.spotify_playlist, arg.status, arg.statuses, arg.tags) - - -@overload -def delete_venue(conn: sqlc.Connection, slug: str) -> None: - pass - - -@overload -def delete_venue(conn: sqlc.AsyncConnection, slug: str) -> Awaitable[None]: - pass - - -def delete_venue(conn: sqlc.GenericConnection, slug: str) -> sqlc.ReturnType[None]: - return conn.execute_none(DELETE_VENUE, slug) - - -@overload -def get_venue(conn: sqlc.Connection, slug: str, city: str) -> Optional[GetVenueRow]: - pass - - -@overload -def get_venue(conn: sqlc.AsyncConnection, slug: str, city: str) -> Awaitable[Optional[GetVenueRow]]: - pass - - -def get_venue(conn: sqlc.GenericConnection, slug: str, city: str) -> sqlc.ReturnType[Optional[GetVenueRow]]: - return conn.execute_one_model(GetVenueRow, GET_VENUE, slug, city) - - -@overload -def list_venues(conn: sqlc.Connection, city: str) -> Iterator[ListVenuesRow]: - pass - - -@overload -def list_venues(conn: sqlc.AsyncConnection, city: str) -> AsyncIterator[ListVenuesRow]: - pass - - -def list_venues(conn: sqlc.GenericConnection, city: str) -> sqlc.IteratorReturn[ListVenuesRow]: - return conn.execute_many_model(ListVenuesRow, LIST_VENUES, city) - - -@overload -def update_venue_name(conn: sqlc.Connection, slug: str, name: str) -> Optional[int]: - pass - - -@overload -def update_venue_name(conn: sqlc.AsyncConnection, slug: str, name: str) -> Awaitable[Optional[int]]: - pass - - -def update_venue_name(conn: sqlc.GenericConnection, slug: str, name: str) -> sqlc.ReturnType[Optional[int]]: - return conn.execute_one(UPDATE_VENUE_NAME, slug, name) - - -@overload -def venue_count_by_city(conn: sqlc.Connection) -> Iterator[VenueCountByCityRow]: - pass - - -@overload -def venue_count_by_city(conn: sqlc.AsyncConnection) -> AsyncIterator[VenueCountByCityRow]: - pass - - -def venue_count_by_city(conn: sqlc.GenericConnection) -> sqlc.IteratorReturn[VenueCountByCityRow]: - return conn.execute_many_model(VenueCountByCityRow, VENUE_COUNT_BY_CITY) +class AsyncQuerier: + def __init__(self, conn: sqlalchemy.ext.asyncio.AsyncConnection): + self._conn = conn + + async def create_venue(self, arg: CreateVenueParams) -> Optional[int]: + row = (await self._conn.execute(sqlalchemy.text(CREATE_VENUE), { + "p1": arg.slug, + "p2": arg.name, + "p3": arg.city, + "p4": arg.spotify_playlist, + "p5": arg.status, + "p6": arg.statuses, + "p7": arg.tags, + })).first() + if row is None: + return None + return row[0] + + async def delete_venue(self, *, slug: str) -> None: + await self._conn.execute(sqlalchemy.text(DELETE_VENUE), {"p1": slug}) + + async def get_venue(self, *, slug: str, city: str) -> Optional[models.Venue]: + row = (await self._conn.execute(sqlalchemy.text(GET_VENUE), {"p1": slug, "p2": city})).first() + if row is None: + return None + return models.Venue( + id=row[0], + status=row[1], + statuses=row[2], + slug=row[3], + name=row[4], + city=row[5], + spotify_playlist=row[6], + songkick_id=row[7], + tags=row[8], + created_at=row[9], + ) + + async def list_venues(self, *, city: str) -> AsyncIterator[models.Venue]: + result = await self._conn.stream(sqlalchemy.text(LIST_VENUES), {"p1": city}) + async for row in result: + yield models.Venue( + id=row[0], + status=row[1], + statuses=row[2], + slug=row[3], + name=row[4], + city=row[5], + spotify_playlist=row[6], + songkick_id=row[7], + tags=row[8], + created_at=row[9], + ) + + async def update_venue_name(self, *, slug: str, name: str) -> Optional[int]: + row = (await self._conn.execute(sqlalchemy.text(UPDATE_VENUE_NAME), {"p1": slug, "p2": name})).first() + if row is None: + return None + return row[0] + + async def venue_count_by_city(self) -> AsyncIterator[VenueCountByCityRow]: + result = await self._conn.stream(sqlalchemy.text(VENUE_COUNT_BY_CITY)) + async for row in result: + yield VenueCountByCityRow( + city=row[0], + count=row[1], + ) diff --git a/examples/python/src/tests/conftest.py b/examples/python/src/tests/conftest.py index b02ae01ad1..f807209229 100644 --- a/examples/python/src/tests/conftest.py +++ b/examples/python/src/tests/conftest.py @@ -2,9 +2,6 @@ import os import random -import asyncpg -import psycopg2 -import psycopg2.extensions import pytest import sqlalchemy import sqlalchemy.ext.asyncio @@ -23,53 +20,43 @@ def postgres_uri() -> str: @pytest.fixture(scope="session") def sqlalchemy_connection(postgres_uri) -> sqlalchemy.engine.Connection: - schema_name = f"sqltest_{random.randint(0, 1000)}" - engine = sqlalchemy.create_engine(postgres_uri) + engine = sqlalchemy.create_engine(postgres_uri, future=True) with engine.connect() as conn: - conn.execute(f"CREATE SCHEMA {schema_name}") - conn.execute(f"SET search_path TO {schema_name}") yield conn - conn.execute(f"DROP SCHEMA {schema_name} CASCADE") - conn.execute("SET search_path TO public") + + +@pytest.fixture(scope="function") +def db(sqlalchemy_connection: sqlalchemy.engine.Connection) -> sqlalchemy.engine.Connection: + conn = sqlalchemy_connection + schema_name = f"sqltest_{random.randint(0, 1000)}" + conn.execute(sqlalchemy.text(f"CREATE SCHEMA {schema_name}")) + conn.execute(sqlalchemy.text(f"SET search_path TO {schema_name}")) + conn.commit() + yield conn + conn.rollback() + conn.execute(sqlalchemy.text(f"DROP SCHEMA {schema_name} CASCADE")) + conn.execute(sqlalchemy.text("SET search_path TO public")) @pytest.fixture(scope="session") async def async_sqlalchemy_connection(postgres_uri) -> sqlalchemy.ext.asyncio.AsyncConnection: postgres_uri = postgres_uri.replace("postgresql", "postgresql+asyncpg") - schema_name = f"sqltest_{random.randint(0, 1000)}" engine = sqlalchemy.ext.asyncio.create_async_engine(postgres_uri) async with engine.connect() as conn: - await conn.execute(sqlalchemy.text(f"CREATE SCHEMA {schema_name}")) - await conn.execute(sqlalchemy.text(f"SET search_path TO {schema_name}")) - await conn.commit() yield conn - await conn.rollback() - await conn.execute(sqlalchemy.text(f"DROP SCHEMA {schema_name} CASCADE")) - await conn.execute(sqlalchemy.text("SET search_path TO public")) - - -@pytest.fixture(scope="session") -def postgres_connection(postgres_uri) -> psycopg2.extensions.connection: - conn = psycopg2.connect(postgres_uri) - yield conn - conn.close() -@pytest.fixture() -def postgres_db(postgres_connection) -> psycopg2.extensions.connection: +@pytest.fixture(scope="function") +async def async_db(async_sqlalchemy_connection: sqlalchemy.ext.asyncio.AsyncConnection) -> sqlalchemy.ext.asyncio.AsyncConnection: + conn = async_sqlalchemy_connection schema_name = f"sqltest_{random.randint(0, 1000)}" - cur = postgres_connection.cursor() - cur.execute(f"CREATE SCHEMA {schema_name}") - cur.execute(f"SET search_path TO {schema_name}") - cur.close() - postgres_connection.commit() - yield postgres_connection - postgres_connection.rollback() - cur = postgres_connection.cursor() - cur.execute(f"DROP SCHEMA {schema_name} CASCADE") - cur.execute(f"SET search_path TO public") - cur.close() - postgres_connection.commit() + await conn.execute(sqlalchemy.text(f"CREATE SCHEMA {schema_name}")) + await conn.execute(sqlalchemy.text(f"SET search_path TO {schema_name}")) + await conn.commit() + yield conn + await conn.rollback() + await conn.execute(sqlalchemy.text(f"DROP SCHEMA {schema_name} CASCADE")) + await conn.execute(sqlalchemy.text("SET search_path TO public")) @pytest.fixture(scope="session") @@ -78,21 +65,3 @@ def event_loop(): loop = asyncio.get_event_loop_policy().new_event_loop() yield loop loop.close() - - -@pytest.fixture(scope="session") -async def async_postgres_connection(postgres_uri: str) -> asyncpg.Connection: - conn = await asyncpg.connect(postgres_uri) - yield conn - await conn.close() - - -@pytest.fixture() -async def async_postgres_db(async_postgres_connection: asyncpg.Connection) -> asyncpg.Connection: - conn = async_postgres_connection - schema_name = f"sqltest_{random.randint(0, 1000)}" - await conn.execute(f"CREATE SCHEMA {schema_name}") - await conn.execute(f"SET search_path TO {schema_name}") - yield conn - await conn.execute(f"DROP SCHEMA {schema_name} CASCADE") - await conn.execute(f"SET search_path TO public") diff --git a/examples/python/src/tests/test_authors.py b/examples/python/src/tests/test_authors.py index 8cef1a5651..7b0a954276 100644 --- a/examples/python/src/tests/test_authors.py +++ b/examples/python/src/tests/test_authors.py @@ -7,50 +7,50 @@ from dbtest.migrations import apply_migrations, apply_migrations_async -def test_authors(sqlalchemy_connection: sqlalchemy.engine.Connection): - apply_migrations(sqlalchemy_connection, [os.path.dirname(__file__) + "/../../../authors/postgresql/schema.sql"]) +def test_authors(db: sqlalchemy.engine.Connection): + apply_migrations(db, [os.path.dirname(__file__) + "/../../../authors/postgresql/schema.sql"]) - db = query.Query(sqlalchemy_connection) + querier = query.Querier(db) - authors = list(db.list_authors()) + authors = list(querier.list_authors()) assert authors == [] author_name = "Brian Kernighan" author_bio = "Co-author of The C Programming Language and The Go Programming Language" - new_author = db.create_author(name=author_name, bio=author_bio) + new_author = querier.create_author(name=author_name, bio=author_bio) assert new_author.id > 0 assert new_author.name == author_name assert new_author.bio == author_bio - db_author = db.get_author(new_author.id) + db_author = querier.get_author(id=new_author.id) assert db_author == new_author - author_list = list(db.list_authors()) + author_list = list(querier.list_authors()) assert len(author_list) == 1 assert author_list[0] == new_author @pytest.mark.asyncio -async def test_authors_async(async_sqlalchemy_connection: sqlalchemy.ext.asyncio.AsyncConnection): - await apply_migrations_async(async_sqlalchemy_connection, [os.path.dirname(__file__) + "/../../../authors/postgresql/schema.sql"]) +async def test_authors_async(async_db: sqlalchemy.ext.asyncio.AsyncConnection): + await apply_migrations_async(async_db, [os.path.dirname(__file__) + "/../../../authors/postgresql/schema.sql"]) - db = query.AsyncQuery(async_sqlalchemy_connection) + querier = query.AsyncQuerier(async_db) - async for _ in db.list_authors(): + async for _ in querier.list_authors(): assert False, "No authors should exist" author_name = "Brian Kernighan" author_bio = "Co-author of The C Programming Language and The Go Programming Language" - new_author = await db.create_author(name=author_name, bio=author_bio) + new_author = await querier.create_author(name=author_name, bio=author_bio) assert new_author.id > 0 assert new_author.name == author_name assert new_author.bio == author_bio - db_author = await db.get_author(new_author.id) + db_author = await querier.get_author(id=new_author.id) assert db_author == new_author author_list = [] - async for author in db.list_authors(): + async for author in querier.list_authors(): author_list.append(author) assert len(author_list) == 1 assert author_list[0] == new_author diff --git a/examples/python/src/tests/test_booktest.py b/examples/python/src/tests/test_booktest.py index b0ba38891a..6106d9d3fd 100644 --- a/examples/python/src/tests/test_booktest.py +++ b/examples/python/src/tests/test_booktest.py @@ -1,87 +1,85 @@ import datetime import os -import asyncpg import pytest -from sqlc_runtime.asyncpg import build_asyncpg_connection +import sqlalchemy.ext.asyncio from booktest import query, models from dbtest.migrations import apply_migrations_async @pytest.mark.asyncio -async def test_books(async_postgres_db: asyncpg.Connection): - await apply_migrations_async(async_postgres_db, [os.path.dirname(__file__) + "/../../../booktest/postgresql/schema.sql"]) +async def test_books(async_db: sqlalchemy.ext.asyncio.AsyncConnection): + await apply_migrations_async(async_db, [os.path.dirname(__file__) + "/../../../booktest/postgresql/schema.sql"]) - db = build_asyncpg_connection(async_postgres_db) + querier = query.AsyncQuerier(async_db) - author = await query.create_author(db, "Unknown Master") + author = await querier.create_author(name="Unknown Master") assert author is not None - async with async_postgres_db.transaction(): - now = datetime.datetime.now() - await query.create_book(db, query.CreateBookParams( - author_id=author.author_id, - isbn="1", - title="my book title", - book_type=models.BookType.FICTION, - year=2016, - available=now, - tags=[], - )) - - b1 = await query.create_book(db, query.CreateBookParams( - author_id=author.author_id, - isbn="2", - title="the second book", - book_type=models.BookType.FICTION, - year=2016, - available=now, - tags=["cool", "unique"], - )) - - await query.update_book(db, book_id=b1.book_id, title="changed second title", tags=["cool", "disastor"]) - - b3 = await query.create_book(db, query.CreateBookParams( - author_id=author.author_id, - isbn="3", - title="the third book", - book_type=models.BookType.FICTION, - year=2001, - available=now, - tags=["cool"], - )) - - b4 = await query.create_book(db, query.CreateBookParams( - author_id=author.author_id, - isbn="4", - title="4th place finisher", - book_type=models.BookType.NONFICTION, - year=2011, - available=now, - tags=["other"], - )) - - await query.update_book_isbn(db, book_id=b4.book_id, isbn="NEW ISBN", title="never ever gonna finish, a quatrain", tags=["someother"]) - - books0 = query.books_by_title_year(db, title="my book title", year=2016) + now = datetime.datetime.now() + await querier.create_book(query.CreateBookParams( + author_id=author.author_id, + isbn="1", + title="my book title", + book_type=models.BookType.FICTION, + year=2016, + available=now, + tags=[], + )) + + b1 = await querier.create_book(query.CreateBookParams( + author_id=author.author_id, + isbn="2", + title="the second book", + book_type=models.BookType.FICTION, + year=2016, + available=now, + tags=["cool", "unique"], + )) + + await querier.update_book(book_id=b1.book_id, title="changed second title", tags=["cool", "disastor"]) + + b3 = await querier.create_book(query.CreateBookParams( + author_id=author.author_id, + isbn="3", + title="the third book", + book_type=models.BookType.FICTION, + year=2001, + available=now, + tags=["cool"], + )) + + b4 = await querier.create_book(query.CreateBookParams( + author_id=author.author_id, + isbn="4", + title="4th place finisher", + book_type=models.BookType.NONFICTION, + year=2011, + available=now, + tags=["other"], + )) + + await querier.update_book_isbn(book_id=b4.book_id, isbn="NEW ISBN", title="never ever gonna finish, a quatrain", tags=["someother"]) + + books0 = querier.books_by_title_year(title="my book title", year=2016) expected_titles = {"my book title"} async for book in books0: expected_titles.remove(book.title) # raises a key error if the title does not exist assert len(book.tags) == 0 - author = await query.get_author(db, author_id=book.author_id) + author = await querier.get_author(author_id=book.author_id) assert author.name == "Unknown Master" assert len(expected_titles) == 0 - books = query.books_by_tags(db, ["cool", "other", "someother"]) + books = querier.books_by_tags(dollar_1=["cool", "other", "someother"]) expected_titles = {"changed second title", "the third book", "never ever gonna finish, a quatrain"} async for book in books: expected_titles.remove(book.title) assert len(expected_titles) == 0 - b5 = await query.get_book(db, b3.book_id) + b5 = await querier.get_book(book_id=b3.book_id) assert b5 is not None - await query.delete_book(db, book_id=b5.book_id) - b6 = await query.get_book(db, b5.book_id) + await querier.delete_book(book_id=b5.book_id) + b6 = await querier.get_book(book_id=b5.book_id) assert b6 is None diff --git a/examples/python/src/tests/test_ondeck.py b/examples/python/src/tests/test_ondeck.py index f12fbe985c..68cfbc9bcb 100644 --- a/examples/python/src/tests/test_ondeck.py +++ b/examples/python/src/tests/test_ondeck.py @@ -1,8 +1,7 @@ import os -import asyncpg import pytest -from sqlc_runtime.asyncpg import build_asyncpg_connection +import sqlalchemy.ext.asyncio from ondeck import models from ondeck import city as city_queries @@ -11,15 +10,16 @@ @pytest.mark.asyncio -async def test_ondeck(async_postgres_db: asyncpg.Connection): - await apply_migrations_async(async_postgres_db, [os.path.dirname(__file__) + "/../../../ondeck/postgresql/schema"]) +async def test_ondeck(async_db: sqlalchemy.ext.asyncio.AsyncConnection): + await apply_migrations_async(async_db, [os.path.dirname(__file__) + "/../../../ondeck/postgresql/schema"]) - db = build_asyncpg_connection(async_postgres_db) + city_querier = city_queries.AsyncQuerier(async_db) + venue_querier = venue_queries.AsyncQuerier(async_db) - city = await city_queries.create_city(db, slug="san-francisco", name="San Francisco") + city = await city_querier.create_city(slug="san-francisco", name="San Francisco") assert city is not None - venue_id = await venue_queries.create_venue(db, venue_queries.CreateVenueParams( + venue_id = await venue_querier.create_venue(venue_queries.CreateVenueParams( slug="the-fillmore", name="The Fillmore", city=city.slug, @@ -30,20 +30,20 @@ async def test_ondeck(async_postgres_db: asyncpg.Connection): )) assert venue_id is not None - venue = await venue_queries.get_venue(db, slug="the-fillmore", city=city.slug) + venue = await venue_querier.get_venue(slug="the-fillmore", city=city.slug) assert venue is not None assert venue.id == venue_id - assert city == await city_queries.get_city(db, city.slug) - assert [venue_queries.VenueCountByCityRow(city=city.slug, count=1)] == await _to_list(venue_queries.venue_count_by_city(db)) - assert [city] == await _to_list(city_queries.list_cities(db)) - assert [venue] == await _to_list(venue_queries.list_venues(db, city=city.slug)) + assert city == await city_querier.get_city(slug=city.slug) + assert [venue_queries.VenueCountByCityRow(city=city.slug, count=1)] == await _to_list(venue_querier.venue_count_by_city()) + assert [city] == await _to_list(city_querier.list_cities()) + assert [venue] == await _to_list(venue_querier.list_venues(city=city.slug)) - await city_queries.update_city_name(db, slug=city.slug, name="SF") - _id = await venue_queries.update_venue_name(db, slug=venue.slug, name="Fillmore") + await city_querier.update_city_name(slug=city.slug, name="SF") + _id = await venue_querier.update_venue_name(slug=venue.slug, name="Fillmore") assert _id == venue_id - await venue_queries.delete_venue(db, slug=venue.slug) + await venue_querier.delete_venue(slug=venue.slug) async def _to_list(it): diff --git a/internal/codegen/python/gen.go b/internal/codegen/python/gen.go index fda5374524..a30dd3ef50 100644 --- a/internal/codegen/python/gen.go +++ b/internal/codegen/python/gen.go @@ -60,19 +60,6 @@ type Struct struct { Comment string } -func (s Struct) DedupFields() []Field { - seen := map[string]struct{}{} - dedupFields := make([]Field, 0) - for _, f := range s.Fields { - if _, ok := seen[f.Name]; ok { - continue - } - seen[f.Name] = struct{}{} - dedupFields = append(dedupFields, f) - } - return dedupFields -} - type QueryValue struct { Emit bool Name string @@ -113,6 +100,19 @@ func (v QueryValue) Type() string { panic("no type for QueryValue: " + v.Name) } +func (v QueryValue) StructRowParser(rowVar string, indentCount int) string { + if !v.IsStruct() { + panic("StructRowParse called on non-struct QueryValue") + } + indent := strings.Repeat(" ", indentCount+4) + params := make([]string, 0, len(v.Struct.Fields)) + for i, f := range v.Struct.Fields { + params = append(params, fmt.Sprintf("%s%s=%s[%v],", indent, f.Name, rowVar, i)) + } + indent = strings.Repeat(" ", indentCount) + return v.Type() + "(\n" + strings.Join(params, "\n") + "\n" + indent + ")" +} + // A struct used to generate methods and fields on the Queries struct type Query struct { Cmd string @@ -127,6 +127,10 @@ type Query struct { } func (q Query) ArgPairs() string { + // A single struct arg does not need to be passed as a keyword argument + if len(q.Args) == 1 && q.Args[0].IsStruct() { + return ", " + q.Args[0].Pair() + } argPairs := make([]string, 0, len(q.Args)) for _, a := range q.Args { argPairs = append(argPairs, a.Pair()) @@ -134,27 +138,33 @@ func (q Query) ArgPairs() string { if len(argPairs) == 0 { return "" } - return ", " + strings.Join(argPairs, ", ") + return ", *, " + strings.Join(argPairs, ", ") } -func (q Query) ArgParams() string { +func (q Query) ArgDict() string { params := make([]string, 0, len(q.Args)) + i := 1 for _, a := range q.Args { if a.isEmpty() { continue } if a.IsStruct() { for _, f := range a.Struct.Fields { - params = append(params, a.Name+"."+f.Name) + params = append(params, fmt.Sprintf("\"p%v\": %s", i, a.Name+"."+f.Name)) + i++ } } else { - params = append(params, a.Name) + params = append(params, fmt.Sprintf("\"p%v\": %s", i, a.Name)) + i++ } } if len(params) == 0 { return "" } - return ", " + strings.Join(params, ", ") + if len(params) < 4 { + return ", {" + strings.Join(params, ", ") + "}" + } + return ", {\n " + strings.Join(params, ",\n ") + ",\n }" } func makePyType(r *compiler.Result, col *compiler.Column, settings config.CombinedSettings) pyType { @@ -356,6 +366,18 @@ func sameTableName(n *ast.TableName, f core.FQN, defaultSchema string) bool { return n.Catalog == f.Catalog && schema == f.Schema && n.Name == f.Rel } +var postgresPlaceholderRegexp = regexp.MustCompile(`\B\$(\d+)\b`) + +// Sqlalchemy uses ":name" for placeholders, so "$N" is converted to ":pN" +// This also means ":" has special meaning to sqlalchemy, so it must be escaped. +func sqlalchemySQL(s string, engine config.Engine) string { + s = strings.ReplaceAll(s, ":", `\\:`) + if engine == config.EnginePostgreSQL { + return postgresPlaceholderRegexp.ReplaceAllString(s, ":p$1") + } + return s +} + func buildQueries(r *compiler.Result, settings config.CombinedSettings, structs []Struct) []Query { qs := make([]Query, 0, len(r.Queries)) for _, query := range r.Queries { @@ -374,7 +396,7 @@ func buildQueries(r *compiler.Result, settings config.CombinedSettings, structs MethodName: methodName, FieldName: codegen.LowerTitle(query.Name) + "Stmt", ConstantName: strings.ToUpper(methodName), - SQL: query.SQL, + SQL: sqlalchemySQL(query.SQL, settings.Package.Engine), SourceName: query.Filename, } @@ -419,8 +441,11 @@ func buildQueries(r *compiler.Result, settings config.CombinedSettings, structs same := true for i, f := range s.Fields { c := query.Columns[i] + // HACK: models do not have "models." on their types, so trim that so we can find matches + trimmedPyType := makePyType(r, c, settings) + trimmedPyType.InnerType = strings.TrimPrefix(trimmedPyType.InnerType, "models.") sameName := f.Name == columnName(c, i) - sameType := f.Type == makePyType(r, c, settings) + sameType := f.Type == trimmedPyType sameTable := sameTableName(c.Table, s.Table, r.Catalog.DefaultSchema) if !sameName || !sameType || !sameTable { same = false @@ -462,8 +487,7 @@ var modelsTmpl = `# Code generated by sqlc. DO NOT EDIT. {{- end}} -# Enums -{{- range .Enums}} +{{range .Enums}} {{- if .Comment}}{{comment .Comment}}{{- end}} class {{.Name}}(str, enum.Enum): {{- range .Constants}} @@ -471,10 +495,10 @@ class {{.Name}}(str, enum.Enum): {{- end}} {{end}} -# Models {{- range .Models}} -{{- if .Comment}}{{comment .Comment}}{{- end}} -class {{.Name}}(pydantic.BaseModel): {{- range .DedupFields}} +{{if .Comment}}{{comment .Comment}}{{- end}} +@dataclasses.dataclass() +class {{.Name}}: {{- range .Fields}} {{- if .Comment}} {{comment .Comment}}{{else}} {{- end}} @@ -484,124 +508,142 @@ class {{.Name}}(pydantic.BaseModel): {{- range .DedupFields}} {{end}} ` -var queriesTmpl = `# Code generated by sqlc. DO NOT EDIT. +var queriesTmpl = ` +{{- define "dataclassParse"}} + +{{end}} +# Code generated by sqlc. DO NOT EDIT. {{- range imports .SourceName}} {{.}} {{- end}} {{range .Queries}} {{- if $.OutputQuery .SourceName}} -{{.ConstantName}} = """-- name: {{.MethodName}} {{.Cmd}} +{{.ConstantName}} = """-- name: {{.MethodName}} \\{{.Cmd}} {{.SQL}} """ {{range .Args}} {{- if .EmitStruct}} -class {{.Type}}(pydantic.BaseModel): {{- range .Struct.DedupFields}} +@dataclasses.dataclass() +class {{.Type}}: {{- range .Struct.Fields}} {{.Name}}: {{.Type}} {{- end}} {{end}}{{end}} {{- if .Ret.EmitStruct}} -class {{.Ret.Type}}(pydantic.BaseModel): {{- range .Ret.Struct.DedupFields}} +@dataclasses.dataclass() +class {{.Ret.Type}}: {{- range .Ret.Struct.Fields}} {{.Name}}: {{.Type}} {{- end}} {{end}} {{end}} {{- end}} -{{- range .Queries}} +{{- if .EmitSync}} +class Querier: + def __init__(self, conn: sqlalchemy.engine.Connection): + self._conn = conn +{{range .Queries}} {{- if $.OutputQuery .SourceName}} {{- if eq .Cmd ":one"}} -@overload -def {{.MethodName}}(conn: sqlc.Connection{{.ArgPairs}}) -> Optional[{{.Ret.Type}}]: - pass - - -@overload -def {{.MethodName}}(conn: sqlc.AsyncConnection{{.ArgPairs}}) -> Awaitable[Optional[{{.Ret.Type}}]]: - pass - - -def {{.MethodName}}(conn: sqlc.GenericConnection{{.ArgPairs}}) -> sqlc.ReturnType[Optional[{{.Ret.Type}}]]: - {{- if .Ret.IsStruct}} - return conn.execute_one_model({{.Ret.Type}}, {{.ConstantName}}{{.ArgParams}}) - {{- else}} - return conn.execute_one({{.ConstantName}}{{.ArgParams}}) - {{- end}} + def {{.MethodName}}(self{{.ArgPairs}}) -> Optional[{{.Ret.Type}}]: + row = self._conn.execute(sqlalchemy.text({{.ConstantName}}){{.ArgDict}}).first() + if row is None: + return None + {{- if .Ret.IsStruct}} + return {{.Ret.StructRowParser "row" 8}} + {{- else}} + return row[0] + {{- end}} {{end}} {{- if eq .Cmd ":many"}} -@overload -def {{.MethodName}}(conn: sqlc.Connection{{.ArgPairs}}) -> Iterator[{{.Ret.Type}}]: - pass - - -@overload -def {{.MethodName}}(conn: sqlc.AsyncConnection{{.ArgPairs}}) -> AsyncIterator[{{.Ret.Type}}]: - pass - - -def {{.MethodName}}(conn: sqlc.GenericConnection{{.ArgPairs}}) -> sqlc.IteratorReturn[{{.Ret.Type}}]: - {{- if .Ret.IsStruct}} - return conn.execute_many_model({{.Ret.Type}}, {{.ConstantName}}{{.ArgParams}}) - {{- else}} - return conn.execute_many({{.ConstantName}}{{.ArgParams}}) - {{- end}} + def {{.MethodName}}(self{{.ArgPairs}}) -> Iterator[{{.Ret.Type}}]: + result = self._conn.execute(sqlalchemy.text({{.ConstantName}}){{.ArgDict}}) + for row in result: + {{- if .Ret.IsStruct}} + yield {{.Ret.StructRowParser "row" 12}} + {{- else}} + yield row[0] + {{- end}} {{end}} {{- if eq .Cmd ":exec"}} -@overload -def {{.MethodName}}(conn: sqlc.Connection{{.ArgPairs}}) -> None: - pass - - -@overload -def {{.MethodName}}(conn: sqlc.AsyncConnection{{.ArgPairs}}) -> Awaitable[None]: - pass - - -def {{.MethodName}}(conn: sqlc.GenericConnection{{.ArgPairs}}) -> sqlc.ReturnType[None]: - return conn.execute_none({{.ConstantName}}{{.ArgParams}}) + def {{.MethodName}}(self{{.ArgPairs}}) -> None: + self._conn.execute(sqlalchemy.text({{.ConstantName}}){{.ArgDict}}) {{end}} {{- if eq .Cmd ":execrows"}} -@overload -def {{.MethodName}}(conn: sqlc.Connection{{.ArgPairs}}) -> int: - pass - - -@overload -def {{.MethodName}}(conn: sqlc.AsyncConnection{{.ArgPairs}}) -> Awaitable[int]: - pass - - -def {{.MethodName}}(conn: sqlc.GenericConnection{{.ArgPairs}}) -> sqlc.ReturnType[int]: - return conn.execute_rowcount({{.ConstantName}}{{.ArgParams}}) + def {{.MethodName}}(self{{.ArgPairs}}) -> int: + result = self._conn.execute(sqlalchemy.text({{.ConstantName}}){{.ArgDict}}) + return result.rowcount {{end}} {{- if eq .Cmd ":execresult"}} -@overload -def {{.MethodName}}(conn: sqlc.Connection{{.ArgPairs}}) -> sqlc.Cursor: - pass + def {{.MethodName}}(self{{.ArgPairs}}) -> sqlalchemy.engine.Result: + return self._conn.execute(sqlalchemy.text({{.ConstantName}}){{.ArgDict}}) +{{end}} +{{- end}} +{{- end}} +{{- end}} +{{- if .EmitAsync}} -@overload -def {{.MethodName}}(conn: sqlc.AsyncConnection{{.ArgPairs}}) -> sqlc.AsyncCursor: - pass +class AsyncQuerier: + def __init__(self, conn: sqlalchemy.ext.asyncio.AsyncConnection): + self._conn = conn +{{range .Queries}} +{{- if $.OutputQuery .SourceName}} +{{- if eq .Cmd ":one"}} + async def {{.MethodName}}(self{{.ArgPairs}}) -> Optional[{{.Ret.Type}}]: + row = (await self._conn.execute(sqlalchemy.text({{.ConstantName}}){{.ArgDict}})).first() + if row is None: + return None + {{- if .Ret.IsStruct}} + return {{.Ret.StructRowParser "row" 8}} + {{- else}} + return row[0] + {{- end}} +{{end}} +{{- if eq .Cmd ":many"}} + async def {{.MethodName}}(self{{.ArgPairs}}) -> AsyncIterator[{{.Ret.Type}}]: + result = await self._conn.stream(sqlalchemy.text({{.ConstantName}}){{.ArgDict}}) + async for row in result: + {{- if .Ret.IsStruct}} + yield {{.Ret.StructRowParser "row" 12}} + {{- else}} + yield row[0] + {{- end}} +{{end}} -def {{.MethodName}}(conn: sqlc.GenericConnection{{.ArgPairs}}) -> sqlc.GenericCursor: - return conn.execute({{.ConstantName}}{{.ArgParams}}) +{{- if eq .Cmd ":exec"}} + async def {{.MethodName}}(self{{.ArgPairs}}) -> None: + await self._conn.execute(sqlalchemy.text({{.ConstantName}}){{.ArgDict}}) {{end}} + +{{- if eq .Cmd ":execrows"}} + async def {{.MethodName}}(self{{.ArgPairs}}) -> int: + result = await self._conn.execute(sqlalchemy.text({{.ConstantName}}){{.ArgDict}}) + return result.rowcount +{{end}} + +{{- if eq .Cmd ":execresult"}} + async def {{.MethodName}}(self{{.ArgPairs}}) -> sqlalchemy.engine.Result: + return await self._conn.execute(sqlalchemy.text({{.ConstantName}}){{.ArgDict}}) {{end}} {{- end}} +{{- end}} +{{- end}} ` type pyTmplCtx struct { Models []Struct Queries []Query Enums []Enum + EmitSync bool + EmitAsync bool SourceName string } @@ -635,9 +677,11 @@ func Generate(r *compiler.Result, settings config.CombinedSettings) (map[string] queriesFile := template.Must(template.New("table").Funcs(funcMap).Parse(queriesTmpl)) tctx := pyTmplCtx{ - Models: models, - Queries: queries, - Enums: enums, + Models: models, + Queries: queries, + Enums: enums, + EmitSync: settings.Python.EmitSyncQuerier, + EmitAsync: settings.Python.EmitAsyncQuerier, } output := map[string]string{} diff --git a/internal/codegen/python/imports.go b/internal/codegen/python/imports.go index 493284a564..dfce83a085 100644 --- a/internal/codegen/python/imports.go +++ b/internal/codegen/python/imports.go @@ -92,7 +92,7 @@ func (i *importer) modelImports() []string { } pkg := make(map[string]importSpec) - pkg["pydantic"] = importSpec{Module: "pydantic"} + pkg["dataclasses"] = importSpec{Module: "dataclasses"} for _, o := range i.Settings.Overrides { if o.PythonType.IsSet() && o.PythonType.Module != "" { @@ -129,11 +129,12 @@ func (i *importer) queryImports(fileName string) []string { } std := stdImports(queryUses) - std["typing.overload"] = importSpec{Module: "typing", Name: "overload"} - std["typing.Awaitable"] = importSpec{Module: "typing", Name: "Awaitable"} pkg := make(map[string]importSpec) - pkg["sqlc_runtime"] = importSpec{Module: "sqlc_runtime", Alias: "sqlc"} + pkg["sqlalchemy"] = importSpec{Module: "sqlalchemy"} + if i.Settings.Python.EmitAsyncQuerier { + pkg["sqlalchemy.ext.asyncio"] = importSpec{Module: "sqlalchemy.ext.asyncio"} + } for _, o := range i.Settings.Overrides { if o.PythonType.IsSet() && o.PythonType.Module != "" { @@ -145,7 +146,7 @@ func (i *importer) queryImports(fileName string) []string { queryValueModelImports := func(qv QueryValue) { if qv.IsStruct() && qv.EmitStruct() { - pkg["pydantic"] = importSpec{Module: "pydantic"} + pkg["dataclasses"] = importSpec{Module: "dataclasses"} } } @@ -157,8 +158,12 @@ func (i *importer) queryImports(fileName string) []string { std["typing.Optional"] = importSpec{Module: "typing", Name: "Optional"} } if q.Cmd == ":many" { - std["typing.Iterator"] = importSpec{Module: "typing", Name: "Iterator"} - std["typing.AsyncIterator"] = importSpec{Module: "typing", Name: "AsyncIterator"} + if i.Settings.Python.EmitSyncQuerier { + std["typing.Iterator"] = importSpec{Module: "typing", Name: "Iterator"} + } + if i.Settings.Python.EmitAsyncQuerier { + std["typing.AsyncIterator"] = importSpec{Module: "typing", Name: "AsyncIterator"} + } } queryValueModelImports(q.Ret) for _, qv := range q.Args { diff --git a/internal/config/config.go b/internal/config/config.go index 54001049e8..bdfba2d28d 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -130,10 +130,12 @@ type SQLKotlin struct { } type SQLPython struct { - EmitExactTableNames bool `json:"emit_exact_table_names" yaml:"emit_exact_table_names"` - Package string `json:"package" yaml:"package"` - Out string `json:"out" yaml:"out"` - Overrides []Override `json:"overrides,omitempty" yaml:"overrides"` + EmitExactTableNames bool `json:"emit_exact_table_names" yaml:"emit_exact_table_names"` + EmitSyncQuerier bool `json:"emit_sync_querier" yaml:"emit_sync_querier"` + EmitAsyncQuerier bool `json:"emit_async_querier" yaml:"emit_async_querier"` + Package string `json:"package" yaml:"package"` + Out string `json:"out" yaml:"out"` + Overrides []Override `json:"overrides,omitempty" yaml:"overrides"` } type Override struct { @@ -229,7 +231,8 @@ var ErrUnknownEngine = errors.New("invalid engine") var ErrNoPackages = errors.New("no packages") var ErrNoPackageName = errors.New("missing package name") var ErrNoPackagePath = errors.New("missing package path") -var ErrKotlinNoOutPath = errors.New("no output path") +var ErrNoOutPath = errors.New("no output path") +var ErrNoQuerierType = errors.New("no querier emit type enabled") func ParseConfig(rd io.Reader) (Config, error) { var buf bytes.Buffer diff --git a/internal/config/v_two.go b/internal/config/v_two.go index 73a699dc7a..3f0db3cda4 100644 --- a/internal/config/v_two.go +++ b/internal/config/v_two.go @@ -53,13 +53,22 @@ func v2ParseConfig(rd io.Reader) (Config, error) { } if conf.SQL[j].Gen.Kotlin != nil { if conf.SQL[j].Gen.Kotlin.Out == "" { - return conf, ErrKotlinNoOutPath + return conf, ErrNoOutPath } if conf.SQL[j].Gen.Kotlin.Package == "" { return conf, ErrNoPackageName } } if conf.SQL[j].Gen.Python != nil { + if conf.SQL[j].Gen.Python.Out == "" { + return conf, ErrNoOutPath + } + if conf.SQL[j].Gen.Python.Package == "" { + return conf, ErrNoPackageName + } + if !conf.SQL[j].Gen.Python.EmitSyncQuerier && !conf.SQL[j].Gen.Python.EmitAsyncQuerier { + return conf, ErrNoQuerierType + } for i := range conf.SQL[j].Gen.Python.Overrides { if err := conf.SQL[j].Gen.Python.Overrides[i].Parse(); err != nil { return conf, err