From 4b9ea97872a241fad38c221afad7a701fc987868 Mon Sep 17 00:00:00 2001 From: Matthew Jee Date: Fri, 23 Dec 2022 00:37:03 -0800 Subject: [PATCH 1/2] Fix incorrect column names for SELECT * preceded by a WITH In #1139 I introduced support for the `WITH` clause. My implementation contains a bug: the statement preparer doesn't produce the correct result column names for queries containing a `SELECT *` preceded by a `WITH`. For example, consider the following statement: ``` WITH temp AS ( SELECT id, email from users) SELECT * from temp ``` An error would be thrown when preparing this statement because the glob expansion procedure would try to look up the column names for the result by looking up the column names for the query `SELECT * from temp`. This does not work because `temp` is a temporary view defined in the `WITH` clause. To fix this, I modified the glob expansion procedure to include the `WITH` clause in the query used to look up the result column names. --- Sources/SQLite/Typed/Query.swift | 19 +++++++++++ .../Typed/QueryIntegrationTests.swift | 33 +++++++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/Sources/SQLite/Typed/Query.swift b/Sources/SQLite/Typed/Query.swift index 04665f39..feff2f69 100644 --- a/Sources/SQLite/Typed/Query.swift +++ b/Sources/SQLite/Typed/Query.swift @@ -1036,10 +1036,29 @@ extension Connection { let column = names.removeLast() let namespace = names.joined(separator: ".") + // Return a copy of the input "with" clause stripping all subclauses besides "select", "join", and "with". + func strip(_ with: WithClauses) -> WithClauses { + var stripped = WithClauses() + stripped.recursive = with.recursive + for subclause in with.clauses { + let query = subclause.query + var strippedQuery = type(of: query).init(query.clauses.from.name, database: query.clauses.from.database) + strippedQuery.clauses.select = query.clauses.select + strippedQuery.clauses.join = query.clauses.join + strippedQuery.clauses.with = strip(query.clauses.with) + + var strippedSubclause = WithClauses.Clause(alias: subclause.alias, query: strippedQuery) + strippedSubclause.columns = subclause.columns + stripped.clauses.append(strippedSubclause) + } + return stripped + } + func expandGlob(_ namespace: Bool) -> (QueryType) throws -> Void { { (queryType: QueryType) throws -> Void in var query = type(of: queryType).init(queryType.clauses.from.name, database: queryType.clauses.from.database) query.clauses.select = queryType.clauses.select + query.clauses.with = strip(queryType.clauses.with) let expression = query.expression var names = try self.prepare(expression.template, expression.bindings).columnNames.map { $0.quote() } if namespace { names = names.map { "\(queryType.tableName().expression.template).\($0)" } } diff --git a/Tests/SQLiteTests/Typed/QueryIntegrationTests.swift b/Tests/SQLiteTests/Typed/QueryIntegrationTests.swift index 3fd388e9..d99b8457 100644 --- a/Tests/SQLiteTests/Typed/QueryIntegrationTests.swift +++ b/Tests/SQLiteTests/Typed/QueryIntegrationTests.swift @@ -276,6 +276,39 @@ class QueryIntegrationTests: SQLiteTestCase { XCTAssertEqual(21, sum) } + + /// Verify that `*` is properly expanded in a SELECT statement following a WITH clause. + func test_with_glob_expansion() throws { + let names = Table("names") + let name = Expression("name") + try db.run(names.create { builder in + builder.column(email) + builder.column(name) + }) + + try db.run(users.insert(email <- "alice@example.com")) + try db.run(names.insert(email <- "alice@example.com", name <- "Alice")) + + // WITH intermediate AS ( SELECT ... ) SELECT * FROM intermediate + let intermediate = Table("intermediate") + let rows = try db.prepare( + intermediate + .with(intermediate, + as: users + .select([id, users[email], name]) + .join(names, on: names[email] == users[email]) + .where(users[email] == "alice@example.com") + )) + + // There should be at least one row in the result. + let row = try XCTUnwrap(rows.makeIterator().next()) + + // Verify the column names + XCTAssertEqual(row.columnNames.count, 3) + XCTAssertNotNil(row[id]) + XCTAssertNotNil(row[name]) + XCTAssertNotNil(row[email]) + } } extension Connection { From bdc3be7fdf9b6289ee0e0ea5c28002839d8cf54d Mon Sep 17 00:00:00 2001 From: Matthew Jee Date: Fri, 23 Dec 2022 00:45:34 -0800 Subject: [PATCH 2/2] linter error --- Tests/SQLiteTests/Typed/QueryIntegrationTests.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Tests/SQLiteTests/Typed/QueryIntegrationTests.swift b/Tests/SQLiteTests/Typed/QueryIntegrationTests.swift index d99b8457..f5105a09 100644 --- a/Tests/SQLiteTests/Typed/QueryIntegrationTests.swift +++ b/Tests/SQLiteTests/Typed/QueryIntegrationTests.swift @@ -302,7 +302,7 @@ class QueryIntegrationTests: SQLiteTestCase { // There should be at least one row in the result. let row = try XCTUnwrap(rows.makeIterator().next()) - + // Verify the column names XCTAssertEqual(row.columnNames.count, 3) XCTAssertNotNil(row[id])