Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve VPN rekeying reliability #664

Merged
merged 12 commits into from
Feb 18, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ final class NetworkProtectionConnectionTester {
func stop() async {
os_log("🔴 Stopping connection tester", log: log)
await stopScheduledTimer()
isRunning = false
}

// MARK: - Obtaining the interface
Expand Down Expand Up @@ -201,7 +202,6 @@ final class NetworkProtectionConnectionTester {
}

private func stopScheduledTimer() async {
isRunning = false
cancelTimerImmediately()
}

Expand Down Expand Up @@ -242,7 +242,7 @@ final class NetworkProtectionConnectionTester {
let onlyVPNIsDown = simulateFailure || (!vpnIsConnected && localIsConnected)
simulateFailure = false

// After completing the conection tests we check if the tester is still supposed to be running
// After completing the connection tests we check if the tester is still supposed to be running
// to avoid giving results when it should not be running.
guard isRunning else {
os_log("Tester skipped returning results as it was stopped while running the tests", log: log, type: .info)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,19 @@ public protocol NetworkProtectionKeyStore {

/// Obtain the current `KeyPair`.
///
func currentKeyPair() -> KeyPair
func currentKeyPair() -> KeyPair?

/// Sets the validity interval for keys
/// Create a new `KeyPair`.
///
func newKeyPair() -> KeyPair

/// Sets the validity interval for keys
///
func setValidityInterval(_ validityInterval: TimeInterval?)

/// Updates the current `KeyPair` to have the specified expiration date
///
/// - Parameters:
/// - newExpirationDate: the new expiration date for the keypair
/// Updates the existing KeyPair.
///
/// - Returns: a new keypair with the specified updates
///
func updateCurrentKeyPair(newExpirationDate: Date) -> KeyPair
func updateKeyPair(_ newKeyPair: KeyPair)

/// Resets the current `KeyPair` so a new one will be generated when requested.
///
Expand Down Expand Up @@ -75,34 +73,38 @@ public final class NetworkProtectionKeychainKeyStore: NetworkProtectionKeyStore

// MARK: - NetworkProtectionKeyStore

public func currentKeyPair() -> KeyPair {
public func currentKeyPair() -> KeyPair? {
os_log("Querying the current key pair (publicKey: %{public}@, expirationDate: %{public}@)",
log: .networkProtectionKeyManagement,
String(describing: currentPublicKey),
String(describing: currentExpirationDate))

guard let currentPrivateKey = currentPrivateKey else {
let keyPair = newCurrentKeyPair()
os_log("Returning a new key pair as there's no current private key (newPublicKey: %{public}@)",
log: .networkProtectionKeyManagement,
String(describing: keyPair.publicKey.base64Key))
return keyPair
os_log("There's no current private key.",
log: .networkProtectionKeyManagement)
return nil
}

guard let currentExpirationDate = currentExpirationDate,
Date().addingTimeInterval(validityInterval) >= currentExpirationDate else {

let keyPair = newCurrentKeyPair()
os_log("Returning a new key pair as the expirationDate date is missing, or we're past it (now: %{public}@, expirationDate: %{public}@)",
os_log("The expirationDate date is missing, or we're past it (now: %{public}@, expirationDate: %{public}@)",
log: .networkProtectionKeyManagement,
String(describing: Date()),
String(describing: currentExpirationDate))
return keyPair
return nil
}

return KeyPair(privateKey: currentPrivateKey, expirationDate: currentExpirationDate)
}

public func newKeyPair() -> KeyPair {
let newPrivateKey = PrivateKey()
let newExpirationDate = Date().addingTimeInterval(validityInterval)

return KeyPair(privateKey: newPrivateKey, expirationDate: newExpirationDate)
}

private var validityInterval = Defaults.validityInterval

public func setValidityInterval(_ validityInterval: TimeInterval?) {
Expand All @@ -123,9 +125,14 @@ public final class NetworkProtectionKeychainKeyStore: NetworkProtectionKeyStore
return KeyPair(privateKey: currentPrivateKey, expirationDate: currentExpirationDate)
}

public func updateCurrentKeyPair(newExpirationDate: Date) -> KeyPair {
currentExpirationDate = newExpirationDate
return currentKeyPair()
public func updateKeyPair(_ newKeyPair: KeyPair) {
if currentPrivateKey != newKeyPair.privateKey {
self.currentPrivateKey = newKeyPair.privateKey
}

if currentExpirationDate != newKeyPair.expirationDate {
self.currentExpirationDate = newKeyPair.expirationDate
}
}

public func resetCurrentKeyPair() {
Expand Down
48 changes: 26 additions & 22 deletions Sources/NetworkProtection/NetworkProtectionDeviceManager.swift
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ public protocol NetworkProtectionDeviceManagement {
func generateTunnelConfiguration(selectionMethod: NetworkProtectionServerSelectionMethod,
includedRoutes: [IPAddressRange],
excludedRoutes: [IPAddressRange],
isKillSwitchEnabled: Bool) async throws -> (TunnelConfiguration, NetworkProtectionServerInfo)
isKillSwitchEnabled: Bool,
regenerateKey: Bool) async throws -> (TunnelConfiguration, NetworkProtectionServerInfo)

}

Expand Down Expand Up @@ -124,9 +125,25 @@ public actor NetworkProtectionDeviceManager: NetworkProtectionDeviceManagement {
public func generateTunnelConfiguration(selectionMethod: NetworkProtectionServerSelectionMethod,
includedRoutes: [IPAddressRange],
excludedRoutes: [IPAddressRange],
isKillSwitchEnabled: Bool) async throws -> (TunnelConfiguration, NetworkProtectionServerInfo) {
isKillSwitchEnabled: Bool,
regenerateKey: Bool) async throws -> (TunnelConfiguration, NetworkProtectionServerInfo) {
var keyPair: KeyPair

if regenerateKey {
keyPair = keyStore.newKeyPair()
} else {
keyPair = keyStore.currentKeyPair() ?? keyStore.newKeyPair()
}

let (selectedServer, newExpiration) = try await register(keyPair: keyPair, selectionMethod: selectionMethod)
os_log("Server registration successul", log: .networkProtection)

let (selectedServer, keyPair) = try await register(selectionMethod: selectionMethod)
keyStore.updateKeyPair(keyPair)

if let newExpiration {
keyPair = KeyPair(privateKey: keyPair.privateKey, expirationDate: newExpiration)
keyStore.updateKeyPair(keyPair)
}

do {
let configuration = try tunnelConfiguration(interfacePrivateKey: keyPair.privateKey,
Expand All @@ -152,13 +169,11 @@ public actor NetworkProtectionDeviceManager: NetworkProtectionDeviceManagement {
// - keyPair: the key pair that was used to register with the server, and that should be used to configure the tunnel
//
// - Throws:`NetworkProtectionError`
// This cannot be a doc comment because of the swiftlint command below
// swiftlint:disable cyclomatic_complexity
private func register(selectionMethod: NetworkProtectionServerSelectionMethod) async throws -> (server: NetworkProtectionServer,
keyPair: KeyPair) {
private func register(keyPair: KeyPair,
selectionMethod: NetworkProtectionServerSelectionMethod) async throws -> (server: NetworkProtectionServer,
newExpiration: Date?) {

guard let token = try? tokenStore.fetchToken() else { throw NetworkProtectionError.noAuthTokenFound }
var keyPair = keyStore.currentKeyPair()

let serverSelection: RegisterServerSelection
let excludedServerName: String?
Expand Down Expand Up @@ -203,29 +218,18 @@ public actor NetworkProtectionDeviceManager: NetworkProtectionDeviceManagement {
errorEvents?.fire(NetworkProtectionError.serverListInconsistency)

let cachedServer = try cachedServer(registeredWith: keyPair)
return (cachedServer, keyPair)
return (cachedServer, nil)
}

selectedServer = registeredServer

// We should not need this IF condition here, because we know registered servers will give us an expiration date,
// but since the structure we're currently using makes the expiration date optional we need to have it.
// We should consider changing our server structure to not allow a missing expiration date here.
if let serverExpirationDate = selectedServer.expirationDate,
keyPair.expirationDate > serverExpirationDate {

keyPair = keyStore.updateCurrentKeyPair(newExpirationDate: serverExpirationDate)
}

return (selectedServer, keyPair)
return (selectedServer, selectedServer.expirationDate)
samsymons marked this conversation as resolved.
Show resolved Hide resolved
case .failure(let error):
handle(clientError: error)

let cachedServer = try cachedServer(registeredWith: keyPair)
return (cachedServer, keyPair)
return (cachedServer, nil)
}
}
// swiftlint:enable cyclomatic_complexity

/// Retrieves the first cached server that's registered with the specified key pair.
///
Expand Down
68 changes: 48 additions & 20 deletions Sources/NetworkProtection/PacketTunnelProvider.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
// SPDX-License-Identifier: MIT
// Copyright © 2018-2021 WireGuard LLC. All Rights Reserved.

// swiftlint:disable file_length

import Combine
import Common
import Foundation
Expand Down Expand Up @@ -146,11 +148,18 @@ open class PacketTunnelProvider: NEPacketTunnelProvider {
}

private var isKeyExpired: Bool {
keyStore.currentKeyPair().expirationDate <= Date()
guard let currentExpirationDate = keyStore.currentExpirationDate else {
return true
}

return currentExpirationDate <= Date()
}

private func rekeyIfExpired() async {
os_log("Checking if rekey is necessary...", log: .networkProtectionKeyManagement)

guard isKeyExpired else {
os_log("The key is not expired", log: .networkProtectionKeyManagement)
return
}

Expand All @@ -162,16 +171,15 @@ open class PacketTunnelProvider: NEPacketTunnelProvider {

// Experimental option to disable rekeying.
guard !settings.disableRekeying else {
os_log("Rekeying disabled", log: .networkProtectionKeyManagement)
return
}

os_log("Rekeying...", log: .networkProtectionKeyManagement)

providerEvents.fire(.rekeyCompleted)
samsymons marked this conversation as resolved.
Show resolved Hide resolved
self.resetRegistrationKey()

do {
try await updateTunnelConfiguration(reassert: false)
try await updateTunnelConfiguration(reassert: false, regenerateKey: true)
providerEvents.fire(.rekeyCompleted)
} catch {
os_log("Rekey attempt failed. This is not an error if you're using debug Key Management options: %{public}@", log: .networkProtectionKeyManagement, type: .error, String(describing: error))
}
Expand Down Expand Up @@ -203,13 +211,13 @@ open class PacketTunnelProvider: NEPacketTunnelProvider {
Task {
await updateBandwidthAnalyzer()

// This provides a more frequent active user pixel check
providerEvents.fire(.userBecameActive)
samsymons marked this conversation as resolved.
Show resolved Hide resolved

guard self.bandwidthAnalyzer.isConnectionIdle() else {
return
}

// This provides a more frequent active user pixel check
providerEvents.fire(.userBecameActive)

await rekeyIfExpired()
}
}
Expand Down Expand Up @@ -553,7 +561,8 @@ open class PacketTunnelProvider: NEPacketTunnelProvider {
let tunnelConfiguration = try await generateTunnelConfiguration(environment: environment,
serverSelectionMethod: currentServerSelectionMethod,
includedRoutes: includedRoutes ?? [],
excludedRoutes: settings.excludedRanges)
excludedRoutes: settings.excludedRanges,
regenerateKey: false)
startTunnel(with: tunnelConfiguration, onDemand: onDemand, completionHandler: completionHandler)
os_log("🔵 Done generating tunnel config", log: .networkProtection, type: .info)
} catch {
Expand Down Expand Up @@ -658,17 +667,26 @@ open class PacketTunnelProvider: NEPacketTunnelProvider {
// MARK: - Tunnel Configuration

@MainActor
public func updateTunnelConfiguration(reassert: Bool = true) async throws {
try await updateTunnelConfiguration(environment: settings.selectedEnvironment, serverSelectionMethod: currentServerSelectionMethod, reassert: reassert)
public func updateTunnelConfiguration(reassert: Bool = true, regenerateKey: Bool = false) async throws {
try await updateTunnelConfiguration(
environment: settings.selectedEnvironment,
serverSelectionMethod: currentServerSelectionMethod,
reassert: reassert,
regenerateKey: regenerateKey
)
}

@MainActor
public func updateTunnelConfiguration(environment: VPNSettings.SelectedEnvironment = .default, serverSelectionMethod: NetworkProtectionServerSelectionMethod, reassert: Bool = true) async throws {
public func updateTunnelConfiguration(environment: VPNSettings.SelectedEnvironment = .default,
serverSelectionMethod: NetworkProtectionServerSelectionMethod,
reassert: Bool = true,
regenerateKey: Bool = false) async throws {

let tunnelConfiguration = try await generateTunnelConfiguration(environment: environment,
serverSelectionMethod: serverSelectionMethod,
includedRoutes: includedRoutes ?? [],
excludedRoutes: settings.excludedRanges)
excludedRoutes: settings.excludedRanges,
regenerateKey: regenerateKey)

try await withCheckedThrowingContinuation { [weak self] (continuation: CheckedContinuation<Void, Error>) in
guard let self = self else {
Expand Down Expand Up @@ -699,7 +717,11 @@ open class PacketTunnelProvider: NEPacketTunnelProvider {
}

@MainActor
private func generateTunnelConfiguration(environment: VPNSettings.SelectedEnvironment = .default, serverSelectionMethod: NetworkProtectionServerSelectionMethod, includedRoutes: [IPAddressRange], excludedRoutes: [IPAddressRange]) async throws -> TunnelConfiguration {
private func generateTunnelConfiguration(environment: VPNSettings.SelectedEnvironment = .default,
serverSelectionMethod: NetworkProtectionServerSelectionMethod,
includedRoutes: [IPAddressRange],
excludedRoutes: [IPAddressRange],
regenerateKey: Bool) async throws -> TunnelConfiguration {

let configurationResult: (TunnelConfiguration, NetworkProtectionServerInfo)

Expand All @@ -710,7 +732,13 @@ open class PacketTunnelProvider: NEPacketTunnelProvider {
keyStore: keyStore,
errorEvents: debugEvents)

configurationResult = try await deviceManager.generateTunnelConfiguration(selectionMethod: serverSelectionMethod, includedRoutes: includedRoutes, excludedRoutes: excludedRoutes, isKillSwitchEnabled: isKillSwitchEnabled)
configurationResult = try await deviceManager.generateTunnelConfiguration(
selectionMethod: serverSelectionMethod,
includedRoutes: includedRoutes,
excludedRoutes: excludedRoutes,
isKillSwitchEnabled: isKillSwitchEnabled,
regenerateKey: regenerateKey
)
} catch {
throw TunnelError.couldNotGenerateTunnelConfiguration(internalError: error)
}
Expand All @@ -724,9 +752,7 @@ open class PacketTunnelProvider: NEPacketTunnelProvider {
selectedServerInfo.name)
os_log("🔵 Excluded routes: %{public}@", log: .networkProtection, type: .info, String(describing: excludedRoutes))

let tunnelConfiguration = configurationResult.0

return tunnelConfiguration
return configurationResult.0
}

// MARK: - App Messages
Expand All @@ -740,7 +766,7 @@ open class PacketTunnelProvider: NEPacketTunnelProvider {

switch message {
case .request(let request):
handleRequest(request)
handleRequest(request, completionHandler: completionHandler)
case .expireRegistrationKey:
handleExpireRegistrationKey(completionHandler: completionHandler)
case .getLastErrorMessage:
Expand Down Expand Up @@ -846,7 +872,7 @@ open class PacketTunnelProvider: NEPacketTunnelProvider {
.setNetworkPathChange,
.setDisableRekeying:
// Intentional no-op, as some setting changes don't require any further operation
break
completionHandler?(nil)
}
}

Expand Down Expand Up @@ -1178,3 +1204,5 @@ extension WireGuardAdapterError: LocalizedError, CustomDebugStringConvertible {
}

}

// swiftlint:enable file_length
Loading
Loading