Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions Sources/GraphQLTransportWS/Server.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ public class Server<InitPayload: Equatable & Codable> {

let onExecute: (GraphQLRequest) -> EventLoopFuture<GraphQLResult>
let onSubscribe: (GraphQLRequest) -> EventLoopFuture<SubscriptionResult>
var auth: (InitPayload) throws -> EventLoopFuture<Void>

var auth: (InitPayload) throws -> Void = { _ in }
var onExit: () -> Void = { }
var onOperationComplete: (String) -> Void = { _ in }
var onOperationError: (String) -> Void = { _ in }
Expand All @@ -32,14 +32,17 @@ public class Server<InitPayload: Equatable & Codable> {
/// - messenger: The messenger to bind the server to.
/// - onExecute: Callback run during `subscribe` resolution for non-streaming queries. Typically this is `API.execute`.
/// - onSubscribe: Callback run during `subscribe` resolution for streaming queries. Typically this is `API.subscribe`.
/// - eventLoop: EventLoop on which to perform server operations.
public init(
messenger: Messenger,
onExecute: @escaping (GraphQLRequest) -> EventLoopFuture<GraphQLResult>,
onSubscribe: @escaping (GraphQLRequest) -> EventLoopFuture<SubscriptionResult>
onSubscribe: @escaping (GraphQLRequest) -> EventLoopFuture<SubscriptionResult>,
eventLoop: EventLoop
) {
self.messenger = messenger
self.onExecute = onExecute
self.onSubscribe = onSubscribe
self.auth = { _ in eventLoop.makeSucceededVoidFuture() }

messenger.onReceive { message in
self.onMessage(message)
Expand Down Expand Up @@ -91,9 +94,9 @@ public class Server<InitPayload: Equatable & Codable> {
}

/// Define the callback run during `connection_init` resolution that allows authorization using the `payload`.
/// Throw to indicate that authorization has failed.
/// Throw or fail the future to indicate that authorization has failed.
/// - Parameter callback: The callback to assign
public func auth(_ callback: @escaping (InitPayload) throws -> Void) {
public func auth(_ callback: @escaping (InitPayload) throws -> EventLoopFuture<Void>) {
self.auth = callback
}

Expand Down Expand Up @@ -128,14 +131,20 @@ public class Server<InitPayload: Equatable & Codable> {
}

do {
try self.auth(connectionInitRequest.payload)
let authResult = try self.auth(connectionInitRequest.payload)
authResult.whenSuccess {
self.initialized = true
self.sendConnectionAck()
}
authResult.whenFailure { error in
self.error(.unauthorized())
return
}
}
catch {
self.error(.unauthorized())
return
}
initialized = true
self.sendConnectionAck()
}

private func onSubscribe(_ subscribeRequest: SubscribeRequest) {
Expand Down
40 changes: 35 additions & 5 deletions Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class GraphqlTransportWSTests: XCTestCase {
var clientMessenger: TestMessenger!
var serverMessenger: TestMessenger!
var server: Server<TokenInitPayload>!
var eventLoop: EventLoop!

override func setUp() {
// Point the client and server at each other
Expand All @@ -18,7 +19,7 @@ class GraphqlTransportWSTests: XCTestCase {
clientMessenger.other = serverMessenger
serverMessenger.other = clientMessenger

let eventLoop = MultiThreadedEventLoopGroup(numberOfThreads: 1).next()
eventLoop = MultiThreadedEventLoopGroup(numberOfThreads: 1).next()
let api = TestAPI()
let context = TestContext()

Expand All @@ -28,16 +29,17 @@ class GraphqlTransportWSTests: XCTestCase {
api.execute(
request: graphQLRequest.query,
context: context,
on: eventLoop
on: self.eventLoop
)
},
onSubscribe: { graphQLRequest in
api.subscribe(
request: graphQLRequest.query,
context: context,
on: eventLoop
on: self.eventLoop
)
}
},
eventLoop: self.eventLoop
)
}

Expand Down Expand Up @@ -71,7 +73,7 @@ class GraphqlTransportWSTests: XCTestCase {
}

/// Tests that throwing in the authorization callback forces an unauthorized error
func testAuth() throws {
func testAuthWithThrow() throws {
server.auth { payload in
throw TestError.couldBeAnything
}
Expand All @@ -98,6 +100,34 @@ class GraphqlTransportWSTests: XCTestCase {
)
}

/// Tests that failing a future in the authorization callback forces an unauthorized error
func testAuthWithFailedFuture() throws {
server.auth { payload in
self.eventLoop.makeFailedFuture(TestError.couldBeAnything)
}

var messages = [String]()
let completeExpectation = XCTestExpectation()

let client = Client<TokenInitPayload>(messenger: clientMessenger)
client.onMessage { message, _ in
messages.append(message)
completeExpectation.fulfill()
}

client.sendConnectionInit(
payload: TokenInitPayload(
authToken: ""
)
)

wait(for: [completeExpectation], timeout: 2)
XCTAssertEqual(
messages,
["\(ErrorCode.unauthorized): Unauthorized"]
)
}

/// Tests a single-op conversation
func testSingleOp() throws {
let id = UUID().description
Expand Down