diff --git a/Sources/Lifecycle/Lifecycle.swift b/Sources/Lifecycle/Lifecycle.swift index 584e884..e6824b8 100644 --- a/Sources/Lifecycle/Lifecycle.swift +++ b/Sources/Lifecycle/Lifecycle.swift @@ -333,6 +333,9 @@ public struct ServiceLifecycle { } extension ServiceLifecycle { + private static var trapped: Set = [] + private static let trappedLock = Lock() + /// Setup a signal trap. /// /// - parameters: @@ -342,14 +345,20 @@ extension ServiceLifecycle { /// - cancelAfterTrap: Defaults to false, which means the signal handler can be run multiple times. If true, the DispatchSignalSource will be cancelled after being trapped once. /// - returns: a `DispatchSourceSignal` for the given trap. The source must be cancelled by the caller. public static func trap(signal sig: Signal, handler: @escaping (Signal) -> Void, on queue: DispatchQueue = .global(), cancelAfterTrap: Bool = false) -> DispatchSourceSignal { + // on linux, we can call singal() once per process + self.trappedLock.withLockVoid { + if !trapped.contains(sig.rawValue) { + signal(sig.rawValue, SIG_IGN) + trapped.insert(sig.rawValue) + } + } let signalSource = DispatchSource.makeSignalSource(signal: sig.rawValue, queue: queue) - signal(sig.rawValue, SIG_IGN) - signalSource.setEventHandler(handler: { + signalSource.setEventHandler { if cancelAfterTrap { signalSource.cancel() } handler(sig) - }) + } signalSource.resume() return signalSource } diff --git a/Tests/LifecycleTests/ServiceLifecycleTests+XCTest.swift b/Tests/LifecycleTests/ServiceLifecycleTests+XCTest.swift index a5c87c3..3200ce1 100644 --- a/Tests/LifecycleTests/ServiceLifecycleTests+XCTest.swift +++ b/Tests/LifecycleTests/ServiceLifecycleTests+XCTest.swift @@ -34,6 +34,7 @@ extension ServiceLifecycleTests { ("testNesting2", testNesting2), ("testSignalDescription", testSignalDescription), ("testBacktracesInstalledOnce", testBacktracesInstalledOnce), + ("testRepeatShutdown", testRepeatShutdown), ] } } diff --git a/Tests/LifecycleTests/ServiceLifecycleTests.swift b/Tests/LifecycleTests/ServiceLifecycleTests.swift index 9062656..cb950cf 100644 --- a/Tests/LifecycleTests/ServiceLifecycleTests.swift +++ b/Tests/LifecycleTests/ServiceLifecycleTests.swift @@ -233,4 +233,42 @@ final class ServiceLifecycleTests: XCTestCase { _ = ServiceLifecycle(configuration: config) _ = ServiceLifecycle(configuration: config) } + + func testRepeatShutdown() { + if ProcessInfo.processInfo.environment["SKIP_SIGNAL_TEST"].flatMap(Bool.init) ?? false { + print("skipping testRepeatShutdown") + return + } + + var count = 0 + + struct Service { + static let signal = ServiceLifecycle.Signal.ALRM + + let lifecycle: ServiceLifecycle + + init() { + self.lifecycle = ServiceLifecycle(configuration: .init(shutdownSignal: [Service.signal])) + self.lifecycle.register(GoodItem()) + } + } + + func gracefulShutdown() { + let service = Service() + service.lifecycle.start { error in + XCTAssertNil(error, "not expecting error") + kill(getpid(), Service.signal.rawValue) + } + + service.lifecycle.wait() + count = count + 1 // not thread safe but fine for this purpose + } + + let attempts = Int.random(in: 2 ..< 5) + for _ in 0 ..< attempts { + gracefulShutdown() + } + + XCTAssertEqual(attempts, count) + } }