From 34c1a5498a8f34f86d9650416ec39560c7e3b7c0 Mon Sep 17 00:00:00 2001 From: Baroukh Ovadia Date: Sun, 20 Aug 2023 07:06:10 -0700 Subject: [PATCH 1/3] Separate the scope of tables to allow removal of aliases --- internal/compiler/find_params.go | 41 ++++++++-- internal/compiler/parse.go | 6 +- internal/compiler/resolve.go | 81 +++++++++++++++++-- .../testdata/subquery_with_where/go/db.go | 31 +++++++ .../testdata/subquery_with_where/go/models.go | 19 +++++ .../subquery_with_where/go/query.sql.go | 53 ++++++++++++ .../testdata/subquery_with_where/query.sql | 9 +++ .../testdata/subquery_with_where/sqlc.json | 12 +++ internal/sql/astutils/walk.go | 6 +- 9 files changed, 244 insertions(+), 14 deletions(-) create mode 100644 internal/endtoend/testdata/subquery_with_where/go/db.go create mode 100644 internal/endtoend/testdata/subquery_with_where/go/models.go create mode 100644 internal/endtoend/testdata/subquery_with_where/go/query.sql.go create mode 100644 internal/endtoend/testdata/subquery_with_where/query.sql create mode 100644 internal/endtoend/testdata/subquery_with_where/sqlc.json diff --git a/internal/compiler/find_params.go b/internal/compiler/find_params.go index 41ffaf8ad7..656481fb85 100644 --- a/internal/compiler/find_params.go +++ b/internal/compiler/find_params.go @@ -10,7 +10,7 @@ import ( func findParameters(root ast.Node) ([]paramRef, error) { refs := make([]paramRef, 0) errors := make([]error, 0) - v := paramSearch{seen: make(map[int]struct{}), refs: &refs, errs: &errors} + v := paramSearch{seen: make(map[int]struct{}), refs: &refs, errs: &errors, rvs: &[]*ast.RangeVar{}} astutils.Walk(v, root) if len(*v.errs) > 0 { problems := *v.errs @@ -22,6 +22,7 @@ func findParameters(root ast.Node) ([]paramRef, error) { type paramRef struct { parent ast.Node + rvs []*ast.RangeVar rv *ast.RangeVar ref *ast.ParamRef name string // Named parameter support @@ -31,6 +32,7 @@ type paramSearch struct { parent ast.Node rangeVar *ast.RangeVar refs *[]paramRef + rvs *[]*ast.RangeVar seen map[int]struct{} errs *[]error @@ -58,6 +60,7 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor { return p } + var reset bool switch n := node.(type) { case *ast.A_Expr: @@ -70,6 +73,7 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor { p.parent = n.FuncCall case *ast.DeleteStmt: + reset = true if n.LimitCount != nil { p.limitCount = n.LimitCount } @@ -78,7 +82,12 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor { p.parent = node case *ast.InsertStmt: + reset = true + if n.Relation != nil { + *p.rvs = append(*p.rvs, n.Relation) + } if s, ok := n.SelectStmt.(*ast.SelectStmt); ok { + *p.rvs = append(*p.rvs, toTables(s.FromClause)...) for i, item := range s.TargetList.Items { target, ok := item.(*ast.ResTarget) if !ok { @@ -92,7 +101,7 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor { *p.errs = append(*p.errs, fmt.Errorf("INSERT has more expressions than target columns")) return p } - *p.refs = append(*p.refs, paramRef{parent: n.Cols.Items[i], ref: ref, rv: n.Relation}) + *p.refs = append(*p.refs, paramRef{parent: n.Cols.Items[i], ref: ref, rv: n.Relation, rvs: *p.rvs}) p.seen[ref.Location] = struct{}{} } for _, item := range s.ValuesLists.Items { @@ -109,13 +118,16 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor { *p.errs = append(*p.errs, fmt.Errorf("INSERT has more expressions than target columns")) return p } - *p.refs = append(*p.refs, paramRef{parent: n.Cols.Items[i], ref: ref, rv: n.Relation}) + *p.refs = append(*p.refs, paramRef{parent: n.Cols.Items[i], ref: ref, rv: n.Relation, rvs: *p.rvs}) p.seen[ref.Location] = struct{}{} } } } case *ast.UpdateStmt: + reset = true + *p.rvs = append(*p.rvs, toTables(n.FromClause)...) + *p.rvs = append(*p.rvs, toTables(n.Relations)...) for _, item := range n.TargetList.Items { target, ok := item.(*ast.ResTarget) if !ok { @@ -130,7 +142,7 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor { if !ok { continue } - *p.refs = append(*p.refs, paramRef{parent: target, ref: ref, rv: rv}) + *p.refs = append(*p.refs, paramRef{parent: target, ref: ref, rv: rv, rvs: *p.rvs}) } p.seen[ref.Location] = struct{}{} } @@ -139,12 +151,16 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor { } case *ast.RangeVar: + if n != nil { + *p.rvs = append(*p.rvs, n) + } p.rangeVar = n case *ast.ResTarget: p.parent = node case *ast.SelectStmt: + reset = true if n.LimitCount != nil { p.limitCount = n.LimitCount } @@ -191,7 +207,7 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor { } if set { - *p.refs = append(*p.refs, paramRef{parent: parent, ref: n, rv: p.rangeVar}) + *p.refs = append(*p.refs, paramRef{parent: parent, ref: n, rv: p.rangeVar, rvs: *p.rvs}) p.seen[n.Location] = struct{}{} } return nil @@ -215,5 +231,20 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor { p.Visit(n.Expr) } } + if reset { + rvs := *p.rvs + return paramSearch{seen: p.seen, refs: p.refs, errs: p.errs, rvs: &rvs, parent: p.parent, rangeVar: p.rangeVar, limitCount: p.limitCount, limitOffset: p.limitOffset} + } return p } + +func toTables(tbl *ast.List) []*ast.RangeVar { + tables := make([]*ast.RangeVar, len(tbl.Items)) + for _, t := range tbl.Items { + item, ok := t.(*ast.RangeVar) + if ok && item != nil { + tables = append(tables, item) + } + } + return tables +} diff --git a/internal/compiler/parse.go b/internal/compiler/parse.go index 8354bd340a..7f9d3badf5 100644 --- a/internal/compiler/parse.go +++ b/internal/compiler/parse.go @@ -95,7 +95,11 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query, return nil, err } - params, err := c.resolveCatalogRefs(qc, rvs, refs, namedParams, embeds) + err = c.resolveCatalogEmbeds(qc, rvs, embeds) + if err != nil { + return nil, err + } + params, err := c.resolveCatalogRefs(qc, refs, namedParams) if err != nil { return nil, err } diff --git a/internal/compiler/resolve.go b/internal/compiler/resolve.go index 0a91b45f25..b5f8acf9dc 100644 --- a/internal/compiler/resolve.go +++ b/internal/compiler/resolve.go @@ -20,7 +20,7 @@ func dataType(n *ast.TypeName) string { } } -func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, args []paramRef, params *named.ParamSet, embeds rewrite.EmbedSet) ([]Parameter, error) { +func (comp *Compiler) resolveCatalogEmbeds(qc *QueryCatalog, rvs []*ast.RangeVar, embeds rewrite.EmbedSet) error { c := comp.catalog aliasMap := map[string]*ast.TableName{} @@ -55,7 +55,7 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, } fqn, err := ParseTableName(rv) if err != nil { - return nil, err + return err } if _, found := aliasMap[fqn.Name]; found { continue @@ -64,13 +64,13 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, if err != nil { // If the table name doesn't exist, fisrt check if it's a CTE if _, qcerr := qc.GetTable(fqn); qcerr != nil { - return nil, err + return err } continue } err = indexTable(table) if err != nil { - return nil, err + return err } if rv.Alias != nil { aliasMap[*rv.Alias.Aliasname] = fqn @@ -90,11 +90,71 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, continue } - return nil, fmt.Errorf("unable to resolve table with %q: %w", embed.Orig(), err) + return fmt.Errorf("unable to resolve table with %q: %w", embed.Orig(), err) } + return nil +} +func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, args []paramRef, params *named.ParamSet) ([]Parameter, error) { + c := comp.catalog + + // resolve a table for an embed var a []Parameter for _, ref := range args { + aliasMap := map[string]*ast.TableName{} + // TODO: Deprecate defaultTable + var defaultTable *ast.TableName + var tables []*ast.TableName + + typeMap := map[string]map[string]map[string]*catalog.Column{} + indexTable := func(table catalog.Table) error { + tables = append(tables, table.Rel) + if defaultTable == nil { + defaultTable = table.Rel + } + schema := table.Rel.Schema + if schema == "" { + schema = c.DefaultSchema + } + if _, exists := typeMap[schema]; !exists { + typeMap[schema] = map[string]map[string]*catalog.Column{} + } + typeMap[schema][table.Rel.Name] = map[string]*catalog.Column{} + for _, c := range table.Columns { + cc := c + typeMap[schema][table.Rel.Name][c.Name] = cc + } + return nil + } + + for _, rv := range ref.rvs { + if rv == nil || rv.Relname == nil { + continue + } + fqn, err := ParseTableName(rv) + if err != nil { + return nil, err + } + if _, found := aliasMap[fqn.Name]; found { + continue + } + table, err := c.GetTable(fqn) + if err != nil { + // If the table name doesn't exist, fisrt check if it's a CTE + if _, qcerr := qc.GetTable(fqn); qcerr != nil { + return nil, err + } + continue + } + err = indexTable(table) + if err != nil { + return nil, err + } + if rv.Alias != nil { + aliasMap[*rv.Alias.Aliasname] = fqn + } + } + switch n := ref.parent.(type) { case *limitOffset: @@ -196,7 +256,12 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, } var found int + seenTable := make(map[string]bool, len(search)) for _, table := range search { + if seenTable[table.Name] { + continue + } + seenTable[table.Name] = true schema := table.Schema if schema == "" { schema = c.DefaultSchema @@ -236,6 +301,7 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, } } if found > 1 { + fmt.Println("ambiguous 3") return nil, &sqlerr.Error{ Code: "42703", Message: fmt.Sprintf("column reference %q is ambiguous", key), @@ -551,7 +617,12 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, } } + seenTables := make(map[string]bool, len(search)) for _, table := range search { + if seenTables[table.Name] { + continue + } + seenTables[table.Name] = true schema := table.Schema if schema == "" { schema = c.DefaultSchema diff --git a/internal/endtoend/testdata/subquery_with_where/go/db.go b/internal/endtoend/testdata/subquery_with_where/go/db.go new file mode 100644 index 0000000000..57406b68e8 --- /dev/null +++ b/internal/endtoend/testdata/subquery_with_where/go/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.20.0 + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/subquery_with_where/go/models.go b/internal/endtoend/testdata/subquery_with_where/go/models.go new file mode 100644 index 0000000000..3fa48ca789 --- /dev/null +++ b/internal/endtoend/testdata/subquery_with_where/go/models.go @@ -0,0 +1,19 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.20.0 + +package querytest + +import ( + "database/sql" +) + +type Bar struct { + A int32 + Alias sql.NullString +} + +type Foo struct { + A int32 + Name sql.NullString +} diff --git a/internal/endtoend/testdata/subquery_with_where/go/query.sql.go b/internal/endtoend/testdata/subquery_with_where/go/query.sql.go new file mode 100644 index 0000000000..d6db500c95 --- /dev/null +++ b/internal/endtoend/testdata/subquery_with_where/go/query.sql.go @@ -0,0 +1,53 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.20.0 +// source: query.sql + +package querytest + +import ( + "context" + "database/sql" +) + +const subquery = `-- name: Subquery :many +SELECT + a, + name, + (SELECT alias FROM bar WHERE bar.a=foo.a AND alias = $1 ORDER BY bar.a DESC limit 1) as alias +FROM FOO WHERE a = $2 +` + +type SubqueryParams struct { + Alias sql.NullString + A int32 +} + +type SubqueryRow struct { + A int32 + Name sql.NullString + Alias sql.NullString +} + +func (q *Queries) Subquery(ctx context.Context, arg SubqueryParams) ([]SubqueryRow, error) { + rows, err := q.db.QueryContext(ctx, subquery, arg.Alias, arg.A) + if err != nil { + return nil, err + } + defer rows.Close() + var items []SubqueryRow + for rows.Next() { + var i SubqueryRow + if err := rows.Scan(&i.A, &i.Name, &i.Alias); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/endtoend/testdata/subquery_with_where/query.sql b/internal/endtoend/testdata/subquery_with_where/query.sql new file mode 100644 index 0000000000..12e6dfaf3f --- /dev/null +++ b/internal/endtoend/testdata/subquery_with_where/query.sql @@ -0,0 +1,9 @@ +CREATE TABLE foo (a int not null, name text); +CREATE TABLE bar (a int not null, alias text); + +-- name: Subquery :many +SELECT + a, + name, + (SELECT alias FROM bar WHERE bar.a=foo.a AND alias = $1 ORDER BY bar.a DESC limit 1) as alias +FROM FOO WHERE a = $2; diff --git a/internal/endtoend/testdata/subquery_with_where/sqlc.json b/internal/endtoend/testdata/subquery_with_where/sqlc.json new file mode 100644 index 0000000000..c72b6132d5 --- /dev/null +++ b/internal/endtoend/testdata/subquery_with_where/sqlc.json @@ -0,0 +1,12 @@ +{ + "version": "1", + "packages": [ + { + "path": "go", + "engine": "postgresql", + "name": "querytest", + "schema": "query.sql", + "queries": "query.sql" + } + ] +} diff --git a/internal/sql/astutils/walk.go b/internal/sql/astutils/walk.go index 9f26617ad3..149403601c 100644 --- a/internal/sql/astutils/walk.go +++ b/internal/sql/astutils/walk.go @@ -1818,6 +1818,9 @@ func Walk(f Visitor, node ast.Node) { } case *ast.SelectStmt: + if n.FromClause != nil { + Walk(f, n.FromClause) + } if n.DistinctClause != nil { Walk(f, n.DistinctClause) } @@ -1827,9 +1830,6 @@ func Walk(f Visitor, node ast.Node) { if n.TargetList != nil { Walk(f, n.TargetList) } - if n.FromClause != nil { - Walk(f, n.FromClause) - } if n.WhereClause != nil { Walk(f, n.WhereClause) } From 6361797d598ba954801d58c6f6e75f710044159f Mon Sep 17 00:00:00 2001 From: Baroukh Ovadia Date: Sun, 20 Aug 2023 07:32:12 -0700 Subject: [PATCH 2/3] Remove duplicate additions of tables --- internal/compiler/find_params.go | 15 ++++++++------- internal/compiler/resolve.go | 10 ---------- internal/sql/astutils/walk.go | 6 +++--- 3 files changed, 11 insertions(+), 20 deletions(-) diff --git a/internal/compiler/find_params.go b/internal/compiler/find_params.go index 656481fb85..05fa1af187 100644 --- a/internal/compiler/find_params.go +++ b/internal/compiler/find_params.go @@ -83,11 +83,12 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor { case *ast.InsertStmt: reset = true + rvs := *p.rvs if n.Relation != nil { - *p.rvs = append(*p.rvs, n.Relation) + rvs = append(rvs, n.Relation) } if s, ok := n.SelectStmt.(*ast.SelectStmt); ok { - *p.rvs = append(*p.rvs, toTables(s.FromClause)...) + rvs = append(rvs, toTables(s.FromClause)...) for i, item := range s.TargetList.Items { target, ok := item.(*ast.ResTarget) if !ok { @@ -101,7 +102,7 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor { *p.errs = append(*p.errs, fmt.Errorf("INSERT has more expressions than target columns")) return p } - *p.refs = append(*p.refs, paramRef{parent: n.Cols.Items[i], ref: ref, rv: n.Relation, rvs: *p.rvs}) + *p.refs = append(*p.refs, paramRef{parent: n.Cols.Items[i], ref: ref, rv: n.Relation, rvs: rvs}) p.seen[ref.Location] = struct{}{} } for _, item := range s.ValuesLists.Items { @@ -118,7 +119,7 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor { *p.errs = append(*p.errs, fmt.Errorf("INSERT has more expressions than target columns")) return p } - *p.refs = append(*p.refs, paramRef{parent: n.Cols.Items[i], ref: ref, rv: n.Relation, rvs: *p.rvs}) + *p.refs = append(*p.refs, paramRef{parent: n.Cols.Items[i], ref: ref, rv: n.Relation, rvs: rvs}) p.seen[ref.Location] = struct{}{} } } @@ -126,8 +127,8 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor { case *ast.UpdateStmt: reset = true - *p.rvs = append(*p.rvs, toTables(n.FromClause)...) - *p.rvs = append(*p.rvs, toTables(n.Relations)...) + rvs := append(*p.rvs, toTables(n.FromClause)...) + rvs = append(rvs, toTables(n.Relations)...) for _, item := range n.TargetList.Items { target, ok := item.(*ast.ResTarget) if !ok { @@ -142,7 +143,7 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor { if !ok { continue } - *p.refs = append(*p.refs, paramRef{parent: target, ref: ref, rv: rv, rvs: *p.rvs}) + *p.refs = append(*p.refs, paramRef{parent: target, ref: ref, rv: rv, rvs: rvs}) } p.seen[ref.Location] = struct{}{} } diff --git a/internal/compiler/resolve.go b/internal/compiler/resolve.go index b5f8acf9dc..7f6a9f8675 100644 --- a/internal/compiler/resolve.go +++ b/internal/compiler/resolve.go @@ -256,12 +256,7 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, args []paramRef, para } var found int - seenTable := make(map[string]bool, len(search)) for _, table := range search { - if seenTable[table.Name] { - continue - } - seenTable[table.Name] = true schema := table.Schema if schema == "" { schema = c.DefaultSchema @@ -617,12 +612,7 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, args []paramRef, para } } - seenTables := make(map[string]bool, len(search)) for _, table := range search { - if seenTables[table.Name] { - continue - } - seenTables[table.Name] = true schema := table.Schema if schema == "" { schema = c.DefaultSchema diff --git a/internal/sql/astutils/walk.go b/internal/sql/astutils/walk.go index 149403601c..f21e916975 100644 --- a/internal/sql/astutils/walk.go +++ b/internal/sql/astutils/walk.go @@ -2032,15 +2032,15 @@ func Walk(f Visitor, node ast.Node) { if n.Relations != nil { Walk(f, n.Relations) } + if n.FromClause != nil { + Walk(f, n.FromClause) + } if n.TargetList != nil { Walk(f, n.TargetList) } if n.WhereClause != nil { Walk(f, n.WhereClause) } - if n.FromClause != nil { - Walk(f, n.FromClause) - } if n.LimitCount != nil { Walk(f, n.LimitCount) } From 5d171a018fe7605811034e5ef9e4b89d174ccb1b Mon Sep 17 00:00:00 2001 From: Baroukh Ovadia Date: Sun, 20 Aug 2023 10:47:07 -0700 Subject: [PATCH 3/3] remove comment --- internal/compiler/resolve.go | 1 - 1 file changed, 1 deletion(-) diff --git a/internal/compiler/resolve.go b/internal/compiler/resolve.go index 7f6a9f8675..d706aef31b 100644 --- a/internal/compiler/resolve.go +++ b/internal/compiler/resolve.go @@ -296,7 +296,6 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, args []paramRef, para } } if found > 1 { - fmt.Println("ambiguous 3") return nil, &sqlerr.Error{ Code: "42703", Message: fmt.Sprintf("column reference %q is ambiguous", key),