Skip to content

Commit

Permalink
feat: Support key versionless
Browse files Browse the repository at this point in the history
Retrieve latest key version from akv and
put key version into annotation for decryption.

Signed-off-by: Zhecheng Li <[email protected]>
  • Loading branch information
lzhecheng committed Oct 16, 2024
1 parent 2b68d2f commit cf924af
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 73 deletions.
35 changes: 20 additions & 15 deletions cmd/server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,16 @@ import (
)

var (
listenAddr = flag.String("listen-addr", "unix:///opt/azurekms.socket", "gRPC listen address")
keyvaultName = flag.String("keyvault-name", "", "Azure Key Vault name")
keyName = flag.String("key-name", "", "Azure Key Vault KMS key name")
keyVersion = flag.String("key-version", "", "Azure Key Vault KMS key version")
managedHSM = flag.Bool("managed-hsm", false, "Azure Key Vault Managed HSM. Refer to https://docs.microsoft.com/en-us/azure/key-vault/managed-hsm/overview for more details.")
logFormatJSON = flag.Bool("log-format-json", false, "set log formatter to json")
logLevel = flag.Int("v", 0, "In order of increasing verbosity: 0=warning/error, 2=info, 4=debug, 6=trace, 10=all")
listenAddr = flag.String("listen-addr", "unix:///opt/azurekms.socket", "gRPC listen address")
keyvaultName = flag.String("keyvault-name", "", "Azure Key Vault name")
keyName = flag.String("key-name", "", "Azure Key Vault KMS key name")
keyVersion = flag.String("key-version", "", "Azure Key Vault KMS key version")
// keyVersion and keyVersionlessEnabled are not mutually exclusive for compatibility.
// When keyVersionlessEnabled is enabled on an existing-KMS cluster, only new secrets will be encrypted with versionless key.
keyVersionlessEnabled = flag.Bool("key-versionless-enabled", false, "Azure Key Vault KMS key versionless enabled")
managedHSM = flag.Bool("managed-hsm", false, "Azure Key Vault Managed HSM. Refer to https://docs.microsoft.com/en-us/azure/key-vault/managed-hsm/overview for more details.")
logFormatJSON = flag.Bool("log-format-json", false, "set log formatter to json")
logLevel = flag.Int("v", 0, "In order of increasing verbosity: 0=warning/error, 2=info, 4=debug, 6=trace, 10=all")
// TODO remove this flag in future release.
_ = flag.String("configFilePath", "/etc/kubernetes/azure.json", "[DEPRECATED] Path for Azure Cloud Provider config file")
configFilePath = flag.String("config-file-path", "/etc/kubernetes/azure.json", "Path for Azure Cloud Provider config file")
Expand Down Expand Up @@ -90,14 +93,15 @@ func setupKMSPlugin() error {
mlog.Always("Starting KeyManagementServiceServer service", "version", version.BuildVersion, "buildDate", version.BuildDate)

pluginConfig := &plugin.Config{
KeyVaultName: *keyvaultName,
KeyName: *keyName,
KeyVersion: *keyVersion,
ManagedHSM: *managedHSM,
ProxyMode: *proxyMode,
ProxyAddress: *proxyAddress,
ProxyPort: *proxyPort,
ConfigFilePath: *configFilePath,
KeyVaultName: *keyvaultName,
KeyName: *keyName,
KeyVersion: *keyVersion,
KeyVersionlessEnabled: *keyVersionlessEnabled,
ManagedHSM: *managedHSM,
ProxyMode: *proxyMode,
ProxyAddress: *proxyAddress,
ProxyPort: *proxyPort,
ConfigFilePath: *configFilePath,
}

azureConfig, err := config.GetAzureConfig(pluginConfig.ConfigFilePath)
Expand All @@ -110,6 +114,7 @@ func setupKMSPlugin() error {
pluginConfig.KeyVaultName,
pluginConfig.KeyName,
pluginConfig.KeyVersion,
pluginConfig.KeyVersionlessEnabled,
pluginConfig.ProxyMode,
pluginConfig.ProxyAddress,
pluginConfig.ProxyPort,
Expand Down
108 changes: 80 additions & 28 deletions pkg/plugin/keyvault.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ const (
keyvaultRegionAnnotationKey = "x-ms-keyvault-region.azure.akv.io"
versionAnnotationKey = "version.azure.akv.io"
algorithmAnnotationKey = "algorithm.azure.akv.io"
keyVersionAnnotationKey = "keyversion.azure.akv.io"
keyIDHashAnnotationKey = "keyidhash.azure.akv.io"
dateAnnotationValue = "Date"
requestIDAnnotationValue = "X-Ms-Request-Id"
keyvaultRegionAnnotationValue = "X-Ms-Keyvault-Region"
Expand All @@ -64,20 +66,22 @@ type Client interface {

// KeyVaultClient is a client for interacting with Keyvault.
type KeyVaultClient struct {
baseClient kv.BaseClient
config *config.AzureConfig
vaultName string
keyName string
keyVersion string
vaultURL string
keyIDHash string
azureEnvironment *azure.Environment
baseClient kv.BaseClient
config *config.AzureConfig
vaultName string
keyName string
keyVersion string
keyVersionlessEnabled bool
vaultURL string
keyIDHash string
azureEnvironment *azure.Environment
}

// NewKeyVaultClient returns a new key vault client to use for kms operations.
func NewKeyVaultClient(
config *config.AzureConfig,
vaultName, keyName, keyVersion string,
keyVersionlessEnabled bool,
proxyMode bool,
proxyAddress string,
proxyPort int,
Expand All @@ -90,9 +94,10 @@ func NewKeyVaultClient(

// this should be the case for bring your own key, clusters bootstrapped with
// aks-engine or aks and standalone kms plugin deployments
if len(vaultName) == 0 || len(keyName) == 0 || len(keyVersion) == 0 {
return nil, fmt.Errorf("key vault name, key name and key version are required")
if len(vaultName) == 0 || len(keyName) == 0 || (!keyVersionlessEnabled && len(keyVersion) == 0) {
return nil, fmt.Errorf("key vault name, key name and key version (not key versionless enabled) are required")
}

kvClient := kv.New()
err := kvClient.AddToUserAgent(version.GetUserAgent())
if err != nil {
Expand Down Expand Up @@ -121,9 +126,12 @@ func NewKeyVaultClient(
return nil, fmt.Errorf("failed to get vault url, error: %+v", err)
}

keyIDHash, err := getKeyIDHash(*vaultURL, keyName, keyVersion)
if err != nil {
return nil, fmt.Errorf("failed to get key id hash, error: %w", err)
var keyIDHash string
if len(keyVersion) != 0 {
keyIDHash, err = getKeyIDHash(*vaultURL, keyName, keyVersion)
if err != nil {
return nil, fmt.Errorf("failed to get key id hash, error: %w", err)
}
}

if proxyMode {
Expand All @@ -134,18 +142,35 @@ func NewKeyVaultClient(
mlog.Always("using kms key for encrypt/decrypt", "vaultURL", *vaultURL, "keyName", keyName, "keyVersion", keyVersion)

client := &KeyVaultClient{
baseClient: kvClient,
config: config,
vaultName: vaultName,
keyName: keyName,
keyVersion: keyVersion,
vaultURL: *vaultURL,
azureEnvironment: env,
keyIDHash: keyIDHash,
baseClient: kvClient,
config: config,
vaultName: vaultName,
keyName: keyName,
keyVersion: keyVersion,
keyVersionlessEnabled: keyVersionlessEnabled,
vaultURL: *vaultURL,
azureEnvironment: env,
keyIDHash: keyIDHash,
}
return client, nil
}

func (kvc *KeyVaultClient) GetLatestKeyVersion(ctx context.Context) (string, error) {
keyBundle, err := kvc.baseClient.GetKey(ctx, kvc.vaultURL, kvc.keyName, "")
if err != nil {
return "", fmt.Errorf("failed to get key, error: %+v", err)
}
if keyBundle.Key == nil || keyBundle.Key.Kid == nil {
return "", fmt.Errorf("failed to get latest key version, key bundle is empty for keyvault %q, key %q", kvc.vaultName, kvc.keyName)
}
kidSplitted := strings.Split(*keyBundle.Key.Kid, "/")
if len(kidSplitted) == 0 {
return "", fmt.Errorf("failed to get latest key version, key id is invalid %q", *keyBundle.Key.Kid)
}
latestKeyVersion := kidSplitted[len(kidSplitted)-1]
return latestKeyVersion, nil
}

// Encrypt encrypts the given plain text using the keyvault key.
func (kvc *KeyVaultClient) Encrypt(
ctx context.Context,
Expand All @@ -158,15 +183,29 @@ func (kvc *KeyVaultClient) Encrypt(
Algorithm: encryptionAlgorithm,
Value: &value,
}
result, err := kvc.baseClient.Encrypt(ctx, kvc.vaultURL, kvc.keyName, kvc.keyVersion, params)

keyVersion := kvc.keyVersion
keyIDHash := kvc.keyIDHash
if kvc.keyVersionlessEnabled {
var err error
if keyVersion, err = kvc.GetLatestKeyVersion(ctx); err != nil {
return nil, fmt.Errorf("failed to get latest key version, error: %+v", err)
}

if keyIDHash, err = getKeyIDHash(kvc.vaultURL, kvc.keyName, keyVersion); err != nil {
return nil, fmt.Errorf("failed to get key id hash, error: %w", err)
}
}

result, err := kvc.baseClient.Encrypt(ctx, kvc.vaultURL, kvc.keyName, keyVersion, params)
if err != nil {
return nil, fmt.Errorf("failed to encrypt, error: %+v", err)
}

if kvc.keyIDHash != fmt.Sprintf("%x", sha256.Sum256([]byte(*result.Kid))) {
if keyIDHash != fmt.Sprintf("%x", sha256.Sum256([]byte(*result.Kid))) {
return nil, fmt.Errorf(
"key id initialized does not match with the key id from encryption result, expected: %s, got: %s",
kvc.keyIDHash,
keyIDHash,
*result.Kid,
)
}
Expand All @@ -177,11 +216,14 @@ func (kvc *KeyVaultClient) Encrypt(
keyvaultRegionAnnotationKey: []byte(result.Header.Get(keyvaultRegionAnnotationValue)),
versionAnnotationKey: []byte(encryptionResponseVersion),
algorithmAnnotationKey: []byte(encryptionAlgorithm),
keyVersionAnnotationKey: []byte(keyVersion),
keyIDHashAnnotationKey: []byte(keyIDHash),
}

mlog.Info("Encryption succeeded", "vaultName", kvc.vaultName, "keyName", kvc.keyName, "keyVersion", keyVersion)
return &service.EncryptResponse{
Ciphertext: []byte(*result.Result),
KeyID: kvc.keyIDHash,
KeyID: keyIDHash,
Annotations: annotations,
}, nil
}
Expand All @@ -208,7 +250,12 @@ func (kvc *KeyVaultClient) Decrypt(
Value: &value,
}

result, err := kvc.baseClient.Decrypt(ctx, kvc.vaultURL, kvc.keyName, kvc.keyVersion, params)
keyVersion := kvc.keyVersion
if len(annotations[keyVersionAnnotationKey]) != 0 {
keyVersion = string(annotations[keyVersionAnnotationKey])
}

result, err := kvc.baseClient.Decrypt(ctx, kvc.vaultURL, kvc.keyName, keyVersion, params)
if err != nil {
return nil, fmt.Errorf("failed to decrypt, error: %+v", err)
}
Expand All @@ -217,6 +264,7 @@ func (kvc *KeyVaultClient) Decrypt(
return nil, fmt.Errorf("failed to base64 decode result, error: %+v", err)
}

mlog.Info("Decryption succeeded", "vaultName", kvc.vaultName, "keyName", kvc.keyName, "keyVersion", keyVersion)
return bytes, nil
}

Expand All @@ -241,11 +289,15 @@ func (kvc *KeyVaultClient) validateAnnotations(
return fmt.Errorf("invalid annotations, annotations cannot be empty")
}

if keyID != kvc.keyIDHash {
expectedKeyIDHash := kvc.keyIDHash
if len(annotations[keyIDHashAnnotationKey]) != 0 {
expectedKeyIDHash = string(annotations[keyIDHashAnnotationKey])
}
if keyID != expectedKeyIDHash {
return fmt.Errorf(
"key id %s does not match expected key id %s used for encryption",
keyID,
kvc.keyIDHash,
expectedKeyIDHash,
)
}

Expand Down
56 changes: 34 additions & 22 deletions pkg/plugin/keyvault_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,16 @@ var (

func TestNewKeyVaultClientError(t *testing.T) {
tests := []struct {
desc string
config *config.AzureConfig
vaultName string
keyName string
keyVersion string
proxyMode bool
proxyAddress string
proxyPort int
managedHSM bool
desc string
config *config.AzureConfig
vaultName string
keyName string
keyVersion string
keyVersionlessEnabled bool
proxyMode bool
proxyAddress string
proxyPort int
managedHSM bool
}{
{
desc: "vault name not provided",
Expand All @@ -43,7 +44,7 @@ func TestNewKeyVaultClientError(t *testing.T) {
proxyMode: false,
},
{
desc: "key version not provided",
desc: "key version not provided when not keyVersionlessEnabled",
config: &config.AzureConfig{},
vaultName: "testkv",
keyName: "k8s",
Expand All @@ -68,7 +69,7 @@ func TestNewKeyVaultClientError(t *testing.T) {

for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
if _, err := NewKeyVaultClient(test.config, test.vaultName, test.keyName, test.keyVersion, test.proxyMode, test.proxyAddress, test.proxyPort, test.managedHSM); err == nil {
if _, err := NewKeyVaultClient(test.config, test.vaultName, test.keyName, test.keyVersion, test.keyVersionlessEnabled, test.proxyMode, test.proxyAddress, test.proxyPort, test.managedHSM); err == nil {
t.Fatalf("newKeyVaultClient() expected error, got nil")
}
})
Expand All @@ -77,16 +78,17 @@ func TestNewKeyVaultClientError(t *testing.T) {

func TestNewKeyVaultClient(t *testing.T) {
tests := []struct {
desc string
config *config.AzureConfig
vaultName string
keyName string
keyVersion string
proxyMode bool
proxyAddress string
proxyPort int
managedHSM bool
expectedVaultURL string
desc string
config *config.AzureConfig
vaultName string
keyName string
keyVersion string
keyVersionlessEnabled bool
proxyMode bool
proxyAddress string
proxyPort int
managedHSM bool
expectedVaultURL string
}{
{
desc: "no error",
Expand Down Expand Up @@ -127,11 +129,21 @@ func TestNewKeyVaultClient(t *testing.T) {
proxyMode: false,
expectedVaultURL: "https://testkv.managedhsm.azure.net/",
},
{
desc: "no error when no key version with keyVersionlessEnabled",
config: &config.AzureConfig{ClientID: "clientid", ClientSecret: "clientsecret"},
vaultName: "testkv",
keyName: "key1",
keyVersion: "",
keyVersionlessEnabled: true,
proxyMode: false,
expectedVaultURL: "https://testkv.vault.azure.net/",
},
}

for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
kvClient, err := NewKeyVaultClient(test.config, test.vaultName, test.keyName, test.keyVersion, test.proxyMode, test.proxyAddress, test.proxyPort, test.managedHSM)
kvClient, err := NewKeyVaultClient(test.config, test.vaultName, test.keyName, test.keyVersion, test.keyVersionlessEnabled, test.proxyMode, test.proxyAddress, test.proxyPort, test.managedHSM)
if err != nil {
t.Fatalf("newKeyVaultClient() failed with error: %v", err)
}
Expand Down
17 changes: 9 additions & 8 deletions pkg/plugin/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,15 @@ type KeyManagementServiceServer struct {

// Config is the configuration for the KMS plugin.
type Config struct {
ConfigFilePath string
KeyVaultName string
KeyName string
KeyVersion string
ManagedHSM bool
ProxyMode bool
ProxyAddress string
ProxyPort int
ConfigFilePath string
KeyVaultName string
KeyName string
KeyVersion string
KeyVersionlessEnabled bool
ManagedHSM bool
ProxyMode bool
ProxyAddress string
ProxyPort int
}

// NewKMSv1Server creates an instance of the KMS Service Server.
Expand Down

0 comments on commit cf924af

Please sign in to comment.