Skip to content

Commit 4b9ea97

Browse files
committed
Fix incorrect column names for SELECT * preceded by a WITH
In stephencelis#1139 I introduced support for the `WITH` clause. My implementation contains a bug: the statement preparer doesn't produce the correct result column names for queries containing a `SELECT *` preceded by a `WITH`. For example, consider the following statement: ``` WITH temp AS ( SELECT id, email from users) SELECT * from temp ``` An error would be thrown when preparing this statement because the glob expansion procedure would try to look up the column names for the result by looking up the column names for the query `SELECT * from temp`. This does not work because `temp` is a temporary view defined in the `WITH` clause. To fix this, I modified the glob expansion procedure to include the `WITH` clause in the query used to look up the result column names.
1 parent f06b8df commit 4b9ea97

File tree

2 files changed

+52
-0
lines changed

2 files changed

+52
-0
lines changed

Sources/SQLite/Typed/Query.swift

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,10 +1036,29 @@ extension Connection {
10361036
let column = names.removeLast()
10371037
let namespace = names.joined(separator: ".")
10381038

1039+
// Return a copy of the input "with" clause stripping all subclauses besides "select", "join", and "with".
1040+
func strip(_ with: WithClauses) -> WithClauses {
1041+
var stripped = WithClauses()
1042+
stripped.recursive = with.recursive
1043+
for subclause in with.clauses {
1044+
let query = subclause.query
1045+
var strippedQuery = type(of: query).init(query.clauses.from.name, database: query.clauses.from.database)
1046+
strippedQuery.clauses.select = query.clauses.select
1047+
strippedQuery.clauses.join = query.clauses.join
1048+
strippedQuery.clauses.with = strip(query.clauses.with)
1049+
1050+
var strippedSubclause = WithClauses.Clause(alias: subclause.alias, query: strippedQuery)
1051+
strippedSubclause.columns = subclause.columns
1052+
stripped.clauses.append(strippedSubclause)
1053+
}
1054+
return stripped
1055+
}
1056+
10391057
func expandGlob(_ namespace: Bool) -> (QueryType) throws -> Void {
10401058
{ (queryType: QueryType) throws -> Void in
10411059
var query = type(of: queryType).init(queryType.clauses.from.name, database: queryType.clauses.from.database)
10421060
query.clauses.select = queryType.clauses.select
1061+
query.clauses.with = strip(queryType.clauses.with)
10431062
let expression = query.expression
10441063
var names = try self.prepare(expression.template, expression.bindings).columnNames.map { $0.quote() }
10451064
if namespace { names = names.map { "\(queryType.tableName().expression.template).\($0)" } }

Tests/SQLiteTests/Typed/QueryIntegrationTests.swift

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,39 @@ class QueryIntegrationTests: SQLiteTestCase {
276276

277277
XCTAssertEqual(21, sum)
278278
}
279+
280+
/// Verify that `*` is properly expanded in a SELECT statement following a WITH clause.
281+
func test_with_glob_expansion() throws {
282+
let names = Table("names")
283+
let name = Expression<String>("name")
284+
try db.run(names.create { builder in
285+
builder.column(email)
286+
builder.column(name)
287+
})
288+
289+
try db.run(users.insert(email <- "[email protected]"))
290+
try db.run(names.insert(email <- "[email protected]", name <- "Alice"))
291+
292+
// WITH intermediate AS ( SELECT ... ) SELECT * FROM intermediate
293+
let intermediate = Table("intermediate")
294+
let rows = try db.prepare(
295+
intermediate
296+
.with(intermediate,
297+
as: users
298+
.select([id, users[email], name])
299+
.join(names, on: names[email] == users[email])
300+
.where(users[email] == "[email protected]")
301+
))
302+
303+
// There should be at least one row in the result.
304+
let row = try XCTUnwrap(rows.makeIterator().next())
305+
306+
// Verify the column names
307+
XCTAssertEqual(row.columnNames.count, 3)
308+
XCTAssertNotNil(row[id])
309+
XCTAssertNotNil(row[name])
310+
XCTAssertNotNil(row[email])
311+
}
279312
}
280313

281314
extension Connection {

0 commit comments

Comments
 (0)