diff --git a/Sources/SQLite/Typed/Coding.swift b/Sources/SQLite/Typed/Coding.swift index ec2e0d6c..b96bc64e 100644 --- a/Sources/SQLite/Typed/Coding.swift +++ b/Sources/SQLite/Typed/Coding.swift @@ -86,12 +86,22 @@ extension QueryType { /// - Returns: An `INSERT` statement for the encodable objects public func insertMany(_ encodables: [Encodable], userInfo: [CodingUserInfoKey: Any] = [:], otherSetters: [Setter] = []) throws -> Insert { - let combinedSetters = try encodables.map { encodable -> [Setter] in - let encoder = SQLiteEncoder(userInfo: userInfo) + let combinedSettersWithoutNils = try encodables.map { encodable -> [Setter] in + let encoder = SQLiteEncoder(userInfo: userInfo, forcingNilValueSetters: false) try encodable.encode(to: encoder) return encoder.setters + otherSetters } - return self.insertMany(combinedSetters) + // requires the same number of setters per encodable + guard Set(combinedSettersWithoutNils.map(\.count)).count == 1 else { + // asymmetric sets of value insertions (some nil, some not), requires NULL value to satisfy INSERT query + let combinedSymmetricSetters = try encodables.map { encodable -> [Setter] in + let encoder = SQLiteEncoder(userInfo: userInfo, forcingNilValueSetters: true) + try encodable.encode(to: encoder) + return encoder.setters + otherSetters + } + return self.insertMany(combinedSymmetricSetters) + } + return self.insertMany(combinedSettersWithoutNils) } /// Creates an `INSERT ON CONFLICT DO UPDATE` statement, aka upsert, by encoding the given object @@ -165,9 +175,11 @@ private class SQLiteEncoder: Encoder { let encoder: SQLiteEncoder let codingPath: [CodingKey] = [] + let forcingNilValueSetters: Bool - init(encoder: SQLiteEncoder) { + init(encoder: SQLiteEncoder, forcingNilValueSetters: Bool = false) { self.encoder = encoder + self.forcingNilValueSetters = forcingNilValueSetters } func superEncoder() -> Swift.Encoder { @@ -202,6 +214,46 @@ private class SQLiteEncoder: Encoder { encoder.setters.append(Expression(key.stringValue) <- value) } + func encodeIfPresent(_ value: Int?, forKey key: SQLiteEncoder.SQLiteKeyedEncodingContainer.Key) throws { + if let value = value { + try encode(value, forKey: key) + } else if forcingNilValueSetters { + encoder.setters.append(Expression(key.stringValue) <- nil) + } + } + + func encodeIfPresent(_ value: Bool?, forKey key: Key) throws { + if let value = value { + try encode(value, forKey: key) + } else if forcingNilValueSetters { + encoder.setters.append(Expression(key.stringValue) <- nil) + } + } + + func encodeIfPresent(_ value: Float?, forKey key: Key) throws { + if let value = value { + try encode(value, forKey: key) + } else if forcingNilValueSetters { + encoder.setters.append(Expression(key.stringValue) <- nil) + } + } + + func encodeIfPresent(_ value: Double?, forKey key: Key) throws { + if let value = value { + try encode(value, forKey: key) + } else if forcingNilValueSetters { + encoder.setters.append(Expression(key.stringValue) <- nil) + } + } + + func encodeIfPresent(_ value: String?, forKey key: MyKey) throws { + if let value = value { + try encode(value, forKey: key) + } else if forcingNilValueSetters { + encoder.setters.append(Expression(key.stringValue) <- nil) + } + } + func encode(_ value: T, forKey key: Key) throws where T: Swift.Encodable { switch value { case let data as Data: @@ -217,6 +269,17 @@ private class SQLiteEncoder: Encoder { } } + func encodeIfPresent(_ value: T?, forKey key: Key) throws where T: Swift.Encodable { + guard let value = value else { + guard forcingNilValueSetters else { + return + } + encoder.setters.append(Expression(key.stringValue) <- nil) + return + } + try encode(value, forKey: key) + } + func encode(_ value: Int8, forKey key: Key) throws { throw EncodingError.invalidValue(value, EncodingError.Context(codingPath: codingPath, debugDescription: "encoding an Int8 is not supported")) @@ -274,9 +337,11 @@ private class SQLiteEncoder: Encoder { fileprivate var setters: [Setter] = [] let codingPath: [CodingKey] = [] let userInfo: [CodingUserInfoKey: Any] + let forcingNilValueSetters: Bool - init(userInfo: [CodingUserInfoKey: Any]) { + init(userInfo: [CodingUserInfoKey: Any], forcingNilValueSetters: Bool = false) { self.userInfo = userInfo + self.forcingNilValueSetters = forcingNilValueSetters } func singleValueContainer() -> SingleValueEncodingContainer { @@ -288,7 +353,7 @@ private class SQLiteEncoder: Encoder { } func container(keyedBy type: Key.Type) -> KeyedEncodingContainer where Key: CodingKey { - KeyedEncodingContainer(SQLiteKeyedEncodingContainer(encoder: self)) + KeyedEncodingContainer(SQLiteKeyedEncodingContainer(encoder: self, forcingNilValueSetters: forcingNilValueSetters)) } } diff --git a/Tests/SQLiteTests/QueryIntegrationTests.swift b/Tests/SQLiteTests/QueryIntegrationTests.swift index b3d9cf96..f3b4bcd3 100644 --- a/Tests/SQLiteTests/QueryIntegrationTests.swift +++ b/Tests/SQLiteTests/QueryIntegrationTests.swift @@ -130,6 +130,28 @@ class QueryIntegrationTests: SQLiteTestCase { XCTAssertEqual(2, id) } + func test_insert_many_encodables() throws { + let table = Table("codable") + try db.run(table.create { builder in + builder.column(Expression("int")) + builder.column(Expression("string")) + builder.column(Expression("bool")) + builder.column(Expression("float")) + builder.column(Expression("double")) + builder.column(Expression("date")) + builder.column(Expression("uuid")) + }) + + let value1 = TestOptionalCodable(int: 5, string: "6", bool: true, float: 7, double: 8, + date: Date(timeIntervalSince1970: 5000), uuid: testUUIDValue) + let valueWithNils = TestOptionalCodable(int: nil, string: nil, bool: nil, float: nil, double: nil, date: nil, uuid: nil) + try db.run(table.insertMany([value1, valueWithNils])) + + let rows = try db.prepare(table) + let values: [TestOptionalCodable] = try rows.map({ try $0.decode() }) + XCTAssertEqual(values.count, 2) + } + func test_upsert() throws { try XCTSkipUnless(db.satisfiesMinimumVersion(minor: 24)) let fetchAge = { () throws -> Int? in diff --git a/Tests/SQLiteTests/QueryTests.swift b/Tests/SQLiteTests/QueryTests.swift index 2201caef..b2f679e0 100644 --- a/Tests/SQLiteTests/QueryTests.swift +++ b/Tests/SQLiteTests/QueryTests.swift @@ -365,21 +365,21 @@ class QueryTests: XCTestCase { ) } - func test_insert_many_encodable() throws { + func test_insert_many_encodables() throws { let emails = Table("emails") let value1 = TestCodable(int: 1, string: "2", bool: true, float: 3, double: 4, date: Date(timeIntervalSince1970: 0), uuid: testUUIDValue, optional: nil, sub: nil) let value2 = TestCodable(int: 2, string: "3", bool: true, float: 3, double: 5, - date: Date(timeIntervalSince1970: 0), uuid: testUUIDValue, optional: nil, sub: nil) + date: Date(timeIntervalSince1970: 0), uuid: testUUIDValue, optional: "optional", sub: nil) let value3 = TestCodable(int: 3, string: "4", bool: true, float: 3, double: 6, date: Date(timeIntervalSince1970: 0), uuid: testUUIDValue, optional: nil, sub: nil) let insert = try emails.insertMany([value1, value2, value3]) assertSQL( """ - INSERT INTO \"emails\" (\"int\", \"string\", \"bool\", \"float\", \"double\", \"date\", \"uuid\") - VALUES (1, '2', 1, 3.0, 4.0, '1970-01-01T00:00:00.000', 'E621E1F8-C36C-495A-93FC-0C247A3E6E5F'), - (2, '3', 1, 3.0, 5.0, '1970-01-01T00:00:00.000', 'E621E1F8-C36C-495A-93FC-0C247A3E6E5F'), - (3, '4', 1, 3.0, 6.0, '1970-01-01T00:00:00.000', 'E621E1F8-C36C-495A-93FC-0C247A3E6E5F') + INSERT INTO \"emails\" (\"int\", \"string\", \"bool\", \"float\", \"double\", \"date\", \"uuid\", \"optional\", \"sub\") + VALUES (1, '2', 1, 3.0, 4.0, '1970-01-01T00:00:00.000', 'E621E1F8-C36C-495A-93FC-0C247A3E6E5F', NULL, NULL), + (2, '3', 1, 3.0, 5.0, '1970-01-01T00:00:00.000', 'E621E1F8-C36C-495A-93FC-0C247A3E6E5F', 'optional', NULL), + (3, '4', 1, 3.0, 6.0, '1970-01-01T00:00:00.000', 'E621E1F8-C36C-495A-93FC-0C247A3E6E5F', NULL, NULL) """.replacingOccurrences(of: "\n", with: ""), insert ) diff --git a/Tests/SQLiteTests/TestHelpers.swift b/Tests/SQLiteTests/TestHelpers.swift index 4a71fd40..f43310c9 100644 --- a/Tests/SQLiteTests/TestHelpers.swift +++ b/Tests/SQLiteTests/TestHelpers.swift @@ -145,3 +145,23 @@ class TestCodable: Codable, Equatable { lhs.sub == rhs.sub } } + +struct TestOptionalCodable: Codable, Equatable { + let int: Int? + let string: String? + let bool: Bool? + let float: Float? + let double: Double? + let date: Date? + let uuid: UUID? + + init(int: Int?, string: String?, bool: Bool?, float: Float?, double: Double?, date: Date?, uuid: UUID?) { + self.int = int + self.string = string + self.bool = bool + self.float = float + self.double = double + self.date = date + self.uuid = uuid + } +}