Skip to content

Pydantic models config field #1538

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
/.idea/
/.idea/
__pycache__
1 change: 1 addition & 0 deletions internal/cmd/shim.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ func pluginPythonCode(s config.SQLPython) *plugin.PythonCode {
EmitExactTableNames: s.EmitExactTableNames,
EmitSyncQuerier: s.EmitSyncQuerier,
EmitAsyncQuerier: s.EmitAsyncQuerier,
EmitPydanticModels: s.EmitPydanticModels,
}
}

Expand Down
69 changes: 53 additions & 16 deletions internal/codegen/python/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,26 @@ func dataclassNode(name string) *pyast.ClassDef {
}
}

func pydanticNode(name string) *pyast.ClassDef {
return &pyast.ClassDef{
Name: name,
Bases: []*pyast.Node{
{
Node: &pyast.Node_Attribute{
Attribute: &pyast.Attribute{
Value: &pyast.Node{
Node: &pyast.Node_Name{
Name: &pyast.Name{Id: "pydantic"},
},
},
Attr: "BaseModel",
},
},
},
},
}
}

func fieldNode(f Field) *pyast.Node {
return &pyast.Node{
Node: &pyast.Node_AnnAssign{
Expand Down Expand Up @@ -692,7 +712,12 @@ func buildModelsTree(ctx *pyTmplCtx, i *importer) *pyast.Node {
}

for _, m := range ctx.Models {
def := dataclassNode(m.Name)
var def *pyast.ClassDef
if ctx.EmitPydanticModels {
def = pydanticNode(m.Name)
} else {
def = dataclassNode(m.Name)
}
if m.Comment != "" {
def.Body = append(def.Body, &pyast.Node{
Node: &pyast.Node_Expr{
Expand Down Expand Up @@ -822,15 +847,25 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node {
mod.Body = append(mod.Body, assignNode(q.ConstantName, poet.Constant(queryText)))
for _, arg := range q.Args {
if arg.EmitStruct() {
def := dataclassNode(arg.Struct.Name)
var def *pyast.ClassDef
if ctx.EmitPydanticModels {
def = pydanticNode(arg.Struct.Name)
} else {
def = dataclassNode(arg.Struct.Name)
}
for _, f := range arg.Struct.Fields {
def.Body = append(def.Body, fieldNode(f))
}
mod.Body = append(mod.Body, poet.Node(def))
}
}
if q.Ret.EmitStruct() {
def := dataclassNode(q.Ret.Struct.Name)
var def *pyast.ClassDef
if ctx.EmitPydanticModels {
def = pydanticNode(q.Ret.Struct.Name)
} else {
def = dataclassNode(q.Ret.Struct.Name)
}
for _, f := range q.Ret.Struct.Fields {
def.Body = append(def.Body, fieldNode(f))
}
Expand Down Expand Up @@ -1027,13 +1062,14 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node {
}

type pyTmplCtx struct {
Models []Struct
Queries []Query
Enums []Enum
EmitSync bool
EmitAsync bool
SourceName string
SqlcVersion string
Models []Struct
Queries []Query
Enums []Enum
EmitSync bool
EmitAsync bool
SourceName string
SqlcVersion string
EmitPydanticModels bool
}

func (t *pyTmplCtx) OutputQuery(sourceName string) bool {
Expand All @@ -1060,12 +1096,13 @@ func Generate(req *plugin.CodeGenRequest) (*plugin.CodeGenResponse, error) {
}

tctx := pyTmplCtx{
Models: models,
Queries: queries,
Enums: enums,
EmitSync: req.Settings.Python.EmitSyncQuerier,
EmitAsync: req.Settings.Python.EmitAsyncQuerier,
SqlcVersion: req.SqlcVersion,
Models: models,
Queries: queries,
Enums: enums,
EmitSync: req.Settings.Python.EmitSyncQuerier,
EmitAsync: req.Settings.Python.EmitAsyncQuerier,
SqlcVersion: req.SqlcVersion,
EmitPydanticModels: req.Settings.Python.EmitPydanticModels,
}

output := map[string]string{}
Expand Down
12 changes: 10 additions & 2 deletions internal/codegen/python/imports.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,11 @@ func (i *importer) modelImportSpecs() (map[string]importSpec, map[string]importS
}

std := stdImports(modelUses)
std["dataclasses"] = importSpec{Module: "dataclasses"}
if i.Settings.Python.EmitPydanticModels {
std["pydantic"] = importSpec{Module: "pydantic"}
} else {
std["dataclasses"] = importSpec{Module: "dataclasses"}
}
if len(i.Enums) > 0 {
std["enum"] = importSpec{Module: "enum"}
}
Expand Down Expand Up @@ -162,7 +166,11 @@ func (i *importer) queryImportSpecs(fileName string) (map[string]importSpec, map

queryValueModelImports := func(qv QueryValue) {
if qv.IsStruct() && qv.EmitStruct() {
std["dataclasses"] = importSpec{Module: "dataclasses"}
if i.Settings.Python.EmitPydanticModels {
std["pydantic"] = importSpec{Module: "pydantic"}
} else {
std["dataclasses"] = importSpec{Module: "dataclasses"}
}
}
}

Expand Down
1 change: 1 addition & 0 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ type SQLPython struct {
Package string `json:"package" yaml:"package"`
Out string `json:"out" yaml:"out"`
Overrides []Override `json:"overrides,omitempty" yaml:"overrides"`
EmitPydanticModels bool `json:"emit_pydantic_models,omitempty" yaml:"emit_pydantic_models"`
}

type Override struct {
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Code generated by sqlc. DO NOT EDIT.
# versions:
# sqlc v1.13.0
import pydantic
from typing import Optional


class Author(pydantic.BaseModel):
id: int
name: str
bio: Optional[str]
112 changes: 112 additions & 0 deletions internal/endtoend/testdata/emit_pydantic_models/postgresql/query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Code generated by sqlc. DO NOT EDIT.
# versions:
# sqlc v1.13.0
# source: query.sql
from typing import AsyncIterator, Iterator, Optional

import sqlalchemy
import sqlalchemy.ext.asyncio

from postgresql import models


CREATE_AUTHOR = """-- name: create_author \\:one
INSERT INTO authors (
name, bio
) VALUES (
:p1, :p2
)
RETURNING id, name, bio
"""


DELETE_AUTHOR = """-- name: delete_author \\:exec
DELETE FROM authors
WHERE id = :p1
"""


GET_AUTHOR = """-- name: get_author \\:one
SELECT id, name, bio FROM authors
WHERE id = :p1 LIMIT 1
"""


LIST_AUTHORS = """-- name: list_authors \\:many
SELECT id, name, bio FROM authors
ORDER BY name
"""


class Querier:
def __init__(self, conn: sqlalchemy.engine.Connection):
self._conn = conn

def create_author(self, *, name: str, bio: Optional[str]) -> Optional[models.Author]:
row = self._conn.execute(sqlalchemy.text(CREATE_AUTHOR), {"p1": name, "p2": bio}).first()
if row is None:
return None
return models.Author(
id=row[0],
name=row[1],
bio=row[2],
)

def delete_author(self, *, id: int) -> None:
self._conn.execute(sqlalchemy.text(DELETE_AUTHOR), {"p1": id})

def get_author(self, *, id: int) -> Optional[models.Author]:
row = self._conn.execute(sqlalchemy.text(GET_AUTHOR), {"p1": id}).first()
if row is None:
return None
return models.Author(
id=row[0],
name=row[1],
bio=row[2],
)

def list_authors(self) -> Iterator[models.Author]:
result = self._conn.execute(sqlalchemy.text(LIST_AUTHORS))
for row in result:
yield models.Author(
id=row[0],
name=row[1],
bio=row[2],
)


class AsyncQuerier:
def __init__(self, conn: sqlalchemy.ext.asyncio.AsyncConnection):
self._conn = conn

async def create_author(self, *, name: str, bio: Optional[str]) -> Optional[models.Author]:
row = (await self._conn.execute(sqlalchemy.text(CREATE_AUTHOR), {"p1": name, "p2": bio})).first()
if row is None:
return None
return models.Author(
id=row[0],
name=row[1],
bio=row[2],
)

async def delete_author(self, *, id: int) -> None:
await self._conn.execute(sqlalchemy.text(DELETE_AUTHOR), {"p1": id})

async def get_author(self, *, id: int) -> Optional[models.Author]:
row = (await self._conn.execute(sqlalchemy.text(GET_AUTHOR), {"p1": id})).first()
if row is None:
return None
return models.Author(
id=row[0],
name=row[1],
bio=row[2],
)

async def list_authors(self) -> AsyncIterator[models.Author]:
result = await self._conn.stream(sqlalchemy.text(LIST_AUTHORS))
async for row in result:
yield models.Author(
id=row[0],
name=row[1],
bio=row[2],
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
-- name: GetAuthor :one
SELECT * FROM authors
WHERE id = $1 LIMIT 1;

-- name: ListAuthors :many
SELECT * FROM authors
ORDER BY name;

-- name: CreateAuthor :one
INSERT INTO authors (
name, bio
) VALUES (
$1, $2
)
RETURNING *;

-- name: DeleteAuthor :exec
DELETE FROM authors
WHERE id = $1;
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
CREATE TABLE authors (
id BIGSERIAL PRIMARY KEY,
name text NOT NULL,
bio text
);
19 changes: 19 additions & 0 deletions internal/endtoend/testdata/emit_pydantic_models/sqlc.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{
"version": "2",
"sql": [
{
"schema": "postgresql/schema.sql",
"queries": "postgresql/query.sql",
"engine": "postgresql",
"gen": {
"python": {
"out": "postgresql",
"package": "postgresql",
"emit_sync_querier": true,
"emit_async_querier": true,
"emit_pydantic_models": true
}
}
}
]
}
Loading