From e0561eba7c3503b50950ec4d071f380a6b3962c7 Mon Sep 17 00:00:00 2001 From: eaudetcobello Date: Wed, 11 Dec 2024 16:58:28 -0500 Subject: [PATCH] Unit test ValidateCAPIAuthTokenAccessHandler (#885) --- .../pkg/k8sd/api/capi_access_handler_test.go | 90 +++++++++++++++++++ src/k8s/pkg/k8sd/database/capi_auth_test.go | 10 ++- .../pkg/k8sd/database/cluster_config_test.go | 16 ++-- .../pkg/k8sd/database/feature_status_test.go | 6 +- .../database/kubernetes_auth_tokens_test.go | 14 +-- src/k8s/pkg/k8sd/database/worker_test.go | 6 +- .../microcluster/state.go} | 29 +++--- 7 files changed, 133 insertions(+), 38 deletions(-) create mode 100644 src/k8s/pkg/k8sd/api/capi_access_handler_test.go rename src/k8s/pkg/{k8sd/database/util_test.go => utils/microcluster/state.go} (77%) diff --git a/src/k8s/pkg/k8sd/api/capi_access_handler_test.go b/src/k8s/pkg/k8sd/api/capi_access_handler_test.go new file mode 100644 index 000000000..00b12bc47 --- /dev/null +++ b/src/k8s/pkg/k8sd/api/capi_access_handler_test.go @@ -0,0 +1,90 @@ +package api_test + +import ( + "context" + "database/sql" + "net/http" + "testing" + + "github.com/canonical/k8s/pkg/k8sd/api" + "github.com/canonical/k8s/pkg/k8sd/database" + testenv "github.com/canonical/k8s/pkg/utils/microcluster" + "github.com/canonical/microcluster/v2/state" + . "github.com/onsi/gomega" +) + +func TestValidateCAPIAuthTokenAccessHandler(t *testing.T) { + g := NewWithT(t) + + for _, tc := range []struct { + name string + tokenHeaderContent string + tokenDBContent string + expectErr bool + }{ + { + name: "valid token", + tokenHeaderContent: "test-token", + tokenDBContent: "test-token", + expectErr: false, + }, + { + name: "wrong token in header", + tokenHeaderContent: "invalid-token", + tokenDBContent: "expected-token", + expectErr: true, + }, + { + name: "wrong token in db", + tokenHeaderContent: "expected-token", + tokenDBContent: "invalid-token", + expectErr: true, + }, + { + name: "empty token in header", + tokenHeaderContent: "", + tokenDBContent: "test-token", + expectErr: true, + }, + { + name: "empty token in db", + tokenHeaderContent: "test-token", + tokenDBContent: "", + expectErr: true, + }, + { + name: "empty token in header and db", + tokenHeaderContent: "", + tokenDBContent: "", + expectErr: true, + }, + } { + t.Run(tc.name, func(t *testing.T) { + testenv.WithState(t, func(ctx context.Context, s state.State) { + var err error + if tc.tokenDBContent != "" { + err = s.Database().Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error { + return database.SetClusterAPIToken(ctx, tx, tc.tokenDBContent) + }) + g.Expect(err).To(Not(HaveOccurred())) + } + + req := &http.Request{ + Header: make(http.Header), + } + req.Header.Set("Capi-Auth-Token", tc.tokenHeaderContent) + + handler := api.ValidateCAPIAuthTokenAccessHandler("Capi-Auth-Token") + valid, resp := handler(s, req) + + if tc.expectErr { + g.Expect(valid).To(BeFalse()) + g.Expect(resp).To(Not(BeNil())) + } else { + g.Expect(valid).To(BeTrue()) + g.Expect(resp).To(BeNil()) + } + }) + }) + } +} diff --git a/src/k8s/pkg/k8sd/database/capi_auth_test.go b/src/k8s/pkg/k8sd/database/capi_auth_test.go index 2ffbcdf46..4a1d8c78b 100644 --- a/src/k8s/pkg/k8sd/database/capi_auth_test.go +++ b/src/k8s/pkg/k8sd/database/capi_auth_test.go @@ -6,16 +6,18 @@ import ( "testing" "github.com/canonical/k8s/pkg/k8sd/database" + testenv "github.com/canonical/k8s/pkg/utils/microcluster" + "github.com/canonical/microcluster/v2/state" . "github.com/onsi/gomega" ) func TestClusterAPIAuthTokens(t *testing.T) { - WithDB(t, func(ctx context.Context, db DB) { + testenv.WithState(t, func(ctx context.Context, s state.State) { var token string = "test-token" t.Run("SetAuthToken", func(t *testing.T) { g := NewWithT(t) - err := db.Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error { + err := s.Database().Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error { err := database.SetClusterAPIToken(ctx, tx, token) g.Expect(err).To(Not(HaveOccurred())) return nil @@ -26,7 +28,7 @@ func TestClusterAPIAuthTokens(t *testing.T) { t.Run("CheckAuthToken", func(t *testing.T) { t.Run("ValidToken", func(t *testing.T) { g := NewWithT(t) - err := db.Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error { + err := s.Database().Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error { valid, err := database.ValidateClusterAPIToken(ctx, tx, token) g.Expect(err).To(Not(HaveOccurred())) g.Expect(valid).To(BeTrue()) @@ -37,7 +39,7 @@ func TestClusterAPIAuthTokens(t *testing.T) { t.Run("InvalidToken", func(t *testing.T) { g := NewWithT(t) - err := db.Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error { + err := s.Database().Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error { valid, err := database.ValidateClusterAPIToken(ctx, tx, "invalid-token") g.Expect(err).To(Not(HaveOccurred())) g.Expect(valid).To(BeFalse()) diff --git a/src/k8s/pkg/k8sd/database/cluster_config_test.go b/src/k8s/pkg/k8sd/database/cluster_config_test.go index e4b21d220..2b36ede6c 100644 --- a/src/k8s/pkg/k8sd/database/cluster_config_test.go +++ b/src/k8s/pkg/k8sd/database/cluster_config_test.go @@ -8,11 +8,13 @@ import ( "github.com/canonical/k8s/pkg/k8sd/database" "github.com/canonical/k8s/pkg/k8sd/types" "github.com/canonical/k8s/pkg/utils" + testenv "github.com/canonical/k8s/pkg/utils/microcluster" + "github.com/canonical/microcluster/v2/state" . "github.com/onsi/gomega" ) func TestClusterConfig(t *testing.T) { - WithDB(t, func(ctx context.Context, d DB) { + testenv.WithState(t, func(ctx context.Context, s state.State) { t.Run("Set", func(t *testing.T) { g := NewWithT(t) expectedClusterConfig := types.ClusterConfig{ @@ -24,7 +26,7 @@ func TestClusterConfig(t *testing.T) { expectedClusterConfig.SetDefaults() // Write some config to the database - err := d.Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error { + err := s.Database().Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error { _, err := database.SetClusterConfig(context.Background(), tx, expectedClusterConfig) g.Expect(err).To(Not(HaveOccurred())) return nil @@ -32,7 +34,7 @@ func TestClusterConfig(t *testing.T) { g.Expect(err).To(Not(HaveOccurred())) // Retrieve it and map it to the struct - err = d.Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error { + err = s.Database().Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error { clusterConfig, err := database.GetClusterConfig(ctx, tx) g.Expect(err).To(Not(HaveOccurred())) g.Expect(clusterConfig).To(Equal(expectedClusterConfig)) @@ -52,7 +54,7 @@ func TestClusterConfig(t *testing.T) { } expectedClusterConfig.SetDefaults() - err := d.Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error { + err := s.Database().Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error { _, err := database.SetClusterConfig(context.Background(), tx, types.ClusterConfig{ Certificates: types.Certificates{ CACert: utils.Pointer("CA CERT NEW DATA"), @@ -63,7 +65,7 @@ func TestClusterConfig(t *testing.T) { }) g.Expect(err).To(HaveOccurred()) - err = d.Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error { + err = s.Database().Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error { clusterConfig, err := database.GetClusterConfig(ctx, tx) g.Expect(err).To(Not(HaveOccurred())) g.Expect(clusterConfig).To(Equal(expectedClusterConfig)) @@ -90,7 +92,7 @@ func TestClusterConfig(t *testing.T) { } expectedClusterConfig.SetDefaults() - err := d.Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error { + err := s.Database().Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error { returnedConfig, err := database.SetClusterConfig(context.Background(), tx, types.ClusterConfig{ Kubelet: types.Kubelet{ ClusterDNS: utils.Pointer("10.152.183.10"), @@ -109,7 +111,7 @@ func TestClusterConfig(t *testing.T) { }) g.Expect(err).To(Not(HaveOccurred())) - err = d.Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error { + err = s.Database().Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error { clusterConfig, err := database.GetClusterConfig(ctx, tx) g.Expect(err).To(Not(HaveOccurred())) g.Expect(clusterConfig).To(Equal(expectedClusterConfig)) diff --git a/src/k8s/pkg/k8sd/database/feature_status_test.go b/src/k8s/pkg/k8sd/database/feature_status_test.go index 200e57f3b..af59431f7 100644 --- a/src/k8s/pkg/k8sd/database/feature_status_test.go +++ b/src/k8s/pkg/k8sd/database/feature_status_test.go @@ -9,12 +9,14 @@ import ( "github.com/canonical/k8s/pkg/k8sd/database" "github.com/canonical/k8s/pkg/k8sd/features" "github.com/canonical/k8s/pkg/k8sd/types" + testenv "github.com/canonical/k8s/pkg/utils/microcluster" + "github.com/canonical/microcluster/v2/state" . "github.com/onsi/gomega" ) func TestFeatureStatus(t *testing.T) { - WithDB(t, func(ctx context.Context, db DB) { - _ = db.Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error { + testenv.WithState(t, func(ctx context.Context, s state.State) { + _ = s.Database().Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error { t0, _ := time.Parse(time.RFC3339, time.Now().Format(time.RFC3339)) networkStatus := types.FeatureStatus{ Enabled: true, diff --git a/src/k8s/pkg/k8sd/database/kubernetes_auth_tokens_test.go b/src/k8s/pkg/k8sd/database/kubernetes_auth_tokens_test.go index dd7d8ce9e..0878dffc3 100644 --- a/src/k8s/pkg/k8sd/database/kubernetes_auth_tokens_test.go +++ b/src/k8s/pkg/k8sd/database/kubernetes_auth_tokens_test.go @@ -6,16 +6,18 @@ import ( "testing" "github.com/canonical/k8s/pkg/k8sd/database" + testenv "github.com/canonical/k8s/pkg/utils/microcluster" + "github.com/canonical/microcluster/v2/state" . "github.com/onsi/gomega" ) func TestKubernetesAuthTokens(t *testing.T) { - WithDB(t, func(ctx context.Context, db DB) { + testenv.WithState(t, func(ctx context.Context, s state.State) { var token1, token2 string t.Run("GetOrCreateToken", func(t *testing.T) { g := NewWithT(t) - err := db.Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error { + err := s.Database().Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error { var err error token1, err = database.GetOrCreateToken(ctx, tx, "user1", []string{"group1", "group2"}) @@ -33,7 +35,7 @@ func TestKubernetesAuthTokens(t *testing.T) { t.Run("Existing", func(t *testing.T) { g := NewWithT(t) - err := db.Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error { + err := s.Database().Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error { token, err := database.GetOrCreateToken(ctx, tx, "user1", []string{"group1", "group2"}) g.Expect(err).To(Not(HaveOccurred())) g.Expect(token).To(Equal(token1)) @@ -46,7 +48,7 @@ func TestKubernetesAuthTokens(t *testing.T) { t.Run("CheckToken", func(t *testing.T) { t.Run("user1", func(t *testing.T) { g := NewWithT(t) - err := db.Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error { + err := s.Database().Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error { username, groups, err := database.CheckToken(ctx, tx, token1) g.Expect(err).To(Not(HaveOccurred())) g.Expect(username).To(Equal("user1")) @@ -57,7 +59,7 @@ func TestKubernetesAuthTokens(t *testing.T) { }) t.Run("user2", func(t *testing.T) { g := NewWithT(t) - err := db.Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error { + err := s.Database().Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error { username, groups, err := database.CheckToken(ctx, tx, token2) g.Expect(err).To(Not(HaveOccurred())) g.Expect(username).To(Equal("user2")) @@ -70,7 +72,7 @@ func TestKubernetesAuthTokens(t *testing.T) { t.Run("DeleteToken", func(t *testing.T) { g := NewWithT(t) - err := db.Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error { + err := s.Database().Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error { err := database.DeleteToken(ctx, tx, token2) g.Expect(err).To(Not(HaveOccurred())) diff --git a/src/k8s/pkg/k8sd/database/worker_test.go b/src/k8s/pkg/k8sd/database/worker_test.go index ce154fcdb..0c9e16e53 100644 --- a/src/k8s/pkg/k8sd/database/worker_test.go +++ b/src/k8s/pkg/k8sd/database/worker_test.go @@ -7,12 +7,14 @@ import ( "time" "github.com/canonical/k8s/pkg/k8sd/database" + testenv "github.com/canonical/k8s/pkg/utils/microcluster" + "github.com/canonical/microcluster/v2/state" . "github.com/onsi/gomega" ) func TestWorkerNodeToken(t *testing.T) { - WithDB(t, func(ctx context.Context, db DB) { - _ = db.Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error { + testenv.WithState(t, func(ctx context.Context, s state.State) { + _ = s.Database().Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error { tokenExpiry := time.Now().Add(time.Hour) t.Run("Default", func(t *testing.T) { g := NewWithT(t) diff --git a/src/k8s/pkg/k8sd/database/util_test.go b/src/k8s/pkg/utils/microcluster/state.go similarity index 77% rename from src/k8s/pkg/k8sd/database/util_test.go rename to src/k8s/pkg/utils/microcluster/state.go index 93f7efe27..1b2055b70 100644 --- a/src/k8s/pkg/k8sd/database/util_test.go +++ b/src/k8s/pkg/utils/microcluster/state.go @@ -1,8 +1,7 @@ -package database_test +package testenv import ( "context" - "database/sql" "fmt" "testing" "time" @@ -21,21 +20,17 @@ const ( // nextIdx is used to pick different listen ports for each microcluster instance. var nextIdx int -// DB is an interface for the internal microcluster DB type. -type DB interface { - Transaction(ctx context.Context, f func(context.Context, *sql.Tx) error) error -} - -// WithDB can be used to run isolated tests against the microcluster database. +// WithState can be used to run isolated tests against the microcluster database. +// The Database() can be accessed by calling s.Database(). // // Example usage: // // func TestKubernetesAuthTokens(t *testing.T) { // t.Run("ValidToken", func(t *testing.T) { // g := NewWithT(t) -// WithDB(t, func(ctx context.Context, db DB) { +// WithState(t, func(ctx context.Context, s state.State) { // err := db.Transaction(ctx, func(ctx context.Context, tx *sql.Tx) error { -// token, err := database.GetOrCreateToken(ctx, tx, "user1", []string{"group1", "group2"}) +// token, err := s.Database().GetOrCreateToken(ctx, tx, "user1", []string{"group1", "group2"}) // if !g.Expect(err).To(Not(HaveOccurred())) { // return err // } @@ -46,7 +41,7 @@ type DB interface { // }) // }) // } -func WithDB(t *testing.T, f func(context.Context, DB)) { +func WithState(t *testing.T, f func(context.Context, state.State)) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -57,16 +52,16 @@ func WithDB(t *testing.T, f func(context.Context, DB)) { t.Fatalf("failed to create microcluster app: %v", err) } - databaseCh := make(chan DB, 1) + stateChan := make(chan state.State, 1) doneCh := make(chan error, 1) - defer close(databaseCh) + defer close(stateChan) defer close(doneCh) - // app.Run() is blocking, so we get the database handle through a channel + // app.Run() is blocking, so we get the state handle through a channel go func() { doneCh <- app.Run(ctx, &state.Hooks{ PostBootstrap: func(ctx context.Context, s state.State, initConfig map[string]string) error { - databaseCh <- s.Database() + stateChan <- s return nil }, OnStart: func(ctx context.Context, s state.State) error { @@ -95,8 +90,8 @@ func WithDB(t *testing.T, f func(context.Context, DB)) { select { case <-time.After(microclusterDatabaseInitTimeout): t.Fatalf("timed out waiting for microcluster to start") - case db := <-databaseCh: - f(ctx, db) + case state := <-stateChan: + f(ctx, state) } // cancel context to stop the microcluster instance, and wait for it to shutdown