Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve VPN rekeying reliability #664

Merged
merged 12 commits into from
Feb 18, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ final class NetworkProtectionConnectionTester {
self?.timer = nil
}

isRunning = true
samsymons marked this conversation as resolved.
Show resolved Hide resolved
samsymons marked this conversation as resolved.
Show resolved Hide resolved
timer.resume()
}

Expand Down Expand Up @@ -242,7 +243,7 @@ final class NetworkProtectionConnectionTester {
let onlyVPNIsDown = simulateFailure || (!vpnIsConnected && localIsConnected)
simulateFailure = false

// After completing the conection tests we check if the tester is still supposed to be running
// After completing the connection tests we check if the tester is still supposed to be running
// to avoid giving results when it should not be running.
guard isRunning else {
os_log("Tester skipped returning results as it was stopped while running the tests", log: log, type: .info)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,26 @@ public protocol NetworkProtectionKeyStore {
///
func currentKeyPair() -> KeyPair

/// Sets the validity interval for keys
/// Create a new `KeyPair`.
///
func newKeyPair() -> KeyPair

/// Sets the validity interval for keys
///
func setValidityInterval(_ validityInterval: TimeInterval?)

/// Updates the existing KeyPair.
///
func updateKeyPair(_ newKeyPair: KeyPair)

/// Updates the current `KeyPair` to have the specified expiration date
///
/// - Parameters:
/// - newExpirationDate: the new expiration date for the keypair
/// - newExpirationDate: the new expiration date for the KeyPair
///
/// - Returns: a new keypair with the specified updates
/// - Returns: a new KeyPair with the specified updates
///
func updateCurrentKeyPair(newExpirationDate: Date) -> KeyPair
func updateKeyPairExpirationDate(_ newDate: Date) -> KeyPair
samsymons marked this conversation as resolved.
Show resolved Hide resolved

/// Resets the current `KeyPair` so a new one will be generated when requested.
///
Expand Down Expand Up @@ -103,6 +110,10 @@ public final class NetworkProtectionKeychainKeyStore: NetworkProtectionKeyStore
return KeyPair(privateKey: currentPrivateKey, expirationDate: currentExpirationDate)
}

public func newKeyPair() -> KeyPair {
return newCurrentKeyPair()
samsymons marked this conversation as resolved.
Show resolved Hide resolved
}

private var validityInterval = Defaults.validityInterval

public func setValidityInterval(_ validityInterval: TimeInterval?) {
Expand All @@ -123,8 +134,13 @@ public final class NetworkProtectionKeychainKeyStore: NetworkProtectionKeyStore
return KeyPair(privateKey: currentPrivateKey, expirationDate: currentExpirationDate)
}

public func updateCurrentKeyPair(newExpirationDate: Date) -> KeyPair {
currentExpirationDate = newExpirationDate
public func updateKeyPair(_ newKeyPair: KeyPair) {
self.currentPrivateKey = newKeyPair.privateKey
self.currentExpirationDate = newKeyPair.expirationDate
}

public func updateKeyPairExpirationDate(_ newDate: Date) -> KeyPair {
self.currentExpirationDate = Date().addingTimeInterval(.seconds(30)) // newDate
return currentKeyPair()
}

Expand Down
49 changes: 27 additions & 22 deletions Sources/NetworkProtection/NetworkProtectionDeviceManager.swift
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ public protocol NetworkProtectionDeviceManagement {
func generateTunnelConfiguration(selectionMethod: NetworkProtectionServerSelectionMethod,
includedRoutes: [IPAddressRange],
excludedRoutes: [IPAddressRange],
isKillSwitchEnabled: Bool) async throws -> (TunnelConfiguration, NetworkProtectionServerInfo)
isKillSwitchEnabled: Bool,
regenerateKey: Bool) async throws -> (TunnelConfiguration, NetworkProtectionServerInfo)

}

Expand Down Expand Up @@ -124,9 +125,26 @@ public actor NetworkProtectionDeviceManager: NetworkProtectionDeviceManagement {
public func generateTunnelConfiguration(selectionMethod: NetworkProtectionServerSelectionMethod,
includedRoutes: [IPAddressRange],
excludedRoutes: [IPAddressRange],
isKillSwitchEnabled: Bool) async throws -> (TunnelConfiguration, NetworkProtectionServerInfo) {
isKillSwitchEnabled: Bool,
regenerateKey: Bool) async throws -> (TunnelConfiguration, NetworkProtectionServerInfo) {
var keyPair: KeyPair

if regenerateKey {
keyPair = keyStore.newKeyPair()
} else {
keyPair = keyStore.currentKeyPair()
}

let (selectedServer, newExpiration) = try await register(keyPair: keyPair, selectionMethod: selectionMethod)

let (selectedServer, keyPair) = try await register(selectionMethod: selectionMethod)
// If we're regenerating the key, then we know at this point it has been successfully registered. It's now safe to replace the old key.
if regenerateKey {
keyStore.updateKeyPair(keyPair)
}

if let newExpiration {
keyPair = keyStore.updateKeyPairExpirationDate(newExpiration)
}

do {
let configuration = try tunnelConfiguration(interfacePrivateKey: keyPair.privateKey,
Expand All @@ -152,13 +170,11 @@ public actor NetworkProtectionDeviceManager: NetworkProtectionDeviceManagement {
// - keyPair: the key pair that was used to register with the server, and that should be used to configure the tunnel
//
// - Throws:`NetworkProtectionError`
// This cannot be a doc comment because of the swiftlint command below
// swiftlint:disable cyclomatic_complexity
private func register(selectionMethod: NetworkProtectionServerSelectionMethod) async throws -> (server: NetworkProtectionServer,
keyPair: KeyPair) {
private func register(keyPair: KeyPair,
selectionMethod: NetworkProtectionServerSelectionMethod) async throws -> (server: NetworkProtectionServer,
newExpiration: Date?) {

guard let token = try? tokenStore.fetchToken() else { throw NetworkProtectionError.noAuthTokenFound }
var keyPair = keyStore.currentKeyPair()

let serverSelection: RegisterServerSelection
let excludedServerName: String?
Expand Down Expand Up @@ -203,29 +219,18 @@ public actor NetworkProtectionDeviceManager: NetworkProtectionDeviceManagement {
errorEvents?.fire(NetworkProtectionError.serverListInconsistency)

let cachedServer = try cachedServer(registeredWith: keyPair)
return (cachedServer, keyPair)
return (cachedServer, nil)
}

selectedServer = registeredServer

// We should not need this IF condition here, because we know registered servers will give us an expiration date,
// but since the structure we're currently using makes the expiration date optional we need to have it.
// We should consider changing our server structure to not allow a missing expiration date here.
if let serverExpirationDate = selectedServer.expirationDate,
keyPair.expirationDate > serverExpirationDate {

keyPair = keyStore.updateCurrentKeyPair(newExpirationDate: serverExpirationDate)
}

return (selectedServer, keyPair)
return (selectedServer, selectedServer.expirationDate)
samsymons marked this conversation as resolved.
Show resolved Hide resolved
case .failure(let error):
handle(clientError: error)

let cachedServer = try cachedServer(registeredWith: keyPair)
return (cachedServer, keyPair)
return (cachedServer, nil)
}
}
// swiftlint:enable cyclomatic_complexity

/// Retrieves the first cached server that's registered with the specified key pair.
///
Expand Down
50 changes: 33 additions & 17 deletions Sources/NetworkProtection/PacketTunnelProvider.swift
Original file line number Diff line number Diff line change
Expand Up @@ -167,11 +167,9 @@ open class PacketTunnelProvider: NEPacketTunnelProvider {

os_log("Rekeying...", log: .networkProtectionKeyManagement)

providerEvents.fire(.rekeyCompleted)
samsymons marked this conversation as resolved.
Show resolved Hide resolved
self.resetRegistrationKey()

do {
try await updateTunnelConfiguration(reassert: false)
try await updateTunnelConfiguration(reassert: false, regenerateKey: true)
providerEvents.fire(.rekeyCompleted)
} catch {
os_log("Rekey attempt failed. This is not an error if you're using debug Key Management options: %{public}@", log: .networkProtectionKeyManagement, type: .error, String(describing: error))
}
Expand Down Expand Up @@ -203,13 +201,13 @@ open class PacketTunnelProvider: NEPacketTunnelProvider {
Task {
await updateBandwidthAnalyzer()

// This provides a more frequent active user pixel check
providerEvents.fire(.userBecameActive)
samsymons marked this conversation as resolved.
Show resolved Hide resolved

guard self.bandwidthAnalyzer.isConnectionIdle() else {
return
}

// This provides a more frequent active user pixel check
providerEvents.fire(.userBecameActive)

await rekeyIfExpired()
}
}
Expand Down Expand Up @@ -553,7 +551,8 @@ open class PacketTunnelProvider: NEPacketTunnelProvider {
let tunnelConfiguration = try await generateTunnelConfiguration(environment: environment,
serverSelectionMethod: currentServerSelectionMethod,
includedRoutes: includedRoutes ?? [],
excludedRoutes: settings.excludedRanges)
excludedRoutes: settings.excludedRanges,
regenerateKey: false)
startTunnel(with: tunnelConfiguration, onDemand: onDemand, completionHandler: completionHandler)
os_log("🔵 Done generating tunnel config", log: .networkProtection, type: .info)
} catch {
Expand Down Expand Up @@ -658,17 +657,26 @@ open class PacketTunnelProvider: NEPacketTunnelProvider {
// MARK: - Tunnel Configuration

@MainActor
public func updateTunnelConfiguration(reassert: Bool = true) async throws {
try await updateTunnelConfiguration(environment: settings.selectedEnvironment, serverSelectionMethod: currentServerSelectionMethod, reassert: reassert)
public func updateTunnelConfiguration(reassert: Bool = true, regenerateKey: Bool = false) async throws {
try await updateTunnelConfiguration(
environment: settings.selectedEnvironment,
serverSelectionMethod: currentServerSelectionMethod,
reassert: reassert,
regenerateKey: regenerateKey
)
}

@MainActor
public func updateTunnelConfiguration(environment: VPNSettings.SelectedEnvironment = .default, serverSelectionMethod: NetworkProtectionServerSelectionMethod, reassert: Bool = true) async throws {
public func updateTunnelConfiguration(environment: VPNSettings.SelectedEnvironment = .default,
serverSelectionMethod: NetworkProtectionServerSelectionMethod,
reassert: Bool = true,
regenerateKey: Bool = false) async throws {

let tunnelConfiguration = try await generateTunnelConfiguration(environment: environment,
serverSelectionMethod: serverSelectionMethod,
includedRoutes: includedRoutes ?? [],
excludedRoutes: settings.excludedRanges)
excludedRoutes: settings.excludedRanges,
regenerateKey: regenerateKey)

try await withCheckedThrowingContinuation { [weak self] (continuation: CheckedContinuation<Void, Error>) in
guard let self = self else {
Expand Down Expand Up @@ -699,7 +707,11 @@ open class PacketTunnelProvider: NEPacketTunnelProvider {
}

@MainActor
private func generateTunnelConfiguration(environment: VPNSettings.SelectedEnvironment = .default, serverSelectionMethod: NetworkProtectionServerSelectionMethod, includedRoutes: [IPAddressRange], excludedRoutes: [IPAddressRange]) async throws -> TunnelConfiguration {
private func generateTunnelConfiguration(environment: VPNSettings.SelectedEnvironment = .default,
serverSelectionMethod: NetworkProtectionServerSelectionMethod,
includedRoutes: [IPAddressRange],
excludedRoutes: [IPAddressRange],
regenerateKey: Bool) async throws -> TunnelConfiguration {

let configurationResult: (TunnelConfiguration, NetworkProtectionServerInfo)

Expand All @@ -710,7 +722,13 @@ open class PacketTunnelProvider: NEPacketTunnelProvider {
keyStore: keyStore,
errorEvents: debugEvents)

configurationResult = try await deviceManager.generateTunnelConfiguration(selectionMethod: serverSelectionMethod, includedRoutes: includedRoutes, excludedRoutes: excludedRoutes, isKillSwitchEnabled: isKillSwitchEnabled)
configurationResult = try await deviceManager.generateTunnelConfiguration(
selectionMethod: serverSelectionMethod,
includedRoutes: includedRoutes,
excludedRoutes: excludedRoutes,
isKillSwitchEnabled: isKillSwitchEnabled,
regenerateKey: regenerateKey
)
} catch {
throw TunnelError.couldNotGenerateTunnelConfiguration(internalError: error)
}
Expand All @@ -724,9 +742,7 @@ open class PacketTunnelProvider: NEPacketTunnelProvider {
selectedServerInfo.name)
os_log("🔵 Excluded routes: %{public}@", log: .networkProtection, type: .info, String(describing: excludedRoutes))

let tunnelConfiguration = configurationResult.0

return tunnelConfiguration
return configurationResult.0
}

// MARK: - App Messages
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,15 @@ final class NetworkProtectionKeyStoreMock: NetworkProtectionKeyStore {
}
}

func updateCurrentKeyPair(newExpirationDate: Date) -> NetworkProtection.KeyPair {
func newKeyPair() -> NetworkProtection.KeyPair {
return KeyPair(privateKey: PrivateKey(), expirationDate: Date().addingTimeInterval(.day))
}

public func updateKeyPair(_ newKeyPair: KeyPair) {
self.keyPair = newKeyPair
}

func updateKeyPairExpirationDate(_ newExpirationDate: Date) -> NetworkProtection.KeyPair {
let keyPair = KeyPair(privateKey: keyPair?.privateKey ?? PrivateKey(), expirationDate: newExpirationDate)
self.keyPair = keyPair
return keyPair
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ extension NetworkProtectionServerInfo {

extension NetworkProtectionServer {

static let mockBaseServer = NetworkProtectionServer(registeredPublicKey: nil, allowedIPs: nil, serverInfo: .mock, expirationDate: nil)
static let mockBaseServer = NetworkProtectionServer(registeredPublicKey: nil, allowedIPs: nil, serverInfo: .mock, expirationDate: Date())
static let mockRegisteredServer = NetworkProtectionServer(registeredPublicKey: "ovn9RpzUuvQ4XLQt6B3RKuEXGIxa5QpTnehjduZlcSE=",
allowedIPs: ["0.0.0.0/0", "::/0"],
serverInfo: .mock,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,13 @@ final class NetworkProtectionDeviceManagerTests: XCTestCase {
extension NetworkProtectionDeviceManager {

func generateTunnelConfiguration(selectionMethod: NetworkProtectionServerSelectionMethod) async throws -> (TunnelConfiguration, NetworkProtectionServerInfo) {
try await generateTunnelConfiguration(selectionMethod: selectionMethod, includedRoutes: [], excludedRoutes: [], isKillSwitchEnabled: false)
try await generateTunnelConfiguration(
selectionMethod: selectionMethod,
includedRoutes: [],
excludedRoutes: [],
isKillSwitchEnabled: false,
regenerateKey: false
)
}

}
Loading