Skip to content

Commit

Permalink
Improve VPN rekeying reliability (#664)
Browse files Browse the repository at this point in the history
Required:

Task/Issue URL: https://app.asana.com/0/414235014887631/1206607513978260/f
iOS PR: duckduckgo/iOS#2478
macOS PR: duckduckgo/macos-browser#2207
What kind of version bump will this require?: Major

Description:

This PR fixes three issues:

The rekeying implementation was deleting the old key, and then registering a new one - however, if the registration call failed, then it got left in a confused state with a connection that didn't work
The connection tester is what is responsible for triggering rekeying and the DAU pixel, but it was accidentally setting isRunning to false even when it was running just fine
The DAU event was sent only when the connection was idle, which isn't necessary
  • Loading branch information
samsymons authored Feb 18, 2024
1 parent 93a9a41 commit 39a0ed6
Show file tree
Hide file tree
Showing 7 changed files with 208 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ final class NetworkProtectionConnectionTester {
func stop() async {
os_log("🔴 Stopping connection tester", log: log)
await stopScheduledTimer()
isRunning = false
}

// MARK: - Obtaining the interface
Expand Down Expand Up @@ -201,7 +202,6 @@ final class NetworkProtectionConnectionTester {
}

private func stopScheduledTimer() async {
isRunning = false
cancelTimerImmediately()
}

Expand Down Expand Up @@ -242,7 +242,7 @@ final class NetworkProtectionConnectionTester {
let onlyVPNIsDown = simulateFailure || (!vpnIsConnected && localIsConnected)
simulateFailure = false

// After completing the conection tests we check if the tester is still supposed to be running
// After completing the connection tests we check if the tester is still supposed to be running
// to avoid giving results when it should not be running.
guard isRunning else {
os_log("Tester skipped returning results as it was stopped while running the tests", log: log, type: .info)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,19 @@ public protocol NetworkProtectionKeyStore {

/// Obtain the current `KeyPair`.
///
func currentKeyPair() -> KeyPair
func currentKeyPair() -> KeyPair?

/// Sets the validity interval for keys
/// Create a new `KeyPair`.
///
func newKeyPair() -> KeyPair

/// Sets the validity interval for keys
///
func setValidityInterval(_ validityInterval: TimeInterval?)

/// Updates the current `KeyPair` to have the specified expiration date
///
/// - Parameters:
/// - newExpirationDate: the new expiration date for the keypair
/// Updates the existing KeyPair.
///
/// - Returns: a new keypair with the specified updates
///
func updateCurrentKeyPair(newExpirationDate: Date) -> KeyPair
func updateKeyPair(_ newKeyPair: KeyPair)

/// Resets the current `KeyPair` so a new one will be generated when requested.
///
Expand Down Expand Up @@ -75,34 +73,38 @@ public final class NetworkProtectionKeychainKeyStore: NetworkProtectionKeyStore

// MARK: - NetworkProtectionKeyStore

public func currentKeyPair() -> KeyPair {
public func currentKeyPair() -> KeyPair? {
os_log("Querying the current key pair (publicKey: %{public}@, expirationDate: %{public}@)",
log: .networkProtectionKeyManagement,
String(describing: currentPublicKey),
String(describing: currentExpirationDate))

guard let currentPrivateKey = currentPrivateKey else {
let keyPair = newCurrentKeyPair()
os_log("Returning a new key pair as there's no current private key (newPublicKey: %{public}@)",
log: .networkProtectionKeyManagement,
String(describing: keyPair.publicKey.base64Key))
return keyPair
os_log("There's no current private key.",
log: .networkProtectionKeyManagement)
return nil
}

guard let currentExpirationDate = currentExpirationDate,
Date().addingTimeInterval(validityInterval) >= currentExpirationDate else {

let keyPair = newCurrentKeyPair()
os_log("Returning a new key pair as the expirationDate date is missing, or we're past it (now: %{public}@, expirationDate: %{public}@)",
os_log("The expirationDate date is missing, or we're past it (now: %{public}@, expirationDate: %{public}@)",
log: .networkProtectionKeyManagement,
String(describing: Date()),
String(describing: currentExpirationDate))
return keyPair
return nil
}

return KeyPair(privateKey: currentPrivateKey, expirationDate: currentExpirationDate)
}

public func newKeyPair() -> KeyPair {
let newPrivateKey = PrivateKey()
let newExpirationDate = Date().addingTimeInterval(validityInterval)

return KeyPair(privateKey: newPrivateKey, expirationDate: newExpirationDate)
}

private var validityInterval = Defaults.validityInterval

public func setValidityInterval(_ validityInterval: TimeInterval?) {
Expand All @@ -123,9 +125,14 @@ public final class NetworkProtectionKeychainKeyStore: NetworkProtectionKeyStore
return KeyPair(privateKey: currentPrivateKey, expirationDate: currentExpirationDate)
}

public func updateCurrentKeyPair(newExpirationDate: Date) -> KeyPair {
currentExpirationDate = newExpirationDate
return currentKeyPair()
public func updateKeyPair(_ newKeyPair: KeyPair) {
if currentPrivateKey != newKeyPair.privateKey {
self.currentPrivateKey = newKeyPair.privateKey
}

if currentExpirationDate != newKeyPair.expirationDate {
self.currentExpirationDate = newKeyPair.expirationDate
}
}

public func resetCurrentKeyPair() {
Expand Down
48 changes: 26 additions & 22 deletions Sources/NetworkProtection/NetworkProtectionDeviceManager.swift
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ public protocol NetworkProtectionDeviceManagement {
func generateTunnelConfiguration(selectionMethod: NetworkProtectionServerSelectionMethod,
includedRoutes: [IPAddressRange],
excludedRoutes: [IPAddressRange],
isKillSwitchEnabled: Bool) async throws -> (TunnelConfiguration, NetworkProtectionServerInfo)
isKillSwitchEnabled: Bool,
regenerateKey: Bool) async throws -> (TunnelConfiguration, NetworkProtectionServerInfo)

}

Expand Down Expand Up @@ -124,9 +125,25 @@ public actor NetworkProtectionDeviceManager: NetworkProtectionDeviceManagement {
public func generateTunnelConfiguration(selectionMethod: NetworkProtectionServerSelectionMethod,
includedRoutes: [IPAddressRange],
excludedRoutes: [IPAddressRange],
isKillSwitchEnabled: Bool) async throws -> (TunnelConfiguration, NetworkProtectionServerInfo) {
isKillSwitchEnabled: Bool,
regenerateKey: Bool) async throws -> (TunnelConfiguration, NetworkProtectionServerInfo) {
var keyPair: KeyPair

if regenerateKey {
keyPair = keyStore.newKeyPair()
} else {
keyPair = keyStore.currentKeyPair() ?? keyStore.newKeyPair()
}

let (selectedServer, newExpiration) = try await register(keyPair: keyPair, selectionMethod: selectionMethod)
os_log("Server registration successul", log: .networkProtection)

let (selectedServer, keyPair) = try await register(selectionMethod: selectionMethod)
keyStore.updateKeyPair(keyPair)

if let newExpiration {
keyPair = KeyPair(privateKey: keyPair.privateKey, expirationDate: newExpiration)
keyStore.updateKeyPair(keyPair)
}

do {
let configuration = try tunnelConfiguration(interfacePrivateKey: keyPair.privateKey,
Expand All @@ -152,13 +169,11 @@ public actor NetworkProtectionDeviceManager: NetworkProtectionDeviceManagement {
// - keyPair: the key pair that was used to register with the server, and that should be used to configure the tunnel
//
// - Throws:`NetworkProtectionError`
// This cannot be a doc comment because of the swiftlint command below
// swiftlint:disable cyclomatic_complexity
private func register(selectionMethod: NetworkProtectionServerSelectionMethod) async throws -> (server: NetworkProtectionServer,
keyPair: KeyPair) {
private func register(keyPair: KeyPair,
selectionMethod: NetworkProtectionServerSelectionMethod) async throws -> (server: NetworkProtectionServer,
newExpiration: Date?) {

guard let token = try? tokenStore.fetchToken() else { throw NetworkProtectionError.noAuthTokenFound }
var keyPair = keyStore.currentKeyPair()

let serverSelection: RegisterServerSelection
let excludedServerName: String?
Expand Down Expand Up @@ -203,29 +218,18 @@ public actor NetworkProtectionDeviceManager: NetworkProtectionDeviceManagement {
errorEvents?.fire(NetworkProtectionError.serverListInconsistency)

let cachedServer = try cachedServer(registeredWith: keyPair)
return (cachedServer, keyPair)
return (cachedServer, nil)
}

selectedServer = registeredServer

// We should not need this IF condition here, because we know registered servers will give us an expiration date,
// but since the structure we're currently using makes the expiration date optional we need to have it.
// We should consider changing our server structure to not allow a missing expiration date here.
if let serverExpirationDate = selectedServer.expirationDate,
keyPair.expirationDate > serverExpirationDate {

keyPair = keyStore.updateCurrentKeyPair(newExpirationDate: serverExpirationDate)
}

return (selectedServer, keyPair)
return (selectedServer, selectedServer.expirationDate)
case .failure(let error):
handle(clientError: error)

let cachedServer = try cachedServer(registeredWith: keyPair)
return (cachedServer, keyPair)
return (cachedServer, nil)
}
}
// swiftlint:enable cyclomatic_complexity

/// Retrieves the first cached server that's registered with the specified key pair.
///
Expand Down
68 changes: 48 additions & 20 deletions Sources/NetworkProtection/PacketTunnelProvider.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
// SPDX-License-Identifier: MIT
// Copyright © 2018-2021 WireGuard LLC. All Rights Reserved.

// swiftlint:disable file_length

import Combine
import Common
import Foundation
Expand Down Expand Up @@ -146,11 +148,18 @@ open class PacketTunnelProvider: NEPacketTunnelProvider {
}

private var isKeyExpired: Bool {
keyStore.currentKeyPair().expirationDate <= Date()
guard let currentExpirationDate = keyStore.currentExpirationDate else {
return true
}

return currentExpirationDate <= Date()
}

private func rekeyIfExpired() async {
os_log("Checking if rekey is necessary...", log: .networkProtectionKeyManagement)

guard isKeyExpired else {
os_log("The key is not expired", log: .networkProtectionKeyManagement)
return
}

Expand All @@ -162,16 +171,15 @@ open class PacketTunnelProvider: NEPacketTunnelProvider {

// Experimental option to disable rekeying.
guard !settings.disableRekeying else {
os_log("Rekeying disabled", log: .networkProtectionKeyManagement)
return
}

os_log("Rekeying...", log: .networkProtectionKeyManagement)

providerEvents.fire(.rekeyCompleted)
self.resetRegistrationKey()

do {
try await updateTunnelConfiguration(reassert: false)
try await updateTunnelConfiguration(reassert: false, regenerateKey: true)
providerEvents.fire(.rekeyCompleted)
} catch {
os_log("Rekey attempt failed. This is not an error if you're using debug Key Management options: %{public}@", log: .networkProtectionKeyManagement, type: .error, String(describing: error))
}
Expand Down Expand Up @@ -203,13 +211,13 @@ open class PacketTunnelProvider: NEPacketTunnelProvider {
Task {
await updateBandwidthAnalyzer()

// This provides a more frequent active user pixel check
providerEvents.fire(.userBecameActive)

guard self.bandwidthAnalyzer.isConnectionIdle() else {
return
}

// This provides a more frequent active user pixel check
providerEvents.fire(.userBecameActive)

await rekeyIfExpired()
}
}
Expand Down Expand Up @@ -553,7 +561,8 @@ open class PacketTunnelProvider: NEPacketTunnelProvider {
let tunnelConfiguration = try await generateTunnelConfiguration(environment: environment,
serverSelectionMethod: currentServerSelectionMethod,
includedRoutes: includedRoutes ?? [],
excludedRoutes: settings.excludedRanges)
excludedRoutes: settings.excludedRanges,
regenerateKey: false)
startTunnel(with: tunnelConfiguration, onDemand: onDemand, completionHandler: completionHandler)
os_log("🔵 Done generating tunnel config", log: .networkProtection, type: .info)
} catch {
Expand Down Expand Up @@ -658,17 +667,26 @@ open class PacketTunnelProvider: NEPacketTunnelProvider {
// MARK: - Tunnel Configuration

@MainActor
public func updateTunnelConfiguration(reassert: Bool = true) async throws {
try await updateTunnelConfiguration(environment: settings.selectedEnvironment, serverSelectionMethod: currentServerSelectionMethod, reassert: reassert)
public func updateTunnelConfiguration(reassert: Bool = true, regenerateKey: Bool = false) async throws {
try await updateTunnelConfiguration(
environment: settings.selectedEnvironment,
serverSelectionMethod: currentServerSelectionMethod,
reassert: reassert,
regenerateKey: regenerateKey
)
}

@MainActor
public func updateTunnelConfiguration(environment: VPNSettings.SelectedEnvironment = .default, serverSelectionMethod: NetworkProtectionServerSelectionMethod, reassert: Bool = true) async throws {
public func updateTunnelConfiguration(environment: VPNSettings.SelectedEnvironment = .default,
serverSelectionMethod: NetworkProtectionServerSelectionMethod,
reassert: Bool = true,
regenerateKey: Bool = false) async throws {

let tunnelConfiguration = try await generateTunnelConfiguration(environment: environment,
serverSelectionMethod: serverSelectionMethod,
includedRoutes: includedRoutes ?? [],
excludedRoutes: settings.excludedRanges)
excludedRoutes: settings.excludedRanges,
regenerateKey: regenerateKey)

try await withCheckedThrowingContinuation { [weak self] (continuation: CheckedContinuation<Void, Error>) in
guard let self = self else {
Expand Down Expand Up @@ -699,7 +717,11 @@ open class PacketTunnelProvider: NEPacketTunnelProvider {
}

@MainActor
private func generateTunnelConfiguration(environment: VPNSettings.SelectedEnvironment = .default, serverSelectionMethod: NetworkProtectionServerSelectionMethod, includedRoutes: [IPAddressRange], excludedRoutes: [IPAddressRange]) async throws -> TunnelConfiguration {
private func generateTunnelConfiguration(environment: VPNSettings.SelectedEnvironment = .default,
serverSelectionMethod: NetworkProtectionServerSelectionMethod,
includedRoutes: [IPAddressRange],
excludedRoutes: [IPAddressRange],
regenerateKey: Bool) async throws -> TunnelConfiguration {

let configurationResult: (TunnelConfiguration, NetworkProtectionServerInfo)

Expand All @@ -710,7 +732,13 @@ open class PacketTunnelProvider: NEPacketTunnelProvider {
keyStore: keyStore,
errorEvents: debugEvents)

configurationResult = try await deviceManager.generateTunnelConfiguration(selectionMethod: serverSelectionMethod, includedRoutes: includedRoutes, excludedRoutes: excludedRoutes, isKillSwitchEnabled: isKillSwitchEnabled)
configurationResult = try await deviceManager.generateTunnelConfiguration(
selectionMethod: serverSelectionMethod,
includedRoutes: includedRoutes,
excludedRoutes: excludedRoutes,
isKillSwitchEnabled: isKillSwitchEnabled,
regenerateKey: regenerateKey
)
} catch {
throw TunnelError.couldNotGenerateTunnelConfiguration(internalError: error)
}
Expand All @@ -724,9 +752,7 @@ open class PacketTunnelProvider: NEPacketTunnelProvider {
selectedServerInfo.name)
os_log("🔵 Excluded routes: %{public}@", log: .networkProtection, type: .info, String(describing: excludedRoutes))

let tunnelConfiguration = configurationResult.0

return tunnelConfiguration
return configurationResult.0
}

// MARK: - App Messages
Expand All @@ -740,7 +766,7 @@ open class PacketTunnelProvider: NEPacketTunnelProvider {

switch message {
case .request(let request):
handleRequest(request)
handleRequest(request, completionHandler: completionHandler)
case .expireRegistrationKey:
handleExpireRegistrationKey(completionHandler: completionHandler)
case .getLastErrorMessage:
Expand Down Expand Up @@ -846,7 +872,7 @@ open class PacketTunnelProvider: NEPacketTunnelProvider {
.setNetworkPathChange,
.setDisableRekeying:
// Intentional no-op, as some setting changes don't require any further operation
break
completionHandler?(nil)
}
}

Expand Down Expand Up @@ -1178,3 +1204,5 @@ extension WireGuardAdapterError: LocalizedError, CustomDebugStringConvertible {
}

}

// swiftlint:enable file_length
Loading

0 comments on commit 39a0ed6

Please sign in to comment.