Skip to content

Commit ba5165a

Browse files
committed
fix insertMany() failing with encodables which have assymmetric optional values
1 parent 55f4565 commit ba5165a

File tree

4 files changed

+129
-12
lines changed

4 files changed

+129
-12
lines changed

Sources/SQLite/Typed/Coding.swift

Lines changed: 82 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,22 @@ extension QueryType {
8686
/// - Returns: An `INSERT` statement for the encodable objects
8787
public func insertMany(_ encodables: [Encodable], userInfo: [CodingUserInfoKey: Any] = [:],
8888
otherSetters: [Setter] = []) throws -> Insert {
89-
let combinedSetters = try encodables.map { encodable -> [Setter] in
90-
let encoder = SQLiteEncoder(userInfo: userInfo)
89+
let combinedSettersWithoutNils = try encodables.map { encodable -> [Setter] in
90+
let encoder = SQLiteEncoder(userInfo: userInfo, forcingNilValueSetters: false)
9191
try encodable.encode(to: encoder)
9292
return encoder.setters + otherSetters
9393
}
94-
return self.insertMany(combinedSetters)
94+
// requires the same number of setters per encodable
95+
guard Set(combinedSettersWithoutNils.map(\.count)).count == 1 else {
96+
// asymmetric sets of value insertions (some nil, some not), requires NULL value to satisfy INSERT query
97+
let combinedSymmetricSetters = try encodables.map { encodable -> [Setter] in
98+
let encoder = SQLiteEncoder(userInfo: userInfo, forcingNilValueSetters: true)
99+
try encodable.encode(to: encoder)
100+
return encoder.setters + otherSetters
101+
}
102+
return self.insertMany(combinedSymmetricSetters)
103+
}
104+
return self.insertMany(combinedSettersWithoutNils)
95105
}
96106

97107
/// Creates an `INSERT ON CONFLICT DO UPDATE` statement, aka upsert, by encoding the given object
@@ -165,9 +175,11 @@ private class SQLiteEncoder: Encoder {
165175

166176
let encoder: SQLiteEncoder
167177
let codingPath: [CodingKey] = []
178+
let forcingNilValueSetters: Bool
168179

169-
init(encoder: SQLiteEncoder) {
180+
init(encoder: SQLiteEncoder, forcingNilValueSetters: Bool = false) {
170181
self.encoder = encoder
182+
self.forcingNilValueSetters = forcingNilValueSetters
171183
}
172184

173185
func superEncoder() -> Swift.Encoder {
@@ -202,6 +214,46 @@ private class SQLiteEncoder: Encoder {
202214
encoder.setters.append(Expression(key.stringValue) <- value)
203215
}
204216

217+
func encodeIfPresent(_ value: Int?, forKey key: SQLiteEncoder.SQLiteKeyedEncodingContainer<Key>.Key) throws {
218+
if let value = value {
219+
encoder.setters.append(Expression(key.stringValue) <- value)
220+
} else if forcingNilValueSetters {
221+
encoder.setters.append(Expression<Int?>(key.stringValue) <- nil)
222+
}
223+
}
224+
225+
func encodeIfPresent(_ value: Bool?, forKey key: Key) throws {
226+
if let value = value {
227+
encoder.setters.append(Expression(key.stringValue) <- value)
228+
} else if forcingNilValueSetters {
229+
encoder.setters.append(Expression<Bool?>(key.stringValue) <- nil)
230+
}
231+
}
232+
233+
func encodeIfPresent(_ value: Float?, forKey key: Key) throws {
234+
if let value = value {
235+
encoder.setters.append(Expression(key.stringValue) <- Double(value))
236+
} else if forcingNilValueSetters{
237+
encoder.setters.append(Expression<Double?>(key.stringValue) <- nil)
238+
}
239+
}
240+
241+
func encodeIfPresent(_ value: Double?, forKey key: Key) throws {
242+
if let value = value {
243+
encoder.setters.append(Expression(key.stringValue) <- value)
244+
} else if forcingNilValueSetters {
245+
encoder.setters.append(Expression<Double?>(key.stringValue) <- nil)
246+
}
247+
}
248+
249+
func encodeIfPresent(_ value: String?, forKey key: MyKey) throws {
250+
if let value = value {
251+
encoder.setters.append(Expression(key.stringValue) <- value)
252+
} else if forcingNilValueSetters {
253+
encoder.setters.append(Expression<String?>(key.stringValue) <- nil)
254+
}
255+
}
256+
205257
func encode<T>(_ value: T, forKey key: Key) throws where T: Swift.Encodable {
206258
switch value {
207259
case let data as Data:
@@ -217,6 +269,28 @@ private class SQLiteEncoder: Encoder {
217269
}
218270
}
219271

272+
func encodeIfPresent<T>(_ value: T?, forKey key: Key) throws where T: Swift.Encodable {
273+
guard let value = value else {
274+
guard forcingNilValueSetters else {
275+
return
276+
}
277+
encoder.setters.append(Expression<String?>(key.stringValue) <- nil)
278+
return
279+
}
280+
switch value {
281+
case let data as Data:
282+
encoder.setters.append(Expression(key.stringValue) <- data)
283+
case let date as Date:
284+
encoder.setters.append(Expression(key.stringValue) <- date.datatypeValue)
285+
case let uuid as UUID:
286+
encoder.setters.append(Expression(key.stringValue) <- uuid.datatypeValue)
287+
default:
288+
let encoded = try JSONEncoder().encode(value)
289+
let string = String(data: encoded, encoding: .utf8)
290+
encoder.setters.append(Expression(key.stringValue) <- string)
291+
}
292+
}
293+
220294
func encode(_ value: Int8, forKey key: Key) throws {
221295
throw EncodingError.invalidValue(value, EncodingError.Context(codingPath: codingPath,
222296
debugDescription: "encoding an Int8 is not supported"))
@@ -274,9 +348,11 @@ private class SQLiteEncoder: Encoder {
274348
fileprivate var setters: [Setter] = []
275349
let codingPath: [CodingKey] = []
276350
let userInfo: [CodingUserInfoKey: Any]
351+
let forcingNilValueSetters: Bool
277352

278-
init(userInfo: [CodingUserInfoKey: Any]) {
353+
init(userInfo: [CodingUserInfoKey: Any], forcingNilValueSetters: Bool = false) {
279354
self.userInfo = userInfo
355+
self.forcingNilValueSetters = forcingNilValueSetters
280356
}
281357

282358
func singleValueContainer() -> SingleValueEncodingContainer {
@@ -288,7 +364,7 @@ private class SQLiteEncoder: Encoder {
288364
}
289365

290366
func container<Key>(keyedBy type: Key.Type) -> KeyedEncodingContainer<Key> where Key: CodingKey {
291-
KeyedEncodingContainer(SQLiteKeyedEncodingContainer(encoder: self))
367+
KeyedEncodingContainer(SQLiteKeyedEncodingContainer(encoder: self, forcingNilValueSetters: forcingNilValueSetters))
292368
}
293369
}
294370

Tests/SQLiteTests/QueryIntegrationTests.swift

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,27 @@ class QueryIntegrationTests: SQLiteTestCase {
130130
XCTAssertEqual(2, id)
131131
}
132132

133+
func test_insert_many_encodables() throws {
134+
let table = Table("codable")
135+
try db.run(table.create { builder in
136+
builder.column(Expression<Int?>("int"))
137+
builder.column(Expression<String?>("string"))
138+
builder.column(Expression<Bool?>("bool"))
139+
builder.column(Expression<Double?>("float"))
140+
builder.column(Expression<Double?>("double"))
141+
builder.column(Expression<Date?>("date"))
142+
builder.column(Expression<UUID?>("uuid"))
143+
})
144+
145+
let value1 = TestOptionalCodable(int: 5, string: "6", bool: true, float: 7, double: 8, date: Date(timeIntervalSince1970: 5000), uuid: testUUIDValue)
146+
let valueWithNils = TestOptionalCodable(int: nil, string: nil, bool: nil, float: nil, double: nil, date: nil, uuid: nil)
147+
try db.run(table.insertMany([value1, valueWithNils]))
148+
149+
let rows = try db.prepare(table)
150+
let values: [TestOptionalCodable] = try rows.map({ try $0.decode() })
151+
XCTAssertEqual(values.count, 2)
152+
}
153+
133154
func test_upsert() throws {
134155
try XCTSkipUnless(db.satisfiesMinimumVersion(minor: 24))
135156
let fetchAge = { () throws -> Int? in

Tests/SQLiteTests/QueryTests.swift

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -365,21 +365,21 @@ class QueryTests: XCTestCase {
365365
)
366366
}
367367

368-
func test_insert_many_encodable() throws {
368+
func test_insert_many_encodables() throws {
369369
let emails = Table("emails")
370370
let value1 = TestCodable(int: 1, string: "2", bool: true, float: 3, double: 4,
371371
date: Date(timeIntervalSince1970: 0), uuid: testUUIDValue, optional: nil, sub: nil)
372372
let value2 = TestCodable(int: 2, string: "3", bool: true, float: 3, double: 5,
373-
date: Date(timeIntervalSince1970: 0), uuid: testUUIDValue, optional: nil, sub: nil)
373+
date: Date(timeIntervalSince1970: 0), uuid: testUUIDValue, optional: "optional", sub: nil)
374374
let value3 = TestCodable(int: 3, string: "4", bool: true, float: 3, double: 6,
375375
date: Date(timeIntervalSince1970: 0), uuid: testUUIDValue, optional: nil, sub: nil)
376376
let insert = try emails.insertMany([value1, value2, value3])
377377
assertSQL(
378378
"""
379-
INSERT INTO \"emails\" (\"int\", \"string\", \"bool\", \"float\", \"double\", \"date\", \"uuid\")
380-
VALUES (1, '2', 1, 3.0, 4.0, '1970-01-01T00:00:00.000', 'E621E1F8-C36C-495A-93FC-0C247A3E6E5F'),
381-
(2, '3', 1, 3.0, 5.0, '1970-01-01T00:00:00.000', 'E621E1F8-C36C-495A-93FC-0C247A3E6E5F'),
382-
(3, '4', 1, 3.0, 6.0, '1970-01-01T00:00:00.000', 'E621E1F8-C36C-495A-93FC-0C247A3E6E5F')
379+
INSERT INTO \"emails\" (\"int\", \"string\", \"bool\", \"float\", \"double\", \"date\", \"uuid\", \"optional\", \"sub\")
380+
VALUES (1, '2', 1, 3.0, 4.0, '1970-01-01T00:00:00.000', 'E621E1F8-C36C-495A-93FC-0C247A3E6E5F', NULL, NULL),
381+
(2, '3', 1, 3.0, 5.0, '1970-01-01T00:00:00.000', 'E621E1F8-C36C-495A-93FC-0C247A3E6E5F', 'optional', NULL),
382+
(3, '4', 1, 3.0, 6.0, '1970-01-01T00:00:00.000', 'E621E1F8-C36C-495A-93FC-0C247A3E6E5F', NULL, NULL)
383383
""".replacingOccurrences(of: "\n", with: ""),
384384
insert
385385
)

Tests/SQLiteTests/TestHelpers.swift

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,3 +145,23 @@ class TestCodable: Codable, Equatable {
145145
lhs.sub == rhs.sub
146146
}
147147
}
148+
149+
struct TestOptionalCodable: Codable, Equatable {
150+
let int: Int?
151+
let string: String?
152+
let bool: Bool?
153+
let float: Float?
154+
let double: Double?
155+
let date: Date?
156+
let uuid: UUID?
157+
158+
init(int: Int?, string: String?, bool: Bool?, float: Float?, double: Double?, date: Date?, uuid: UUID?) {
159+
self.int = int
160+
self.string = string
161+
self.bool = bool
162+
self.float = float
163+
self.double = double
164+
self.date = date
165+
self.uuid = uuid
166+
}
167+
}

0 commit comments

Comments
 (0)