Skip to content

Commit 04c6c90

Browse files
authored
Pydantic models config field (#1538)
* add pydantic config field * pydantic import plus test * update based on comments
1 parent f822644 commit 04c6c90

File tree

14 files changed

+443
-184
lines changed

14 files changed

+443
-184
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
/.idea/
1+
/.idea/
2+
__pycache__

internal/cmd/shim.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ func pluginPythonCode(s config.SQLPython) *plugin.PythonCode {
6969
EmitExactTableNames: s.EmitExactTableNames,
7070
EmitSyncQuerier: s.EmitSyncQuerier,
7171
EmitAsyncQuerier: s.EmitAsyncQuerier,
72+
EmitPydanticModels: s.EmitPydanticModels,
7273
}
7374
}
7475

internal/codegen/python/gen.go

Lines changed: 53 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,26 @@ func dataclassNode(name string) *pyast.ClassDef {
585585
}
586586
}
587587

588+
func pydanticNode(name string) *pyast.ClassDef {
589+
return &pyast.ClassDef{
590+
Name: name,
591+
Bases: []*pyast.Node{
592+
{
593+
Node: &pyast.Node_Attribute{
594+
Attribute: &pyast.Attribute{
595+
Value: &pyast.Node{
596+
Node: &pyast.Node_Name{
597+
Name: &pyast.Name{Id: "pydantic"},
598+
},
599+
},
600+
Attr: "BaseModel",
601+
},
602+
},
603+
},
604+
},
605+
}
606+
}
607+
588608
func fieldNode(f Field) *pyast.Node {
589609
return &pyast.Node{
590610
Node: &pyast.Node_AnnAssign{
@@ -692,7 +712,12 @@ func buildModelsTree(ctx *pyTmplCtx, i *importer) *pyast.Node {
692712
}
693713

694714
for _, m := range ctx.Models {
695-
def := dataclassNode(m.Name)
715+
var def *pyast.ClassDef
716+
if ctx.EmitPydanticModels {
717+
def = pydanticNode(m.Name)
718+
} else {
719+
def = dataclassNode(m.Name)
720+
}
696721
if m.Comment != "" {
697722
def.Body = append(def.Body, &pyast.Node{
698723
Node: &pyast.Node_Expr{
@@ -822,15 +847,25 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node {
822847
mod.Body = append(mod.Body, assignNode(q.ConstantName, poet.Constant(queryText)))
823848
for _, arg := range q.Args {
824849
if arg.EmitStruct() {
825-
def := dataclassNode(arg.Struct.Name)
850+
var def *pyast.ClassDef
851+
if ctx.EmitPydanticModels {
852+
def = pydanticNode(arg.Struct.Name)
853+
} else {
854+
def = dataclassNode(arg.Struct.Name)
855+
}
826856
for _, f := range arg.Struct.Fields {
827857
def.Body = append(def.Body, fieldNode(f))
828858
}
829859
mod.Body = append(mod.Body, poet.Node(def))
830860
}
831861
}
832862
if q.Ret.EmitStruct() {
833-
def := dataclassNode(q.Ret.Struct.Name)
863+
var def *pyast.ClassDef
864+
if ctx.EmitPydanticModels {
865+
def = pydanticNode(q.Ret.Struct.Name)
866+
} else {
867+
def = dataclassNode(q.Ret.Struct.Name)
868+
}
834869
for _, f := range q.Ret.Struct.Fields {
835870
def.Body = append(def.Body, fieldNode(f))
836871
}
@@ -1027,13 +1062,14 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node {
10271062
}
10281063

10291064
type pyTmplCtx struct {
1030-
Models []Struct
1031-
Queries []Query
1032-
Enums []Enum
1033-
EmitSync bool
1034-
EmitAsync bool
1035-
SourceName string
1036-
SqlcVersion string
1065+
Models []Struct
1066+
Queries []Query
1067+
Enums []Enum
1068+
EmitSync bool
1069+
EmitAsync bool
1070+
SourceName string
1071+
SqlcVersion string
1072+
EmitPydanticModels bool
10371073
}
10381074

10391075
func (t *pyTmplCtx) OutputQuery(sourceName string) bool {
@@ -1060,12 +1096,13 @@ func Generate(req *plugin.CodeGenRequest) (*plugin.CodeGenResponse, error) {
10601096
}
10611097

10621098
tctx := pyTmplCtx{
1063-
Models: models,
1064-
Queries: queries,
1065-
Enums: enums,
1066-
EmitSync: req.Settings.Python.EmitSyncQuerier,
1067-
EmitAsync: req.Settings.Python.EmitAsyncQuerier,
1068-
SqlcVersion: req.SqlcVersion,
1099+
Models: models,
1100+
Queries: queries,
1101+
Enums: enums,
1102+
EmitSync: req.Settings.Python.EmitSyncQuerier,
1103+
EmitAsync: req.Settings.Python.EmitAsyncQuerier,
1104+
SqlcVersion: req.SqlcVersion,
1105+
EmitPydanticModels: req.Settings.Python.EmitPydanticModels,
10691106
}
10701107

10711108
output := map[string]string{}

internal/codegen/python/imports.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,11 @@ func (i *importer) modelImportSpecs() (map[string]importSpec, map[string]importS
9999
}
100100

101101
std := stdImports(modelUses)
102-
std["dataclasses"] = importSpec{Module: "dataclasses"}
102+
if i.Settings.Python.EmitPydanticModels {
103+
std["pydantic"] = importSpec{Module: "pydantic"}
104+
} else {
105+
std["dataclasses"] = importSpec{Module: "dataclasses"}
106+
}
103107
if len(i.Enums) > 0 {
104108
std["enum"] = importSpec{Module: "enum"}
105109
}
@@ -162,7 +166,11 @@ func (i *importer) queryImportSpecs(fileName string) (map[string]importSpec, map
162166

163167
queryValueModelImports := func(qv QueryValue) {
164168
if qv.IsStruct() && qv.EmitStruct() {
165-
std["dataclasses"] = importSpec{Module: "dataclasses"}
169+
if i.Settings.Python.EmitPydanticModels {
170+
std["pydantic"] = importSpec{Module: "pydantic"}
171+
} else {
172+
std["dataclasses"] = importSpec{Module: "dataclasses"}
173+
}
166174
}
167175
}
168176

internal/config/config.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ type SQLPython struct {
152152
Package string `json:"package" yaml:"package"`
153153
Out string `json:"out" yaml:"out"`
154154
Overrides []Override `json:"overrides,omitempty" yaml:"overrides"`
155+
EmitPydanticModels bool `json:"emit_pydantic_models,omitempty" yaml:"emit_pydantic_models"`
155156
}
156157

157158
type Override struct {

internal/endtoend/testdata/emit_pydantic_models/postgresql/__init__.py

Whitespace-only changes.
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Code generated by sqlc. DO NOT EDIT.
2+
# versions:
3+
# sqlc v1.13.0
4+
import pydantic
5+
from typing import Optional
6+
7+
8+
class Author(pydantic.BaseModel):
9+
id: int
10+
name: str
11+
bio: Optional[str]
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Code generated by sqlc. DO NOT EDIT.
2+
# versions:
3+
# sqlc v1.13.0
4+
# source: query.sql
5+
from typing import AsyncIterator, Iterator, Optional
6+
7+
import sqlalchemy
8+
import sqlalchemy.ext.asyncio
9+
10+
from postgresql import models
11+
12+
13+
CREATE_AUTHOR = """-- name: create_author \\:one
14+
INSERT INTO authors (
15+
name, bio
16+
) VALUES (
17+
:p1, :p2
18+
)
19+
RETURNING id, name, bio
20+
"""
21+
22+
23+
DELETE_AUTHOR = """-- name: delete_author \\:exec
24+
DELETE FROM authors
25+
WHERE id = :p1
26+
"""
27+
28+
29+
GET_AUTHOR = """-- name: get_author \\:one
30+
SELECT id, name, bio FROM authors
31+
WHERE id = :p1 LIMIT 1
32+
"""
33+
34+
35+
LIST_AUTHORS = """-- name: list_authors \\:many
36+
SELECT id, name, bio FROM authors
37+
ORDER BY name
38+
"""
39+
40+
41+
class Querier:
42+
def __init__(self, conn: sqlalchemy.engine.Connection):
43+
self._conn = conn
44+
45+
def create_author(self, *, name: str, bio: Optional[str]) -> Optional[models.Author]:
46+
row = self._conn.execute(sqlalchemy.text(CREATE_AUTHOR), {"p1": name, "p2": bio}).first()
47+
if row is None:
48+
return None
49+
return models.Author(
50+
id=row[0],
51+
name=row[1],
52+
bio=row[2],
53+
)
54+
55+
def delete_author(self, *, id: int) -> None:
56+
self._conn.execute(sqlalchemy.text(DELETE_AUTHOR), {"p1": id})
57+
58+
def get_author(self, *, id: int) -> Optional[models.Author]:
59+
row = self._conn.execute(sqlalchemy.text(GET_AUTHOR), {"p1": id}).first()
60+
if row is None:
61+
return None
62+
return models.Author(
63+
id=row[0],
64+
name=row[1],
65+
bio=row[2],
66+
)
67+
68+
def list_authors(self) -> Iterator[models.Author]:
69+
result = self._conn.execute(sqlalchemy.text(LIST_AUTHORS))
70+
for row in result:
71+
yield models.Author(
72+
id=row[0],
73+
name=row[1],
74+
bio=row[2],
75+
)
76+
77+
78+
class AsyncQuerier:
79+
def __init__(self, conn: sqlalchemy.ext.asyncio.AsyncConnection):
80+
self._conn = conn
81+
82+
async def create_author(self, *, name: str, bio: Optional[str]) -> Optional[models.Author]:
83+
row = (await self._conn.execute(sqlalchemy.text(CREATE_AUTHOR), {"p1": name, "p2": bio})).first()
84+
if row is None:
85+
return None
86+
return models.Author(
87+
id=row[0],
88+
name=row[1],
89+
bio=row[2],
90+
)
91+
92+
async def delete_author(self, *, id: int) -> None:
93+
await self._conn.execute(sqlalchemy.text(DELETE_AUTHOR), {"p1": id})
94+
95+
async def get_author(self, *, id: int) -> Optional[models.Author]:
96+
row = (await self._conn.execute(sqlalchemy.text(GET_AUTHOR), {"p1": id})).first()
97+
if row is None:
98+
return None
99+
return models.Author(
100+
id=row[0],
101+
name=row[1],
102+
bio=row[2],
103+
)
104+
105+
async def list_authors(self) -> AsyncIterator[models.Author]:
106+
result = await self._conn.stream(sqlalchemy.text(LIST_AUTHORS))
107+
async for row in result:
108+
yield models.Author(
109+
id=row[0],
110+
name=row[1],
111+
bio=row[2],
112+
)
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
-- name: GetAuthor :one
2+
SELECT * FROM authors
3+
WHERE id = $1 LIMIT 1;
4+
5+
-- name: ListAuthors :many
6+
SELECT * FROM authors
7+
ORDER BY name;
8+
9+
-- name: CreateAuthor :one
10+
INSERT INTO authors (
11+
name, bio
12+
) VALUES (
13+
$1, $2
14+
)
15+
RETURNING *;
16+
17+
-- name: DeleteAuthor :exec
18+
DELETE FROM authors
19+
WHERE id = $1;
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
CREATE TABLE authors (
2+
id BIGSERIAL PRIMARY KEY,
3+
name text NOT NULL,
4+
bio text
5+
);
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
{
2+
"version": "2",
3+
"sql": [
4+
{
5+
"schema": "postgresql/schema.sql",
6+
"queries": "postgresql/query.sql",
7+
"engine": "postgresql",
8+
"gen": {
9+
"python": {
10+
"out": "postgresql",
11+
"package": "postgresql",
12+
"emit_sync_querier": true,
13+
"emit_async_querier": true,
14+
"emit_pydantic_models": true
15+
}
16+
}
17+
}
18+
]
19+
}

0 commit comments

Comments
 (0)