diff --git a/Sources/MQTTConnectionManager/Live.swift b/Sources/MQTTConnectionManager/Live.swift index 0053ed8..723c17b 100644 --- a/Sources/MQTTConnectionManager/Live.swift +++ b/Sources/MQTTConnectionManager/Live.swift @@ -57,7 +57,7 @@ public struct MQTTConnectionManager: Sendable { } shutdown: { manager.shutdown() } stream: { - MQTTConnectionStream(client: client) + MQTTConnectionStream(client: client, logger: logger) .start() .removeDuplicates() .eraseToStream() @@ -80,17 +80,17 @@ final class MQTTConnectionStream: AsyncSequence, Sendable { private let client: MQTTClient private let continuation: AsyncStream.Continuation - private var logger: Logger { client.logger } + private let logger: Logger? private let name: String private let stream: AsyncStream - init(client: MQTTClient) { + init(client: MQTTClient, logger: Logger?) { let (stream, continuation) = AsyncStream.makeStream() self.client = client self.continuation = continuation + self.logger = logger self.name = UUID().uuidString self.stream = stream - continuation.yield(client.isActive() ? .connected : .disconnected) } deinit { stop() } @@ -98,26 +98,24 @@ final class MQTTConnectionStream: AsyncSequence, Sendable { func start( isolation: isolated (any Actor)? = #isolation ) -> AsyncStream { + // Check if the client is active and yield the result. + continuation.yield(client.isActive() ? .connected : .disconnected) + + // Register listener on the client for when the connection + // closes. client.addCloseListener(named: name) { _ in - self.logger.trace("Client has disconnected.") + self.logger?.trace("Client has disconnected.") self.continuation.yield(.disconnected) } + + // Register listener on the client for when the client + // is shutdown. client.addShutdownListener(named: name) { _ in - self.logger.trace("Client is shutting down.") + self.logger?.trace("Client is shutting down.") self.continuation.yield(.shuttingDown) self.stop() } - let task = Task { - while !Task.isCancelled { - try? await Task.sleep(for: .milliseconds(100)) - continuation.yield( - self.client.isActive() ? .connected : .disconnected - ) - } - } - continuation.onTermination = { _ in - task.cancel() - } + return stream } @@ -133,11 +131,12 @@ final class MQTTConnectionStream: AsyncSequence, Sendable { } -final class ConnectionManager: Sendable { +actor ConnectionManager { private let client: MQTTClient private let logger: Logger? private let name: String private let shouldReconnect: Bool + private var hasConnected: Bool = false init( client: MQTTClient, @@ -156,12 +155,18 @@ final class ConnectionManager: Sendable { self.shutdown(withLogging: false) } + private func setHasConnected() { + hasConnected = true + } + func connect( isolation: isolated (any Actor)? = #isolation, cleanSession: Bool ) async throws { + guard !(await hasConnected) else { return } do { try await client.connect(cleanSession: cleanSession) + await setHasConnected() client.addCloseListener(named: name) { [weak self] _ in guard let `self` else { return } @@ -174,8 +179,8 @@ final class ConnectionManager: Sendable { } } - client.addShutdownListener(named: name) { [weak self] _ in - self?.shutdown() + client.addShutdownListener(named: name) { _ in + self.shutdown() } } catch { @@ -184,7 +189,7 @@ final class ConnectionManager: Sendable { } } - func shutdown(withLogging: Bool = true) { + nonisolated func shutdown(withLogging: Bool = true) { if withLogging { logger?.trace("Shutting down connection.") } diff --git a/Sources/dewPoint-controller/Application.swift b/Sources/dewPoint-controller/Application.swift index 6c218bc..a1853fb 100644 --- a/Sources/dewPoint-controller/Application.swift +++ b/Sources/dewPoint-controller/Application.swift @@ -62,6 +62,7 @@ struct Application { } try await mqtt.shutdown() + try await eventloopGroup.shutdownGracefully() } catch { try await eventloopGroup.shutdownGracefully() } diff --git a/Tests/MQTTConnectionServiceTests/MQTTConnectionServiceTests.swift b/Tests/MQTTConnectionServiceTests/MQTTConnectionServiceTests.swift index bce5ec5..060c6e3 100644 --- a/Tests/MQTTConnectionServiceTests/MQTTConnectionServiceTests.swift +++ b/Tests/MQTTConnectionServiceTests/MQTTConnectionServiceTests.swift @@ -28,6 +28,14 @@ final class MQTTConnectionServiceTests: XCTestCase { // XCTAssertFalse(client.isActive()) // } + func testWhatHappensIfConnectIsCalledMultipleTimes() async throws { + let client = createClient(identifier: "testWhatHappensIfConnectIsCalledMultipleTimes") + let manager = MQTTConnectionManager.live(client: client) + try await manager.connect() + try await manager.connect() + } + + // TODO: Move to integration tests. func testMQTTConnectionStream() async throws { let client = createClient(identifier: "testNonManagedStream") let manager = MQTTConnectionManager.live( @@ -35,7 +43,7 @@ final class MQTTConnectionServiceTests: XCTestCase { logger: Self.logger, alwaysReconnect: false ) - let stream = MQTTConnectionStream(client: client) + let stream = MQTTConnectionStream(client: client, logger: Self.logger) var events = [MQTTConnectionManager.Event]() _ = try await manager.connect() diff --git a/Tests/SensorsServiceTests/SensorsClientTests.swift b/Tests/SensorsServiceTests/SensorsClientTests.swift index 99a84c2..2b557be 100755 --- a/Tests/SensorsServiceTests/SensorsClientTests.swift +++ b/Tests/SensorsServiceTests/SensorsClientTests.swift @@ -5,6 +5,7 @@ import MQTTNIO import NIO import PsychrometricClientLive @testable import SensorsService +import TopicDependencies import XCTest final class SensorsClientTests: XCTestCase { @@ -25,42 +26,28 @@ final class SensorsClientTests: XCTestCase { } } -// func createClient(identifier: String) -> SensorsClient { -// let envVars = EnvVars( -// appEnv: .testing, -// host: Self.hostname, -// port: "1883", -// identifier: identifier, -// userName: nil, -// password: nil -// ) -// return .init(envVars: envVars, logger: Self.logger) -// } - func createClient(identifier: String) -> MQTTClient { - let envVars = EnvVars( - appEnv: .testing, - host: Self.hostname, - port: "1883", - identifier: identifier, - userName: nil, - password: nil - ) - let config = MQTTClient.Configuration( - version: .v3_1_1, - userName: envVars.userName, - password: envVars.password, - useSSL: false, - useWebSockets: false, - tlsConfiguration: nil, - webSocketURLPath: nil - ) - return .init( - host: Self.hostname, - identifier: identifier, - eventLoopGroupProvider: .shared(MultiThreadedEventLoopGroup(numberOfThreads: 1)), - logger: Self.logger, - configuration: config + func testWhatHappensIfClientDisconnectsWhileListening() async throws { + let client = createClient(identifier: "testWhatHappensIfClientDisconnectsWhileListening") + let listener = TopicListener.live(client: client) + try await client.connect() + + let stream = try await listener.listen("/some/topic") + +// try await Task.sleep(for: .seconds(1)) +// try await client.disconnect() +// +// try await client.connect() +// try await Task.sleep(for: .seconds(1)) + try await client.publish( + to: "/some/topic", + payload: ByteBufferAllocator().buffer(string: "Foo"), + qos: .atLeastOnce, + retain: true ) + try await Task.sleep(for: .seconds(1)) + + listener.shutdown() + try await client.shutdown() } // func testConnectAndShutdown() async throws { @@ -234,10 +221,47 @@ final class SensorsClientTests: XCTestCase { // // await client.shutdown() // } + + func createClient(identifier: String) -> MQTTClient { + let envVars = EnvVars( + appEnv: .testing, + host: Self.hostname, + port: "1883", + identifier: identifier, + userName: nil, + password: nil + ) + let config = MQTTClient.Configuration( + version: .v3_1_1, + userName: envVars.userName, + password: envVars.password, + useSSL: false, + useWebSockets: false, + tlsConfiguration: nil, + webSocketURLPath: nil + ) + return .init( + host: Self.hostname, + identifier: identifier, + eventLoopGroupProvider: .shared(MultiThreadedEventLoopGroup(numberOfThreads: 1)), + logger: Self.logger, + configuration: config + ) + } } // MARK: Helpers for tests. +extension AsyncStream { + func first() async -> Element { + var first: Element + for await value in self { + first = value + } + return first + } +} + class PublishInfoContainer { private(set) var info: [MQTTPublishInfo] private var topicFilters: [String]?