diff --git a/Sources/NetworkProtection/Diagnostics/NetworkProtectionConnectionTester.swift b/Sources/NetworkProtection/Diagnostics/NetworkProtectionConnectionTester.swift index 23e89e4fa..162c3ef74 100644 --- a/Sources/NetworkProtection/Diagnostics/NetworkProtectionConnectionTester.swift +++ b/Sources/NetworkProtection/Diagnostics/NetworkProtectionConnectionTester.swift @@ -134,6 +134,7 @@ final class NetworkProtectionConnectionTester { func stop() async { os_log("🔴 Stopping connection tester", log: log) await stopScheduledTimer() + isRunning = false } // MARK: - Obtaining the interface @@ -201,7 +202,6 @@ final class NetworkProtectionConnectionTester { } private func stopScheduledTimer() async { - isRunning = false cancelTimerImmediately() } @@ -242,7 +242,7 @@ final class NetworkProtectionConnectionTester { let onlyVPNIsDown = simulateFailure || (!vpnIsConnected && localIsConnected) simulateFailure = false - // After completing the conection tests we check if the tester is still supposed to be running + // After completing the connection tests we check if the tester is still supposed to be running // to avoid giving results when it should not be running. guard isRunning else { os_log("Tester skipped returning results as it was stopped while running the tests", log: log, type: .info) diff --git a/Sources/NetworkProtection/KeyManagement/NetworkProtectionKeyStore.swift b/Sources/NetworkProtection/KeyManagement/NetworkProtectionKeyStore.swift index c63e9e937..93c77c09c 100644 --- a/Sources/NetworkProtection/KeyManagement/NetworkProtectionKeyStore.swift +++ b/Sources/NetworkProtection/KeyManagement/NetworkProtectionKeyStore.swift @@ -23,21 +23,19 @@ public protocol NetworkProtectionKeyStore { /// Obtain the current `KeyPair`. /// - func currentKeyPair() -> KeyPair + func currentKeyPair() -> KeyPair? - /// Sets the validity interval for keys + /// Create a new `KeyPair`. /// + func newKeyPair() -> KeyPair + /// Sets the validity interval for keys + /// func setValidityInterval(_ validityInterval: TimeInterval?) - /// Updates the current `KeyPair` to have the specified expiration date - /// - /// - Parameters: - /// - newExpirationDate: the new expiration date for the keypair + /// Updates the existing KeyPair. /// - /// - Returns: a new keypair with the specified updates - /// - func updateCurrentKeyPair(newExpirationDate: Date) -> KeyPair + func updateKeyPair(_ newKeyPair: KeyPair) /// Resets the current `KeyPair` so a new one will be generated when requested. /// @@ -75,34 +73,38 @@ public final class NetworkProtectionKeychainKeyStore: NetworkProtectionKeyStore // MARK: - NetworkProtectionKeyStore - public func currentKeyPair() -> KeyPair { + public func currentKeyPair() -> KeyPair? { os_log("Querying the current key pair (publicKey: %{public}@, expirationDate: %{public}@)", log: .networkProtectionKeyManagement, String(describing: currentPublicKey), String(describing: currentExpirationDate)) guard let currentPrivateKey = currentPrivateKey else { - let keyPair = newCurrentKeyPair() - os_log("Returning a new key pair as there's no current private key (newPublicKey: %{public}@)", - log: .networkProtectionKeyManagement, - String(describing: keyPair.publicKey.base64Key)) - return keyPair + os_log("There's no current private key.", + log: .networkProtectionKeyManagement) + return nil } guard let currentExpirationDate = currentExpirationDate, Date().addingTimeInterval(validityInterval) >= currentExpirationDate else { - let keyPair = newCurrentKeyPair() - os_log("Returning a new key pair as the expirationDate date is missing, or we're past it (now: %{public}@, expirationDate: %{public}@)", + os_log("The expirationDate date is missing, or we're past it (now: %{public}@, expirationDate: %{public}@)", log: .networkProtectionKeyManagement, String(describing: Date()), String(describing: currentExpirationDate)) - return keyPair + return nil } return KeyPair(privateKey: currentPrivateKey, expirationDate: currentExpirationDate) } + public func newKeyPair() -> KeyPair { + let newPrivateKey = PrivateKey() + let newExpirationDate = Date().addingTimeInterval(validityInterval) + + return KeyPair(privateKey: newPrivateKey, expirationDate: newExpirationDate) + } + private var validityInterval = Defaults.validityInterval public func setValidityInterval(_ validityInterval: TimeInterval?) { @@ -123,9 +125,14 @@ public final class NetworkProtectionKeychainKeyStore: NetworkProtectionKeyStore return KeyPair(privateKey: currentPrivateKey, expirationDate: currentExpirationDate) } - public func updateCurrentKeyPair(newExpirationDate: Date) -> KeyPair { - currentExpirationDate = newExpirationDate - return currentKeyPair() + public func updateKeyPair(_ newKeyPair: KeyPair) { + if currentPrivateKey != newKeyPair.privateKey { + self.currentPrivateKey = newKeyPair.privateKey + } + + if currentExpirationDate != newKeyPair.expirationDate { + self.currentExpirationDate = newKeyPair.expirationDate + } } public func resetCurrentKeyPair() { diff --git a/Sources/NetworkProtection/NetworkProtectionDeviceManager.swift b/Sources/NetworkProtection/NetworkProtectionDeviceManager.swift index 4ce374a9c..c36c78b59 100644 --- a/Sources/NetworkProtection/NetworkProtectionDeviceManager.swift +++ b/Sources/NetworkProtection/NetworkProtectionDeviceManager.swift @@ -45,7 +45,8 @@ public protocol NetworkProtectionDeviceManagement { func generateTunnelConfiguration(selectionMethod: NetworkProtectionServerSelectionMethod, includedRoutes: [IPAddressRange], excludedRoutes: [IPAddressRange], - isKillSwitchEnabled: Bool) async throws -> (TunnelConfiguration, NetworkProtectionServerInfo) + isKillSwitchEnabled: Bool, + regenerateKey: Bool) async throws -> (TunnelConfiguration, NetworkProtectionServerInfo) } @@ -124,9 +125,25 @@ public actor NetworkProtectionDeviceManager: NetworkProtectionDeviceManagement { public func generateTunnelConfiguration(selectionMethod: NetworkProtectionServerSelectionMethod, includedRoutes: [IPAddressRange], excludedRoutes: [IPAddressRange], - isKillSwitchEnabled: Bool) async throws -> (TunnelConfiguration, NetworkProtectionServerInfo) { + isKillSwitchEnabled: Bool, + regenerateKey: Bool) async throws -> (TunnelConfiguration, NetworkProtectionServerInfo) { + var keyPair: KeyPair + + if regenerateKey { + keyPair = keyStore.newKeyPair() + } else { + keyPair = keyStore.currentKeyPair() ?? keyStore.newKeyPair() + } + + let (selectedServer, newExpiration) = try await register(keyPair: keyPair, selectionMethod: selectionMethod) + os_log("Server registration successul", log: .networkProtection) - let (selectedServer, keyPair) = try await register(selectionMethod: selectionMethod) + keyStore.updateKeyPair(keyPair) + + if let newExpiration { + keyPair = KeyPair(privateKey: keyPair.privateKey, expirationDate: newExpiration) + keyStore.updateKeyPair(keyPair) + } do { let configuration = try tunnelConfiguration(interfacePrivateKey: keyPair.privateKey, @@ -152,13 +169,11 @@ public actor NetworkProtectionDeviceManager: NetworkProtectionDeviceManagement { // - keyPair: the key pair that was used to register with the server, and that should be used to configure the tunnel // // - Throws:`NetworkProtectionError` - // This cannot be a doc comment because of the swiftlint command below - // swiftlint:disable cyclomatic_complexity - private func register(selectionMethod: NetworkProtectionServerSelectionMethod) async throws -> (server: NetworkProtectionServer, - keyPair: KeyPair) { + private func register(keyPair: KeyPair, + selectionMethod: NetworkProtectionServerSelectionMethod) async throws -> (server: NetworkProtectionServer, + newExpiration: Date?) { guard let token = try? tokenStore.fetchToken() else { throw NetworkProtectionError.noAuthTokenFound } - var keyPair = keyStore.currentKeyPair() let serverSelection: RegisterServerSelection let excludedServerName: String? @@ -203,29 +218,18 @@ public actor NetworkProtectionDeviceManager: NetworkProtectionDeviceManagement { errorEvents?.fire(NetworkProtectionError.serverListInconsistency) let cachedServer = try cachedServer(registeredWith: keyPair) - return (cachedServer, keyPair) + return (cachedServer, nil) } selectedServer = registeredServer - - // We should not need this IF condition here, because we know registered servers will give us an expiration date, - // but since the structure we're currently using makes the expiration date optional we need to have it. - // We should consider changing our server structure to not allow a missing expiration date here. - if let serverExpirationDate = selectedServer.expirationDate, - keyPair.expirationDate > serverExpirationDate { - - keyPair = keyStore.updateCurrentKeyPair(newExpirationDate: serverExpirationDate) - } - - return (selectedServer, keyPair) + return (selectedServer, selectedServer.expirationDate) case .failure(let error): handle(clientError: error) let cachedServer = try cachedServer(registeredWith: keyPair) - return (cachedServer, keyPair) + return (cachedServer, nil) } } - // swiftlint:enable cyclomatic_complexity /// Retrieves the first cached server that's registered with the specified key pair. /// diff --git a/Sources/NetworkProtection/PacketTunnelProvider.swift b/Sources/NetworkProtection/PacketTunnelProvider.swift index b4a325e87..67bc888e0 100644 --- a/Sources/NetworkProtection/PacketTunnelProvider.swift +++ b/Sources/NetworkProtection/PacketTunnelProvider.swift @@ -19,6 +19,8 @@ // SPDX-License-Identifier: MIT // Copyright © 2018-2021 WireGuard LLC. All Rights Reserved. +// swiftlint:disable file_length + import Combine import Common import Foundation @@ -146,11 +148,18 @@ open class PacketTunnelProvider: NEPacketTunnelProvider { } private var isKeyExpired: Bool { - keyStore.currentKeyPair().expirationDate <= Date() + guard let currentExpirationDate = keyStore.currentExpirationDate else { + return true + } + + return currentExpirationDate <= Date() } private func rekeyIfExpired() async { + os_log("Checking if rekey is necessary...", log: .networkProtectionKeyManagement) + guard isKeyExpired else { + os_log("The key is not expired", log: .networkProtectionKeyManagement) return } @@ -162,16 +171,15 @@ open class PacketTunnelProvider: NEPacketTunnelProvider { // Experimental option to disable rekeying. guard !settings.disableRekeying else { + os_log("Rekeying disabled", log: .networkProtectionKeyManagement) return } os_log("Rekeying...", log: .networkProtectionKeyManagement) - providerEvents.fire(.rekeyCompleted) - self.resetRegistrationKey() - do { - try await updateTunnelConfiguration(reassert: false) + try await updateTunnelConfiguration(reassert: false, regenerateKey: true) + providerEvents.fire(.rekeyCompleted) } catch { os_log("Rekey attempt failed. This is not an error if you're using debug Key Management options: %{public}@", log: .networkProtectionKeyManagement, type: .error, String(describing: error)) } @@ -203,13 +211,13 @@ open class PacketTunnelProvider: NEPacketTunnelProvider { Task { await updateBandwidthAnalyzer() + // This provides a more frequent active user pixel check + providerEvents.fire(.userBecameActive) + guard self.bandwidthAnalyzer.isConnectionIdle() else { return } - // This provides a more frequent active user pixel check - providerEvents.fire(.userBecameActive) - await rekeyIfExpired() } } @@ -553,7 +561,8 @@ open class PacketTunnelProvider: NEPacketTunnelProvider { let tunnelConfiguration = try await generateTunnelConfiguration(environment: environment, serverSelectionMethod: currentServerSelectionMethod, includedRoutes: includedRoutes ?? [], - excludedRoutes: settings.excludedRanges) + excludedRoutes: settings.excludedRanges, + regenerateKey: false) startTunnel(with: tunnelConfiguration, onDemand: onDemand, completionHandler: completionHandler) os_log("🔵 Done generating tunnel config", log: .networkProtection, type: .info) } catch { @@ -658,17 +667,26 @@ open class PacketTunnelProvider: NEPacketTunnelProvider { // MARK: - Tunnel Configuration @MainActor - public func updateTunnelConfiguration(reassert: Bool = true) async throws { - try await updateTunnelConfiguration(environment: settings.selectedEnvironment, serverSelectionMethod: currentServerSelectionMethod, reassert: reassert) + public func updateTunnelConfiguration(reassert: Bool = true, regenerateKey: Bool = false) async throws { + try await updateTunnelConfiguration( + environment: settings.selectedEnvironment, + serverSelectionMethod: currentServerSelectionMethod, + reassert: reassert, + regenerateKey: regenerateKey + ) } @MainActor - public func updateTunnelConfiguration(environment: VPNSettings.SelectedEnvironment = .default, serverSelectionMethod: NetworkProtectionServerSelectionMethod, reassert: Bool = true) async throws { + public func updateTunnelConfiguration(environment: VPNSettings.SelectedEnvironment = .default, + serverSelectionMethod: NetworkProtectionServerSelectionMethod, + reassert: Bool = true, + regenerateKey: Bool = false) async throws { let tunnelConfiguration = try await generateTunnelConfiguration(environment: environment, serverSelectionMethod: serverSelectionMethod, includedRoutes: includedRoutes ?? [], - excludedRoutes: settings.excludedRanges) + excludedRoutes: settings.excludedRanges, + regenerateKey: regenerateKey) try await withCheckedThrowingContinuation { [weak self] (continuation: CheckedContinuation) in guard let self = self else { @@ -699,7 +717,11 @@ open class PacketTunnelProvider: NEPacketTunnelProvider { } @MainActor - private func generateTunnelConfiguration(environment: VPNSettings.SelectedEnvironment = .default, serverSelectionMethod: NetworkProtectionServerSelectionMethod, includedRoutes: [IPAddressRange], excludedRoutes: [IPAddressRange]) async throws -> TunnelConfiguration { + private func generateTunnelConfiguration(environment: VPNSettings.SelectedEnvironment = .default, + serverSelectionMethod: NetworkProtectionServerSelectionMethod, + includedRoutes: [IPAddressRange], + excludedRoutes: [IPAddressRange], + regenerateKey: Bool) async throws -> TunnelConfiguration { let configurationResult: (TunnelConfiguration, NetworkProtectionServerInfo) @@ -710,7 +732,13 @@ open class PacketTunnelProvider: NEPacketTunnelProvider { keyStore: keyStore, errorEvents: debugEvents) - configurationResult = try await deviceManager.generateTunnelConfiguration(selectionMethod: serverSelectionMethod, includedRoutes: includedRoutes, excludedRoutes: excludedRoutes, isKillSwitchEnabled: isKillSwitchEnabled) + configurationResult = try await deviceManager.generateTunnelConfiguration( + selectionMethod: serverSelectionMethod, + includedRoutes: includedRoutes, + excludedRoutes: excludedRoutes, + isKillSwitchEnabled: isKillSwitchEnabled, + regenerateKey: regenerateKey + ) } catch { throw TunnelError.couldNotGenerateTunnelConfiguration(internalError: error) } @@ -724,9 +752,7 @@ open class PacketTunnelProvider: NEPacketTunnelProvider { selectedServerInfo.name) os_log("🔵 Excluded routes: %{public}@", log: .networkProtection, type: .info, String(describing: excludedRoutes)) - let tunnelConfiguration = configurationResult.0 - - return tunnelConfiguration + return configurationResult.0 } // MARK: - App Messages @@ -740,7 +766,7 @@ open class PacketTunnelProvider: NEPacketTunnelProvider { switch message { case .request(let request): - handleRequest(request) + handleRequest(request, completionHandler: completionHandler) case .expireRegistrationKey: handleExpireRegistrationKey(completionHandler: completionHandler) case .getLastErrorMessage: @@ -846,7 +872,7 @@ open class PacketTunnelProvider: NEPacketTunnelProvider { .setNetworkPathChange, .setDisableRekeying: // Intentional no-op, as some setting changes don't require any further operation - break + completionHandler?(nil) } } @@ -1178,3 +1204,5 @@ extension WireGuardAdapterError: LocalizedError, CustomDebugStringConvertible { } } + +// swiftlint:enable file_length diff --git a/Tests/NetworkProtectionTests/Mocks/NetworkProtectionKeyStoreMocks.swift b/Tests/NetworkProtectionTests/Mocks/NetworkProtectionKeyStoreMocks.swift index fb0d9d540..a1ba18466 100644 --- a/Tests/NetworkProtectionTests/Mocks/NetworkProtectionKeyStoreMocks.swift +++ b/Tests/NetworkProtectionTests/Mocks/NetworkProtectionKeyStoreMocks.swift @@ -26,17 +26,19 @@ final class NetworkProtectionKeyStoreMock: NetworkProtectionKeyStore { // MARK: - NetworkProtectionKeyStore - func currentKeyPair() -> NetworkProtection.KeyPair { - if let keyPair = self.keyPair { - return keyPair - } else { - let keyPair = KeyPair(privateKey: PrivateKey(), expirationDate: Date().addingTimeInterval(.day)) - self.keyPair = keyPair - return keyPair - } + func currentKeyPair() -> NetworkProtection.KeyPair? { + keyPair } - func updateCurrentKeyPair(newExpirationDate: Date) -> NetworkProtection.KeyPair { + func newKeyPair() -> NetworkProtection.KeyPair { + return KeyPair(privateKey: PrivateKey(), expirationDate: Date().addingTimeInterval(.day)) + } + + public func updateKeyPair(_ newKeyPair: KeyPair) { + self.keyPair = newKeyPair + } + + func updateKeyPairExpirationDate(_ newExpirationDate: Date) -> NetworkProtection.KeyPair { let keyPair = KeyPair(privateKey: keyPair?.privateKey ?? PrivateKey(), expirationDate: newExpirationDate) self.keyPair = keyPair return keyPair diff --git a/Tests/NetworkProtectionTests/Mocks/NetworkProtectionServerMocks.swift b/Tests/NetworkProtectionTests/Mocks/NetworkProtectionServerMocks.swift index a01043439..3fdd9056c 100644 --- a/Tests/NetworkProtectionTests/Mocks/NetworkProtectionServerMocks.swift +++ b/Tests/NetworkProtectionTests/Mocks/NetworkProtectionServerMocks.swift @@ -67,7 +67,7 @@ extension NetworkProtectionServerInfo { extension NetworkProtectionServer { - static let mockBaseServer = NetworkProtectionServer(registeredPublicKey: nil, allowedIPs: nil, serverInfo: .mock, expirationDate: nil) + static let mockBaseServer = NetworkProtectionServer(registeredPublicKey: nil, allowedIPs: nil, serverInfo: .mock, expirationDate: Date()) static let mockRegisteredServer = NetworkProtectionServer(registeredPublicKey: "ovn9RpzUuvQ4XLQt6B3RKuEXGIxa5QpTnehjduZlcSE=", allowedIPs: ["0.0.0.0/0", "::/0"], serverInfo: .mock, diff --git a/Tests/NetworkProtectionTests/NetworkProtectionDeviceManagerTests.swift b/Tests/NetworkProtectionTests/NetworkProtectionDeviceManagerTests.swift index d3e86f713..0340e284a 100644 --- a/Tests/NetworkProtectionTests/NetworkProtectionDeviceManagerTests.swift +++ b/Tests/NetworkProtectionTests/NetworkProtectionDeviceManagerTests.swift @@ -63,7 +63,7 @@ final class NetworkProtectionDeviceManagerTests: XCTestCase { let configuration: (TunnelConfiguration, NetworkProtectionServerInfo) do { - configuration = try await manager.generateTunnelConfiguration(selectionMethod: .automatic) + configuration = try await manager.generateTunnelConfiguration(selectionMethod: .automatic, regenerateKey: false) } catch { XCTFail("Unexpected error \(error.localizedDescription)") return @@ -89,7 +89,7 @@ final class NetworkProtectionDeviceManagerTests: XCTestCase { XCTAssertEqual(try? serverListStore.storedNetworkProtectionServerList(), []) XCTAssertNil(networkClient.spyRegister) - _ = try? await manager.generateTunnelConfiguration(selectionMethod: .automatic) + _ = try? await manager.generateTunnelConfiguration(selectionMethod: .automatic, regenerateKey: false) XCTAssertNotNil(try? keyStore.storedPrivateKey()) XCTAssertEqual(try? serverListStore.storedNetworkProtectionServerList(), [registeredServer]) @@ -101,7 +101,7 @@ final class NetworkProtectionDeviceManagerTests: XCTestCase { networkClient.stubRegister = .success([server]) let preferredLocation = NetworkProtectionSelectedLocation(country: "Some country", city: "Some city") - _ = try? await manager.generateTunnelConfiguration(selectionMethod: .preferredLocation(preferredLocation)) + _ = try? await manager.generateTunnelConfiguration(selectionMethod: .preferredLocation(preferredLocation), regenerateKey: false) XCTAssertEqual(networkClient.spyRegister?.requestBody.city, preferredLocation.city) XCTAssertEqual(networkClient.spyRegister?.requestBody.country, preferredLocation.country) @@ -111,7 +111,7 @@ final class NetworkProtectionDeviceManagerTests: XCTestCase { let server = NetworkProtectionServer.mockBaseServer networkClient.stubRegister = .success([server]) - _ = try? await manager.generateTunnelConfiguration(selectionMethod: .preferredServer(serverName: server.serverName)) + _ = try? await manager.generateTunnelConfiguration(selectionMethod: .preferredServer(serverName: server.serverName), regenerateKey: false) XCTAssertEqual(networkClient.spyRegister?.requestBody.server, server.serverName) } @@ -122,7 +122,7 @@ final class NetworkProtectionDeviceManagerTests: XCTestCase { XCTAssertNotNil(tokenStore.token) - _ = try? await manager.generateTunnelConfiguration(selectionMethod: .automatic) + _ = try? await manager.generateTunnelConfiguration(selectionMethod: .automatic, regenerateKey: false) XCTAssertNil(tokenStore.token) } @@ -132,7 +132,7 @@ final class NetworkProtectionDeviceManagerTests: XCTestCase { XCTAssertNotNil(tokenStore.token) - _ = try? await manager.generateTunnelConfiguration(selectionMethod: .automatic) + _ = try? await manager.generateTunnelConfiguration(selectionMethod: .automatic, regenerateKey: false) XCTAssertNil(tokenStore.token) } @@ -145,12 +145,96 @@ final class NetworkProtectionDeviceManagerTests: XCTestCase { XCTAssertEqual(servers2.count, 6) } + func testWhenGeneratingTunnelConfiguration_AndKeyIsStillValid_AndKeyIsNotRegenerated_ThenKeyDoesNotChange() async { + let server = NetworkProtectionServer.mockBaseServer + let registeredServer = NetworkProtectionServer.mockRegisteredServer + + networkClient.stubGetServers = .success([server]) + networkClient.stubRegister = .success([registeredServer]) + + XCTAssertNil(try? keyStore.storedPrivateKey()) + XCTAssertEqual(try? serverListStore.storedNetworkProtectionServerList(), []) + XCTAssertNil(networkClient.spyRegister) + _ = try? await manager.generateTunnelConfiguration(selectionMethod: .automatic, regenerateKey: false) + + let firstKey = try? keyStore.storedPrivateKey() + XCTAssertNotNil(firstKey) + XCTAssertEqual(try? serverListStore.storedNetworkProtectionServerList(), [registeredServer]) + XCTAssertNotNil(networkClient.spyRegister) + _ = try? await manager.generateTunnelConfiguration(selectionMethod: .automatic, regenerateKey: false) + + let secondKey = try? keyStore.storedPrivateKey() + XCTAssertNotNil(secondKey) + XCTAssertEqual(firstKey, secondKey) // Check that the key did NOT change + XCTAssertEqual(try? serverListStore.storedNetworkProtectionServerList(), [registeredServer]) + XCTAssertNotNil(networkClient.spyRegister) + } + + func testWhenGeneratingTunnelConfiguration_AndKeyIsStillValid_AndKeyIsRegenerated_ThenKeyChanges() async { + let server = NetworkProtectionServer.mockBaseServer + let registeredServer = NetworkProtectionServer.mockRegisteredServer + + networkClient.stubGetServers = .success([server]) + networkClient.stubRegister = .success([registeredServer]) + + XCTAssertNil(try? keyStore.storedPrivateKey()) + XCTAssertEqual(try? serverListStore.storedNetworkProtectionServerList(), []) + XCTAssertNil(networkClient.spyRegister) + _ = try? await manager.generateTunnelConfiguration(selectionMethod: .automatic, regenerateKey: false) + + let firstKey = try? keyStore.storedPrivateKey() + XCTAssertNotNil(firstKey) + XCTAssertEqual(try? serverListStore.storedNetworkProtectionServerList(), [registeredServer]) + XCTAssertNotNil(networkClient.spyRegister) + _ = try? await manager.generateTunnelConfiguration(selectionMethod: .automatic, regenerateKey: true) + + let secondKey = try? keyStore.storedPrivateKey() + XCTAssertNotNil(secondKey) + XCTAssertNotEqual(firstKey, secondKey) // Check that the key changed + XCTAssertEqual(try? serverListStore.storedNetworkProtectionServerList(), [registeredServer]) + XCTAssertNotNil(networkClient.spyRegister) + } + + func testWhenGeneratingTunnelConfiguration_AndKeyIsStillValid_AndKeyIsRegenerated_AndRegistrationFails_ThenKeyDoesNotChange() async { + let server = NetworkProtectionServer.mockBaseServer + let registeredServer = NetworkProtectionServer.mockRegisteredServer + + networkClient.stubGetServers = .success([server]) + networkClient.stubRegister = .success([registeredServer]) + + XCTAssertNil(try? keyStore.storedPrivateKey()) + XCTAssertEqual(try? serverListStore.storedNetworkProtectionServerList(), []) + XCTAssertNil(networkClient.spyRegister) + _ = try? await manager.generateTunnelConfiguration(selectionMethod: .automatic, regenerateKey: false) + + let firstKey = try? keyStore.storedPrivateKey() + XCTAssertNotNil(firstKey) + XCTAssertEqual(try? serverListStore.storedNetworkProtectionServerList(), [registeredServer]) + XCTAssertNotNil(networkClient.spyRegister) + + networkClient.stubRegister = .failure(.failedToEncodeRegisterKeyRequest) + _ = try? await manager.generateTunnelConfiguration(selectionMethod: .automatic, regenerateKey: true) + + let secondKey = try? keyStore.storedPrivateKey() + XCTAssertNotNil(secondKey) + XCTAssertEqual(firstKey, secondKey) // Check that the key did NOT change, even though we tried to regenerate it + XCTAssertEqual(try? serverListStore.storedNetworkProtectionServerList(), [registeredServer]) + XCTAssertNotNil(networkClient.spyRegister) + } + } extension NetworkProtectionDeviceManager { - func generateTunnelConfiguration(selectionMethod: NetworkProtectionServerSelectionMethod) async throws -> (TunnelConfiguration, NetworkProtectionServerInfo) { - try await generateTunnelConfiguration(selectionMethod: selectionMethod, includedRoutes: [], excludedRoutes: [], isKillSwitchEnabled: false) + func generateTunnelConfiguration(selectionMethod: NetworkProtectionServerSelectionMethod, + regenerateKey: Bool) async throws -> (TunnelConfiguration, NetworkProtectionServerInfo) { + try await generateTunnelConfiguration( + selectionMethod: selectionMethod, + includedRoutes: [], + excludedRoutes: [], + isKillSwitchEnabled: false, + regenerateKey: regenerateKey + ) } }