Skip to content

Commit

Permalink
feat: Support key versionless
Browse files Browse the repository at this point in the history
Retrieve latest key version from akv and
put key version into annotation for decryption.

Signed-off-by: Zhecheng Li <[email protected]>
  • Loading branch information
lzhecheng committed Oct 16, 2024
1 parent 2b68d2f commit 2953169
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 73 deletions.
35 changes: 20 additions & 15 deletions cmd/server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,16 @@ import (
)

var (
listenAddr = flag.String("listen-addr", "unix:///opt/azurekms.socket", "gRPC listen address")
keyvaultName = flag.String("keyvault-name", "", "Azure Key Vault name")
keyName = flag.String("key-name", "", "Azure Key Vault KMS key name")
keyVersion = flag.String("key-version", "", "Azure Key Vault KMS key version")
managedHSM = flag.Bool("managed-hsm", false, "Azure Key Vault Managed HSM. Refer to https://docs.microsoft.com/en-us/azure/key-vault/managed-hsm/overview for more details.")
logFormatJSON = flag.Bool("log-format-json", false, "set log formatter to json")
logLevel = flag.Int("v", 0, "In order of increasing verbosity: 0=warning/error, 2=info, 4=debug, 6=trace, 10=all")
listenAddr = flag.String("listen-addr", "unix:///opt/azurekms.socket", "gRPC listen address")
keyvaultName = flag.String("keyvault-name", "", "Azure Key Vault name")
keyName = flag.String("key-name", "", "Azure Key Vault KMS key name")
keyVersion = flag.String("key-version", "", "Azure Key Vault KMS key version")
// keyVersion and keyVersionlessEnabled are not mutually exclusive for compatibility.
// When keyVersionlessEnabled is enabled on an existing-KMS cluster, only new secrets will be encrypted with versionless key.
keyVersionlessEnabled = flag.Bool("key-versionless-enabled", false, "Azure Key Vault KMS key versionless enabled")
managedHSM = flag.Bool("managed-hsm", false, "Azure Key Vault Managed HSM. Refer to https://docs.microsoft.com/en-us/azure/key-vault/managed-hsm/overview for more details.")
logFormatJSON = flag.Bool("log-format-json", false, "set log formatter to json")
logLevel = flag.Int("v", 0, "In order of increasing verbosity: 0=warning/error, 2=info, 4=debug, 6=trace, 10=all")
// TODO remove this flag in future release.
_ = flag.String("configFilePath", "/etc/kubernetes/azure.json", "[DEPRECATED] Path for Azure Cloud Provider config file")
configFilePath = flag.String("config-file-path", "/etc/kubernetes/azure.json", "Path for Azure Cloud Provider config file")
Expand Down Expand Up @@ -90,14 +93,15 @@ func setupKMSPlugin() error {
mlog.Always("Starting KeyManagementServiceServer service", "version", version.BuildVersion, "buildDate", version.BuildDate)

pluginConfig := &plugin.Config{
KeyVaultName: *keyvaultName,
KeyName: *keyName,
KeyVersion: *keyVersion,
ManagedHSM: *managedHSM,
ProxyMode: *proxyMode,
ProxyAddress: *proxyAddress,
ProxyPort: *proxyPort,
ConfigFilePath: *configFilePath,
KeyVaultName: *keyvaultName,
KeyName: *keyName,
KeyVersion: *keyVersion,
KeyVersionlessEnabled: *keyVersionlessEnabled,
ManagedHSM: *managedHSM,
ProxyMode: *proxyMode,
ProxyAddress: *proxyAddress,
ProxyPort: *proxyPort,
ConfigFilePath: *configFilePath,
}

azureConfig, err := config.GetAzureConfig(pluginConfig.ConfigFilePath)
Expand All @@ -110,6 +114,7 @@ func setupKMSPlugin() error {
pluginConfig.KeyVaultName,
pluginConfig.KeyName,
pluginConfig.KeyVersion,
pluginConfig.KeyVersionlessEnabled,
pluginConfig.ProxyMode,
pluginConfig.ProxyAddress,
pluginConfig.ProxyPort,
Expand Down
125 changes: 97 additions & 28 deletions pkg/plugin/keyvault.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"path"
"regexp"
"strings"
"time"

"github.com/Azure/kubernetes-kms/pkg/auth"
"github.com/Azure/kubernetes-kms/pkg/config"
Expand All @@ -38,6 +39,8 @@ const (
keyvaultRegionAnnotationKey = "x-ms-keyvault-region.azure.akv.io"
versionAnnotationKey = "version.azure.akv.io"
algorithmAnnotationKey = "algorithm.azure.akv.io"
keyVersionAnnotationKey = "keyversion.azure.akv.io"
keyIDHashAnnotationKey = "keyidhash.azure.akv.io"
dateAnnotationValue = "Date"
requestIDAnnotationValue = "X-Ms-Request-Id"
keyvaultRegionAnnotationValue = "X-Ms-Keyvault-Region"
Expand All @@ -64,20 +67,22 @@ type Client interface {

// KeyVaultClient is a client for interacting with Keyvault.
type KeyVaultClient struct {
baseClient kv.BaseClient
config *config.AzureConfig
vaultName string
keyName string
keyVersion string
vaultURL string
keyIDHash string
azureEnvironment *azure.Environment
baseClient kv.BaseClient
config *config.AzureConfig
vaultName string
keyName string
keyVersion string
keyVersionlessEnabled bool
vaultURL string
keyIDHash string
azureEnvironment *azure.Environment
}

// NewKeyVaultClient returns a new key vault client to use for kms operations.
func NewKeyVaultClient(
config *config.AzureConfig,
vaultName, keyName, keyVersion string,
keyVersionlessEnabled bool,
proxyMode bool,
proxyAddress string,
proxyPort int,
Expand All @@ -90,9 +95,10 @@ func NewKeyVaultClient(

// this should be the case for bring your own key, clusters bootstrapped with
// aks-engine or aks and standalone kms plugin deployments
if len(vaultName) == 0 || len(keyName) == 0 || len(keyVersion) == 0 {
return nil, fmt.Errorf("key vault name, key name and key version are required")
if len(vaultName) == 0 || len(keyName) == 0 || (!keyVersionlessEnabled && len(keyVersion) == 0) {
return nil, fmt.Errorf("key vault name, key name and key version (not key versionless enabled) are required")
}

kvClient := kv.New()
err := kvClient.AddToUserAgent(version.GetUserAgent())
if err != nil {
Expand Down Expand Up @@ -121,9 +127,12 @@ func NewKeyVaultClient(
return nil, fmt.Errorf("failed to get vault url, error: %+v", err)
}

keyIDHash, err := getKeyIDHash(*vaultURL, keyName, keyVersion)
if err != nil {
return nil, fmt.Errorf("failed to get key id hash, error: %w", err)
var keyIDHash string
if len(keyVersion) != 0 {
keyIDHash, err = getKeyIDHash(*vaultURL, keyName, keyVersion)
if err != nil {
return nil, fmt.Errorf("failed to get key id hash, error: %w", err)
}
}

if proxyMode {
Expand All @@ -134,18 +143,51 @@ func NewKeyVaultClient(
mlog.Always("using kms key for encrypt/decrypt", "vaultURL", *vaultURL, "keyName", keyName, "keyVersion", keyVersion)

client := &KeyVaultClient{
baseClient: kvClient,
config: config,
vaultName: vaultName,
keyName: keyName,
keyVersion: keyVersion,
vaultURL: *vaultURL,
azureEnvironment: env,
keyIDHash: keyIDHash,
baseClient: kvClient,
config: config,
vaultName: vaultName,
keyName: keyName,
keyVersion: keyVersion,
keyVersionlessEnabled: keyVersionlessEnabled,
vaultURL: *vaultURL,
azureEnvironment: env,
keyIDHash: keyIDHash,
}
return client, nil
}

func (kvc *KeyVaultClient) GetLatestKeyVersion(ctx context.Context) (string, error) {
keyVersionResultPage, err := kvc.baseClient.GetKeyVersions(ctx, kvc.vaultURL, kvc.keyName, nil)
if err != nil {
return "", fmt.Errorf("failed to get key versions, error: %+v", err)
}
var latestKeyVersionItem kv.KeyItem
for keyVersionResultPage.NotDone() {
for _, value := range keyVersionResultPage.Values() {
if latestKeyVersionItem.Kid == nil {
latestKeyVersionItem = value
} else {
updatedTimeCurrent := time.Time(*value.Attributes.Updated)
updatedTimeLatest := time.Time(*latestKeyVersionItem.Attributes.Updated)
if updatedTimeCurrent.After(updatedTimeLatest) {
latestKeyVersionItem = value
}
}
}
keyVersionResultPage.Next()
}

if latestKeyVersionItem.Kid == nil {
return "", fmt.Errorf("failed to get latest key version, key id is nil")
}
kidSplitted := strings.Split(*latestKeyVersionItem.Kid, "/")
if len(kidSplitted) == 0 {
return "", fmt.Errorf("failed to get latest key version, key id is invalid %q", *latestKeyVersionItem.Kid)
}
latestKeyVersion := kidSplitted[len(kidSplitted)-1]
return latestKeyVersion, nil
}

// Encrypt encrypts the given plain text using the keyvault key.
func (kvc *KeyVaultClient) Encrypt(
ctx context.Context,
Expand All @@ -158,15 +200,29 @@ func (kvc *KeyVaultClient) Encrypt(
Algorithm: encryptionAlgorithm,
Value: &value,
}
result, err := kvc.baseClient.Encrypt(ctx, kvc.vaultURL, kvc.keyName, kvc.keyVersion, params)

keyVersion := kvc.keyVersion
keyIDHash := kvc.keyIDHash
if kvc.keyVersionlessEnabled {
var err error
if keyVersion, err = kvc.GetLatestKeyVersion(ctx); err != nil {
return nil, fmt.Errorf("failed to get latest key version, error: %+v", err)
}

if keyIDHash, err = getKeyIDHash(kvc.vaultURL, kvc.keyName, keyVersion); err != nil {
return nil, fmt.Errorf("failed to get key id hash, error: %w", err)
}
}

result, err := kvc.baseClient.Encrypt(ctx, kvc.vaultURL, kvc.keyName, keyVersion, params)
if err != nil {
return nil, fmt.Errorf("failed to encrypt, error: %+v", err)
}

if kvc.keyIDHash != fmt.Sprintf("%x", sha256.Sum256([]byte(*result.Kid))) {
if keyIDHash != fmt.Sprintf("%x", sha256.Sum256([]byte(*result.Kid))) {
return nil, fmt.Errorf(
"key id initialized does not match with the key id from encryption result, expected: %s, got: %s",
kvc.keyIDHash,
keyIDHash,
*result.Kid,
)
}
Expand All @@ -177,11 +233,14 @@ func (kvc *KeyVaultClient) Encrypt(
keyvaultRegionAnnotationKey: []byte(result.Header.Get(keyvaultRegionAnnotationValue)),
versionAnnotationKey: []byte(encryptionResponseVersion),
algorithmAnnotationKey: []byte(encryptionAlgorithm),
keyVersionAnnotationKey: []byte(keyVersion),
keyIDHashAnnotationKey: []byte(keyIDHash),
}

mlog.Info("Encryption succeeded", "vaultName", kvc.vaultName, "keyName", kvc.keyName, "keyVersion", keyVersion)
return &service.EncryptResponse{
Ciphertext: []byte(*result.Result),
KeyID: kvc.keyIDHash,
KeyID: keyIDHash,
Annotations: annotations,
}, nil
}
Expand All @@ -208,7 +267,12 @@ func (kvc *KeyVaultClient) Decrypt(
Value: &value,
}

result, err := kvc.baseClient.Decrypt(ctx, kvc.vaultURL, kvc.keyName, kvc.keyVersion, params)
keyVersion := kvc.keyVersion
if len(annotations[keyVersionAnnotationKey]) != 0 {
keyVersion = string(annotations[keyVersionAnnotationKey])
}

result, err := kvc.baseClient.Decrypt(ctx, kvc.vaultURL, kvc.keyName, keyVersion, params)
if err != nil {
return nil, fmt.Errorf("failed to decrypt, error: %+v", err)
}
Expand All @@ -217,6 +281,7 @@ func (kvc *KeyVaultClient) Decrypt(
return nil, fmt.Errorf("failed to base64 decode result, error: %+v", err)
}

mlog.Info("Decryption succeeded", "vaultName", kvc.vaultName, "keyName", kvc.keyName, "keyVersion", keyVersion)
return bytes, nil
}

Expand All @@ -241,11 +306,15 @@ func (kvc *KeyVaultClient) validateAnnotations(
return fmt.Errorf("invalid annotations, annotations cannot be empty")
}

if keyID != kvc.keyIDHash {
expectedKeyIDHash := kvc.keyIDHash
if len(annotations[keyIDHashAnnotationKey]) != 0 {
expectedKeyIDHash = string(annotations[keyIDHashAnnotationKey])
}
if keyID != expectedKeyIDHash {
return fmt.Errorf(
"key id %s does not match expected key id %s used for encryption",
keyID,
kvc.keyIDHash,
expectedKeyIDHash,
)
}

Expand Down
56 changes: 34 additions & 22 deletions pkg/plugin/keyvault_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,16 @@ var (

func TestNewKeyVaultClientError(t *testing.T) {
tests := []struct {
desc string
config *config.AzureConfig
vaultName string
keyName string
keyVersion string
proxyMode bool
proxyAddress string
proxyPort int
managedHSM bool
desc string
config *config.AzureConfig
vaultName string
keyName string
keyVersion string
keyVersionlessEnabled bool
proxyMode bool
proxyAddress string
proxyPort int
managedHSM bool
}{
{
desc: "vault name not provided",
Expand All @@ -43,7 +44,7 @@ func TestNewKeyVaultClientError(t *testing.T) {
proxyMode: false,
},
{
desc: "key version not provided",
desc: "key version not provided when not keyVersionlessEnabled",
config: &config.AzureConfig{},
vaultName: "testkv",
keyName: "k8s",
Expand All @@ -68,7 +69,7 @@ func TestNewKeyVaultClientError(t *testing.T) {

for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
if _, err := NewKeyVaultClient(test.config, test.vaultName, test.keyName, test.keyVersion, test.proxyMode, test.proxyAddress, test.proxyPort, test.managedHSM); err == nil {
if _, err := NewKeyVaultClient(test.config, test.vaultName, test.keyName, test.keyVersion, test.keyVersionlessEnabled, test.proxyMode, test.proxyAddress, test.proxyPort, test.managedHSM); err == nil {
t.Fatalf("newKeyVaultClient() expected error, got nil")
}
})
Expand All @@ -77,16 +78,17 @@ func TestNewKeyVaultClientError(t *testing.T) {

func TestNewKeyVaultClient(t *testing.T) {
tests := []struct {
desc string
config *config.AzureConfig
vaultName string
keyName string
keyVersion string
proxyMode bool
proxyAddress string
proxyPort int
managedHSM bool
expectedVaultURL string
desc string
config *config.AzureConfig
vaultName string
keyName string
keyVersion string
keyVersionlessEnabled bool
proxyMode bool
proxyAddress string
proxyPort int
managedHSM bool
expectedVaultURL string
}{
{
desc: "no error",
Expand Down Expand Up @@ -127,11 +129,21 @@ func TestNewKeyVaultClient(t *testing.T) {
proxyMode: false,
expectedVaultURL: "https://testkv.managedhsm.azure.net/",
},
{
desc: "no error when no key version with keyVersionlessEnabled",
config: &config.AzureConfig{ClientID: "clientid", ClientSecret: "clientsecret"},
vaultName: "testkv",
keyName: "key1",
keyVersion: "",
keyVersionlessEnabled: true,
proxyMode: false,
expectedVaultURL: "https://testkv.vault.azure.net/",
},
}

for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
kvClient, err := NewKeyVaultClient(test.config, test.vaultName, test.keyName, test.keyVersion, test.proxyMode, test.proxyAddress, test.proxyPort, test.managedHSM)
kvClient, err := NewKeyVaultClient(test.config, test.vaultName, test.keyName, test.keyVersion, test.keyVersionlessEnabled, test.proxyMode, test.proxyAddress, test.proxyPort, test.managedHSM)
if err != nil {
t.Fatalf("newKeyVaultClient() failed with error: %v", err)
}
Expand Down
Loading

0 comments on commit 2953169

Please sign in to comment.