Skip to content

Commit ac4aea1

Browse files
authored
Allow PartsRepresentable to throw errors (#88)
* almost nonbreaking change * style * make partsvalue accessible only when error is never * fix tests * fix macos * put errors into api methods * style * remove generic error * add non-erroring protocol so force unwraps arent required * api review feedback: use more specific error case and add failure tests * specialize error * style * code feedback changes * use partsValue * use consistent closure name
1 parent e2cebcd commit ac4aea1

File tree

9 files changed

+340
-155
lines changed

9 files changed

+340
-155
lines changed

Examples/GenerativeAISample/GenerativeAIMultimodalSample/ViewModels/PhotoReasoningViewModel.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ class PhotoReasoningViewModel: ObservableObject {
6262

6363
let prompt = "Look at the image(s), and then answer the following question: \(userInput)"
6464

65-
var images = [PartsRepresentable]()
65+
var images = [any ThrowingPartsRepresentable]()
6666
for item in selectedItems {
6767
if let data = try? await item.loadTransferable(type: Data.self) {
6868
guard let image = UIImage(data: data) else {

Sources/GoogleAI/Chat.swift

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ public class Chat {
3131
public var history: [ModelContent]
3232

3333
/// See ``sendMessage(_:)-3ify5``.
34-
public func sendMessage(_ parts: PartsRepresentable...) async throws -> GenerateContentResponse {
34+
public func sendMessage(_ parts: any ThrowingPartsRepresentable...) async throws
35+
-> GenerateContentResponse {
3536
return try await sendMessage([ModelContent(parts: parts)])
3637
}
3738

@@ -40,9 +41,19 @@ public class Chat {
4041
/// - Parameter content: The new content to send as a single chat message.
4142
/// - Returns: The model's response if no error occurred.
4243
/// - Throws: A ``GenerateContentError`` if an error occurred.
43-
public func sendMessage(_ content: [ModelContent]) async throws -> GenerateContentResponse {
44+
public func sendMessage(_ content: @autoclosure () throws -> [ModelContent]) async throws
45+
-> GenerateContentResponse {
4446
// Ensure that the new content has the role set.
45-
let newContent: [ModelContent] = content.map(populateContentRole(_:))
47+
let newContent: [ModelContent]
48+
do {
49+
newContent = try content().map(populateContentRole(_:))
50+
} catch let underlying {
51+
if let contentError = underlying as? ImageConversionError {
52+
throw GenerateContentError.promptImageContentError(underlying: contentError)
53+
} else {
54+
throw GenerateContentError.internalError(underlying: underlying)
55+
}
56+
}
4657

4758
// Send the history alongside the new message as context.
4859
let request = history + newContent
@@ -67,24 +78,39 @@ public class Chat {
6778

6879
/// See ``sendMessageStream(_:)-4abs3``.
6980
@available(macOS 12.0, *)
70-
public func sendMessageStream(_ parts: PartsRepresentable...)
81+
public func sendMessageStream(_ parts: any ThrowingPartsRepresentable...)
7182
-> AsyncThrowingStream<GenerateContentResponse, Error> {
72-
return sendMessageStream([ModelContent(parts: parts)])
83+
return try sendMessageStream([ModelContent(parts: parts)])
7384
}
7485

7586
/// Sends a message using the existing history of this chat as context. If successful, the message
7687
/// and response will be added to the history. If unsuccessful, history will remain unchanged.
7788
/// - Parameter content: The new content to send as a single chat message.
7889
/// - Returns: A stream containing the model's response or an error if an error occurred.
7990
@available(macOS 12.0, *)
80-
public func sendMessageStream(_ content: [ModelContent])
91+
public func sendMessageStream(_ content: @autoclosure () throws -> [ModelContent])
8192
-> AsyncThrowingStream<GenerateContentResponse, Error> {
93+
let resolvedContent: [ModelContent]
94+
do {
95+
resolvedContent = try content()
96+
} catch let underlying {
97+
return AsyncThrowingStream { continuation in
98+
let error: Error
99+
if let contentError = underlying as? ImageConversionError {
100+
error = GenerateContentError.promptImageContentError(underlying: contentError)
101+
} else {
102+
error = GenerateContentError.internalError(underlying: underlying)
103+
}
104+
continuation.finish(throwing: error)
105+
}
106+
}
107+
82108
return AsyncThrowingStream { continuation in
83109
Task {
84110
var aggregatedContent: [ModelContent] = []
85111

86112
// Ensure that the new content has the role set.
87-
let newContent: [ModelContent] = content.map(populateContentRole(_:))
113+
let newContent: [ModelContent] = resolvedContent.map(populateContentRole(_:))
88114

89115
// Send the history alongside the new message as context.
90116
let request = history + newContent

Sources/GoogleAI/GenerateContentError.swift

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ import Foundation
1717
/// Errors that occur when generating content from a model.
1818
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
1919
public enum GenerateContentError: Error {
20+
/// An error occurred when constructing the prompt. Examine the related error for details.
21+
case promptImageContentError(underlying: ImageConversionError)
22+
2023
/// An internal error occurred. See the underlying error for more context.
2124
case internalError(underlying: Error)
2225

Sources/GoogleAI/GenerativeModel.swift

Lines changed: 46 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -96,11 +96,12 @@ public final class GenerativeModel {
9696
/// [few-shot](https://developers.google.com/machine-learning/glossary/generative#few-shot-prompting)
9797
/// prompts, see ``generateContent(_:)-58rm0``.
9898
///
99-
/// - Parameter content: The input(s) given to the model as a prompt (see ``PartsRepresentable``
99+
/// - Parameter content: The input(s) given to the model as a prompt (see
100+
/// ``ThrowingPartsRepresentable``
100101
/// for conforming types).
101102
/// - Returns: The content generated by the model.
102103
/// - Throws: A ``GenerateContentError`` if the request failed.
103-
public func generateContent(_ parts: PartsRepresentable...)
104+
public func generateContent(_ parts: any ThrowingPartsRepresentable...)
104105
async throws -> GenerateContentResponse {
105106
return try await generateContent([ModelContent(parts: parts)])
106107
}
@@ -110,18 +111,21 @@ public final class GenerativeModel {
110111
/// - Parameter content: The input(s) given to the model as a prompt.
111112
/// - Returns: The generated content response from the model.
112113
/// - Throws: A ``GenerateContentError`` if the request failed.
113-
public func generateContent(_ content: [ModelContent]) async throws
114+
public func generateContent(_ content: @autoclosure () throws -> [ModelContent]) async throws
114115
-> GenerateContentResponse {
115-
let generateContentRequest = GenerateContentRequest(model: modelResourceName,
116-
contents: content,
117-
generationConfig: generationConfig,
118-
safetySettings: safetySettings,
119-
isStreaming: false,
120-
options: requestOptions)
121116
let response: GenerateContentResponse
122117
do {
118+
let generateContentRequest = try GenerateContentRequest(model: modelResourceName,
119+
contents: content(),
120+
generationConfig: generationConfig,
121+
safetySettings: safetySettings,
122+
isStreaming: false,
123+
options: requestOptions)
123124
response = try await generativeAIService.loadRequest(request: generateContentRequest)
124125
} catch {
126+
if let imageError = error as? ImageConversionError {
127+
throw GenerateContentError.promptImageContentError(underlying: imageError)
128+
}
125129
throw GenerativeModel.generateContentError(from: error)
126130
}
127131

@@ -148,14 +152,15 @@ public final class GenerativeModel {
148152
/// [few-shot](https://developers.google.com/machine-learning/glossary/generative#few-shot-prompting)
149153
/// prompts, see ``generateContent(_:)-58rm0``.
150154
///
151-
/// - Parameter content: The input(s) given to the model as a prompt (see ``PartsRepresentable``
155+
/// - Parameter content: The input(s) given to the model as a prompt (see
156+
/// ``ThrowingPartsRepresentable``
152157
/// for conforming types).
153158
/// - Returns: A stream wrapping content generated by the model or a ``GenerateContentError``
154159
/// error if an error occurred.
155160
@available(macOS 12.0, *)
156-
public func generateContentStream(_ parts: PartsRepresentable...)
161+
public func generateContentStream(_ parts: any ThrowingPartsRepresentable...)
157162
-> AsyncThrowingStream<GenerateContentResponse, Error> {
158-
return generateContentStream([ModelContent(parts: parts)])
163+
return try generateContentStream([ModelContent(parts: parts)])
159164
}
160165

161166
/// Generates new content from input content given to the model as a prompt.
@@ -164,10 +169,25 @@ public final class GenerativeModel {
164169
/// - Returns: A stream wrapping content generated by the model or a ``GenerateContentError``
165170
/// error if an error occurred.
166171
@available(macOS 12.0, *)
167-
public func generateContentStream(_ content: [ModelContent])
172+
public func generateContentStream(_ content: @autoclosure () throws -> [ModelContent])
168173
-> AsyncThrowingStream<GenerateContentResponse, Error> {
174+
let evaluatedContent: [ModelContent]
175+
do {
176+
evaluatedContent = try content()
177+
} catch let underlying {
178+
return AsyncThrowingStream { continuation in
179+
let error: Error
180+
if let contentError = underlying as? ImageConversionError {
181+
error = GenerateContentError.promptImageContentError(underlying: contentError)
182+
} else {
183+
error = GenerateContentError.internalError(underlying: underlying)
184+
}
185+
continuation.finish(throwing: error)
186+
}
187+
}
188+
169189
let generateContentRequest = GenerateContentRequest(model: modelResourceName,
170-
contents: content,
190+
contents: evaluatedContent,
171191
generationConfig: generationConfig,
172192
safetySettings: safetySettings,
173193
isStreaming: true,
@@ -218,12 +238,14 @@ public final class GenerativeModel {
218238
/// [few-shot](https://developers.google.com/machine-learning/glossary/generative#few-shot-prompting)
219239
/// input, see ``countTokens(_:)-9spwl``.
220240
///
221-
/// - Parameter content: The input(s) given to the model as a prompt (see ``PartsRepresentable``
241+
/// - Parameter content: The input(s) given to the model as a prompt (see
242+
/// ``ThrowingPartsRepresentable``
222243
/// for conforming types).
223244
/// - Returns: The results of running the model's tokenizer on the input; contains
224245
/// ``CountTokensResponse/totalTokens``.
225246
/// - Throws: A ``CountTokensError`` if the tokenization request failed.
226-
public func countTokens(_ parts: PartsRepresentable...) async throws -> CountTokensResponse {
247+
public func countTokens(_ parts: any ThrowingPartsRepresentable...) async throws
248+
-> CountTokensResponse {
227249
return try await countTokens([ModelContent(parts: parts)])
228250
}
229251

@@ -232,16 +254,16 @@ public final class GenerativeModel {
232254
/// - Parameter content: The input given to the model as a prompt.
233255
/// - Returns: The results of running the model's tokenizer on the input; contains
234256
/// ``CountTokensResponse/totalTokens``.
235-
/// - Throws: A ``CountTokensError`` if the tokenization request failed.
236-
public func countTokens(_ content: [ModelContent]) async throws
257+
/// - Throws: A ``CountTokensError`` if the tokenization request failed or the input content was
258+
/// invalid.
259+
public func countTokens(_ content: @autoclosure () throws -> [ModelContent]) async throws
237260
-> CountTokensResponse {
238-
let countTokensRequest = CountTokensRequest(
239-
model: modelResourceName,
240-
contents: content,
241-
options: requestOptions
242-
)
243-
244261
do {
262+
let countTokensRequest = try CountTokensRequest(
263+
model: modelResourceName,
264+
contents: content(),
265+
options: requestOptions
266+
)
245267
return try await generativeAIService.loadRequest(request: countTokensRequest)
246268
} catch {
247269
throw CountTokensError.internalError(underlying: error)

Sources/GoogleAI/ModelContent.swift

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,14 @@ public struct ModelContent: Codable, Equatable {
104104
public let parts: [Part]
105105

106106
/// Creates a new value from any data or `Array` of data interpretable as a
107-
/// ``Part``. See ``PartsRepresentable`` for types that can be interpreted as `Part`s.
107+
/// ``Part``. See ``ThrowingPartsRepresentable`` for types that can be interpreted as `Part`s.
108+
public init(role: String? = "user", parts: some ThrowingPartsRepresentable) throws {
109+
self.role = role
110+
try self.parts = parts.tryPartsValue()
111+
}
112+
113+
/// Creates a new value from any data or `Array` of data interpretable as a
114+
/// ``Part``. See ``ThrowingPartsRepresentable`` for types that can be interpreted as `Part`s.
108115
public init(role: String? = "user", parts: some PartsRepresentable) {
109116
self.role = role
110117
self.parts = parts.partsValue
@@ -116,9 +123,19 @@ public struct ModelContent: Codable, Equatable {
116123
self.parts = parts
117124
}
118125

119-
/// Creates a new value from any data interpretable as a ``Part``. See ``PartsRepresentable``
126+
/// Creates a new value from any data interpretable as a ``Part``. See
127+
/// ``ThrowingPartsRepresentable``
128+
/// for types that can be interpreted as `Part`s.
129+
public init(role: String? = "user", _ parts: any ThrowingPartsRepresentable...) throws {
130+
let content = try parts.flatMap { try $0.tryPartsValue() }
131+
self.init(role: role, parts: content)
132+
}
133+
134+
/// Creates a new value from any data interpretable as a ``Part``. See
135+
/// ``ThrowingPartsRepresentable``
120136
/// for types that can be interpreted as `Part`s.
121-
public init(role: String? = "user", _ parts: PartsRepresentable...) {
122-
self.init(role: role, parts: parts)
137+
public init(role: String? = "user", _ parts: [PartsRepresentable]) {
138+
let content = parts.flatMap { $0.partsValue }
139+
self.init(role: role, parts: content)
123140
}
124141
}
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
// Copyright 2024 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import UniformTypeIdentifiers
16+
#if canImport(UIKit)
17+
import UIKit // For UIImage extensions.
18+
#elseif canImport(AppKit)
19+
import AppKit // For NSImage extensions.
20+
#endif
21+
22+
private let imageCompressionQuality: CGFloat = 0.8
23+
24+
/// An enum describing failures that can occur when converting image types to model content data.
25+
/// For some image types like `CIImage`, creating valid model content requires creating a JPEG
26+
/// representation of the image that may not yet exist, which may be computationally expensive.
27+
public enum ImageConversionError: Error {
28+
/// The image (the receiver of the call `toModelContentParts()`) was invalid.
29+
case invalidUnderlyingImage
30+
31+
/// A valid image destination could not be allocated.
32+
case couldNotAllocateDestination
33+
34+
/// JPEG image data conversion failed, accompanied by the original image, which may be an
35+
/// instance of `NSImageRep`, `UIImage`, `CGImage`, or `CIImage`.
36+
case couldNotConvertToJPEG(Any)
37+
}
38+
39+
#if canImport(UIKit)
40+
/// Enables images to be representable as ``ThrowingPartsRepresentable``.
41+
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
42+
extension UIImage: ThrowingPartsRepresentable {
43+
public func tryPartsValue() throws -> [ModelContent.Part] {
44+
guard let data = jpegData(compressionQuality: imageCompressionQuality) else {
45+
throw ImageConversionError.couldNotConvertToJPEG(self)
46+
}
47+
return [ModelContent.Part.data(mimetype: "image/jpeg", data)]
48+
}
49+
}
50+
51+
#elseif canImport(AppKit)
52+
/// Enables images to be representable as ``ThrowingPartsRepresentable``.
53+
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
54+
extension NSImage: ThrowingPartsRepresentable {
55+
public func tryPartsValue() throws -> [ModelContent.Part] {
56+
guard let cgImage = cgImage(forProposedRect: nil, context: nil, hints: nil) else {
57+
throw ImageConversionError.invalidUnderlyingImage
58+
}
59+
let bmp = NSBitmapImageRep(cgImage: cgImage)
60+
guard let data = bmp.representation(using: .jpeg, properties: [.compressionFactor: 0.8])
61+
else {
62+
throw ImageConversionError.couldNotConvertToJPEG(bmp)
63+
}
64+
return [ModelContent.Part.data(mimetype: "image/jpeg", data)]
65+
}
66+
}
67+
#endif
68+
69+
/// Enables `CGImages` to be representable as model content.
70+
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
71+
extension CGImage: ThrowingPartsRepresentable {
72+
public func tryPartsValue() throws -> [ModelContent.Part] {
73+
let output = NSMutableData()
74+
guard let imageDestination = CGImageDestinationCreateWithData(
75+
output, UTType.jpeg.identifier as CFString, 1, nil
76+
) else {
77+
throw ImageConversionError.couldNotAllocateDestination
78+
}
79+
CGImageDestinationAddImage(imageDestination, self, nil)
80+
CGImageDestinationSetProperties(imageDestination, [
81+
kCGImageDestinationLossyCompressionQuality: imageCompressionQuality,
82+
] as CFDictionary)
83+
if CGImageDestinationFinalize(imageDestination) {
84+
return [.data(mimetype: "image/jpeg", output as Data)]
85+
}
86+
throw ImageConversionError.couldNotConvertToJPEG(self)
87+
}
88+
}
89+
90+
/// Enables `CIImages` to be representable as model content.
91+
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
92+
extension CIImage: ThrowingPartsRepresentable {
93+
public func tryPartsValue() throws -> [ModelContent.Part] {
94+
let context = CIContext()
95+
let jpegData = (colorSpace ?? CGColorSpace(name: CGColorSpace.sRGB))
96+
.flatMap {
97+
// The docs specify kCGImageDestinationLossyCompressionQuality as a supported option, but
98+
// Swift's type system does not allow this.
99+
// [kCGImageDestinationLossyCompressionQuality: imageCompressionQuality]
100+
context.jpegRepresentation(of: self, colorSpace: $0, options: [:])
101+
}
102+
if let jpegData = jpegData {
103+
return [.data(mimetype: "image/jpeg", jpegData)]
104+
}
105+
throw ImageConversionError.couldNotConvertToJPEG(self)
106+
}
107+
}

0 commit comments

Comments
 (0)