Skip to content
17 changes: 14 additions & 3 deletions Sources/Lifecycle/Lifecycle.swift
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ public struct ServiceLifecycle {
let signalSource = ServiceLifecycle.trap(signal: signal, handler: { signal in
self.log("intercepted signal: \(signal)")
self.shutdown()
})
}, cancelAfterTrap: true)
self.underlying.shutdownGroup.notify(queue: .global()) {
signalSource.cancel()
}
Expand All @@ -176,12 +176,16 @@ extension ServiceLifecycle {
/// - parameters:
/// - signal: The signal to trap.
/// - handler: closure to invoke when the signal is captured.
/// - on: DispatchQueue to run the signal handler on (default global dispatch queue)
/// - cancelAfterTrap: Defaults to false, which means the signal handler can be run multiple times. If true, the DispatchSignalSource will be cancelled after being trapped once.
/// - returns: a `DispatchSourceSignal` for the given trap. The source must be cancelled by the caller.
public static func trap(signal sig: Signal, handler: @escaping (Signal) -> Void, on queue: DispatchQueue = .global()) -> DispatchSourceSignal {
public static func trap(signal sig: Signal, handler: @escaping (Signal) -> Void, on queue: DispatchQueue = .global(), cancelAfterTrap: Bool = false) -> DispatchSourceSignal {
let signalSource = DispatchSource.makeSignalSource(signal: sig.rawValue, queue: queue)
signal(sig.rawValue, SIG_IGN)
signalSource.setEventHandler(handler: {
signalSource.cancel()
if cancelAfterTrap {
signalSource.cancel()
}
handler(sig)
})
signalSource.resume()
Expand All @@ -194,6 +198,10 @@ extension ServiceLifecycle {

public static let TERM = Signal(rawValue: SIGTERM)
public static let INT = Signal(rawValue: SIGINT)
public static let USR1 = Signal(rawValue: SIGUSR1)
public static let USR2 = Signal(rawValue: SIGUSR2)
public static let HUP = Signal(rawValue: SIGHUP)

// for testing
internal static let ALRM = Signal(rawValue: SIGALRM)

Expand All @@ -203,6 +211,9 @@ extension ServiceLifecycle {
case Signal.TERM: result += "TERM, "
case Signal.INT: result += "INT, "
case Signal.ALRM: result += "ALRM, "
case Signal.USR1: result += "USR1, "
case Signal.USR2: result += "USR2, "
case Signal.HUP: result += "HUP, "
default: () // ok to ignore
}
result += "rawValue: \(self.rawValue))"
Expand Down