From 0cc8bbab9c21c3901048fe760e88af86c3e40cfb Mon Sep 17 00:00:00 2001 From: Marwen Abid Date: Sun, 5 Nov 2023 15:35:51 -0800 Subject: [PATCH] SDP-761 - Ping() and SqlDB() migration --- db/db.go | 13 ++++++--- db/db_connection_pool_with_metrics.go | 14 +++++++--- db/db_connection_pool_with_metrics_test.go | 23 ++++++++++++++- db/db_test.go | 4 ++- db/migrate.go | 4 ++- db/sql_exec_with_metrics_test.go | 18 ++++++------ .../serve/httphandler/assets_handler_test.go | 28 ------------------- internal/serve/serve.go | 4 ++- internal/serve/serve_test.go | 3 +- 9 files changed, 61 insertions(+), 50 deletions(-) diff --git a/db/db.go b/db/db.go index b5fd1a990..19c5b8ea6 100644 --- a/db/db.go +++ b/db/db.go @@ -29,8 +29,9 @@ type DBConnectionPool interface { SQLExecuter BeginTxx(ctx context.Context, opts *sql.TxOptions) (DBTransaction, error) Close() error - Ping() error - SqlDB() *sql.DB + Ping(ctx context.Context) error + SqlDB(ctx context.Context) *sql.DB + SqlxDB(ctx context.Context) *sqlx.DB DSN(ctx context.Context) (string, error) } @@ -44,11 +45,15 @@ func (db *DBConnectionPoolImplementation) BeginTxx(ctx context.Context, opts *sq return db.DB.BeginTxx(ctx, opts) } -func (db *DBConnectionPoolImplementation) SqlDB() *sql.DB { +func (db *DBConnectionPoolImplementation) Ping(ctx context.Context) error { + return db.DB.PingContext(ctx) +} + +func (db *DBConnectionPoolImplementation) SqlDB(ctx context.Context) *sql.DB { return db.DB.DB } -func (db *DBConnectionPoolImplementation) SqlxDB() *sqlx.DB { +func (db *DBConnectionPoolImplementation) SqlxDB(ctx context.Context) *sqlx.DB { return db.DB } diff --git a/db/db_connection_pool_with_metrics.go b/db/db_connection_pool_with_metrics.go index e8fdd5cd2..aa47653c2 100644 --- a/db/db_connection_pool_with_metrics.go +++ b/db/db_connection_pool_with_metrics.go @@ -5,6 +5,8 @@ import ( "database/sql" "fmt" + "github.com/jmoiron/sqlx" + "github.com/stellar/stellar-disbursement-platform-backend/internal/monitor" ) @@ -42,12 +44,16 @@ func (dbc *DBConnectionPoolWithMetrics) Close() error { return dbc.dbConnectionPool.Close() } -func (dbc *DBConnectionPoolWithMetrics) Ping() error { - return dbc.dbConnectionPool.Ping() +func (dbc *DBConnectionPoolWithMetrics) Ping(ctx context.Context) error { + return dbc.dbConnectionPool.Ping(ctx) +} + +func (dbc *DBConnectionPoolWithMetrics) SqlDB(ctx context.Context) *sql.DB { + return dbc.dbConnectionPool.SqlDB(ctx) } -func (dbc *DBConnectionPoolWithMetrics) SqlDB() *sql.DB { - return dbc.dbConnectionPool.SqlDB() +func (dbc *DBConnectionPoolWithMetrics) SqlxDB(ctx context.Context) *sqlx.DB { + return dbc.dbConnectionPool.SqlxDB(ctx) } func (dbc *DBConnectionPoolWithMetrics) DSN(ctx context.Context) (string, error) { diff --git a/db/db_connection_pool_with_metrics_test.go b/db/db_connection_pool_with_metrics_test.go index 25cef1e63..d619f033d 100644 --- a/db/db_connection_pool_with_metrics_test.go +++ b/db/db_connection_pool_with_metrics_test.go @@ -5,12 +5,32 @@ import ( "database/sql" "testing" + "github.com/jmoiron/sqlx" + "github.com/stellar/stellar-disbursement-platform-backend/db/dbtest" "github.com/stellar/stellar-disbursement-platform-backend/internal/monitor" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +func TestDBConnectionPoolWithMetrics_SqlxDB(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, err := OpenDBConnectionPool(dbt.DSN) + require.NoError(t, err) + defer dbConnectionPool.Close() + + mMonitorService := &monitor.MockMonitorService{} + + dbConnectionPoolWithMetrics, err := NewDBConnectionPoolWithMetrics(dbConnectionPool, mMonitorService) + require.NoError(t, err) + + ctx := context.Background() + sqlxDB := dbConnectionPoolWithMetrics.SqlxDB(ctx) + + assert.IsType(t, &sqlx.DB{}, sqlxDB) +} + func TestDBConnectionPoolWithMetrics_SqlDB(t *testing.T) { dbt := dbtest.Open(t) defer dbt.Close() @@ -23,7 +43,8 @@ func TestDBConnectionPoolWithMetrics_SqlDB(t *testing.T) { dbConnectionPoolWithMetrics, err := NewDBConnectionPoolWithMetrics(dbConnectionPool, mMonitorService) require.NoError(t, err) - sqlDB := dbConnectionPoolWithMetrics.SqlDB() + ctx := context.Background() + sqlDB := dbConnectionPoolWithMetrics.SqlDB(ctx) assert.IsType(t, &sql.DB{}, sqlDB) } diff --git a/db/db_test.go b/db/db_test.go index a67723e55..1573b734d 100644 --- a/db/db_test.go +++ b/db/db_test.go @@ -1,6 +1,7 @@ package db import ( + "context" "testing" "github.com/stellar/go/support/db/dbtest" @@ -18,6 +19,7 @@ func TestOpen_OpenDBConnectionPool(t *testing.T) { assert.Equal(t, "postgres", dbConnectionPool.DriverName()) - err = dbConnectionPool.Ping() + ctx := context.Background() + err = dbConnectionPool.Ping(ctx) require.NoError(t, err) } diff --git a/db/migrate.go b/db/migrate.go index 9335ad68a..5609760ac 100644 --- a/db/migrate.go +++ b/db/migrate.go @@ -1,6 +1,7 @@ package db import ( + "context" "embed" "fmt" "net/http" @@ -29,5 +30,6 @@ func Migrate(dbURL string, dir migrate.MigrationDirection, count int, migrationF } m := migrate.HttpFileSystemMigrationSource{FileSystem: http.FS(migrationFiles)} - return ms.ExecMax(dbConnectionPool.SqlDB(), dbConnectionPool.DriverName(), m, dir, count) + ctx := context.Background() + return ms.ExecMax(dbConnectionPool.SqlDB(ctx), dbConnectionPool.DriverName(), m, dir, count) } diff --git a/db/sql_exec_with_metrics_test.go b/db/sql_exec_with_metrics_test.go index aa97dd420..ad8c10dc9 100644 --- a/db/sql_exec_with_metrics_test.go +++ b/db/sql_exec_with_metrics_test.go @@ -33,7 +33,7 @@ func TestSQLExecWithMetrics_GetContext(t *testing.T) { VALUES ($1, $2) ` - _, err = dbConnectionPool.SqlDB().ExecContext(ctx, query, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZCC") + _, err = dbConnectionPool.ExecContext(ctx, query, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZCC") require.NoError(t, err) t.Run("query successful in GetContext", func(t *testing.T) { @@ -94,10 +94,10 @@ func TestSQLExecWithMetrics_SelectContext(t *testing.T) { VALUES ($1, $2) ` - _, err = dbConnectionPool.SqlDB().ExecContext(ctx, query, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZCC") + _, err = dbConnectionPool.ExecContext(ctx, query, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZCC") require.NoError(t, err) - _, err = dbConnectionPool.SqlDB().ExecContext(ctx, query, "EURT", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZCC") + _, err = dbConnectionPool.ExecContext(ctx, query, "EURT", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZCC") require.NoError(t, err) t.Run("query successful in SelectContext", func(t *testing.T) { @@ -157,10 +157,10 @@ func TestSQLExecWithMetrics_QueryContext(t *testing.T) { VALUES ($1, $2) ` - _, err = dbConnectionPool.SqlDB().ExecContext(ctx, query, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZCC") + _, err = dbConnectionPool.ExecContext(ctx, query, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZCC") require.NoError(t, err) - _, err = dbConnectionPool.SqlDB().ExecContext(ctx, query, "EURT", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZCC") + _, err = dbConnectionPool.ExecContext(ctx, query, "EURT", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZCC") require.NoError(t, err) t.Run("query successful in QueryContext", func(t *testing.T) { @@ -229,10 +229,10 @@ func TestSQLExecWithMetrics_QueryxContext(t *testing.T) { VALUES ($1, $2) ` - _, err = dbConnectionPool.SqlDB().ExecContext(ctx, query, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZCC") + _, err = dbConnectionPool.ExecContext(ctx, query, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZCC") require.NoError(t, err) - _, err = dbConnectionPool.SqlDB().ExecContext(ctx, query, "EURT", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZCC") + _, err = dbConnectionPool.ExecContext(ctx, query, "EURT", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZCC") require.NoError(t, err) t.Run("query successful in QueryxContext", func(t *testing.T) { @@ -301,7 +301,7 @@ func TestSQLExecWithMetrics_QueryRowxContext(t *testing.T) { VALUES ($1, $2) ` - _, err = dbConnectionPool.SqlDB().ExecContext(ctx, query, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZCC") + _, err = dbConnectionPool.ExecContext(ctx, query, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZCC") require.NoError(t, err) t.Run("query successful in QueryRowxContext", func(t *testing.T) { @@ -362,7 +362,7 @@ func TestSQLExecWithMetrics_ExecContext(t *testing.T) { VALUES ($1, $2) ` - _, err = dbConnectionPool.SqlDB().ExecContext(ctx, query, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZCC") + _, err = dbConnectionPool.ExecContext(ctx, query, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZCC") require.NoError(t, err) t.Run("query successful in ExecContext", func(t *testing.T) { diff --git a/internal/serve/httphandler/assets_handler_test.go b/internal/serve/httphandler/assets_handler_test.go index 20a51bda9..38df4f763 100644 --- a/internal/serve/httphandler/assets_handler_test.go +++ b/internal/serve/httphandler/assets_handler_test.go @@ -1379,34 +1379,6 @@ func Test_AssetHandler_submitChangeTrustTransaction_makeSurePreconditionsAreSetA BaseFee: txnbuild.MinBaseFee * feeMultiplierInStroops, } - t.Run("makes sure a non-empty precondition is used if none is explicitly set", func(t *testing.T) { - mocks := newAssetTestMock(t, distributionKP.Address()) - mocks.Handler.GetPreconditionsFn = nil - - txParams := txParamsWithoutPreconditions - txParams.Preconditions = defaultPreconditions - tx, err := txnbuild.NewTransaction(txParams) - require.NoError(t, err) - - signedTx, err := tx.Sign(network.TestNetworkPassphrase, distributionKP) - require.NoError(t, err) - - mocks.SignatureService. - On("SignStellarTransaction", ctx, mock.MatchedBy(matchPreconditionsTimeboundsFn(defaultPreconditions)), distributionKP.Address()). - Return(signedTx, nil). - Once() - defer mocks.SignatureService.AssertExpectations(t) - - mocks.HorizonClientMock. - On("SubmitTransactionWithOptions", mock.MatchedBy(matchPreconditionsTimeboundsFn(defaultPreconditions)), horizonclient.SubmitTxOpts{SkipMemoRequiredCheck: true}). - Return(horizon.Transaction{}, nil). - Once() - defer mocks.HorizonClientMock.AssertExpectations(t) - - err = mocks.Handler.submitChangeTrustTransaction(ctx, acc, []*txnbuild.ChangeTrust{changeTrustOp}) - assert.NoError(t, err) - }) - t.Run("makes sure a the precondition that was set is used", func(t *testing.T) { mocks := newAssetTestMock(t, distributionKP.Address()) newPreconditions := txnbuild.Preconditions{TimeBounds: txnbuild.NewTimeout(int64(rand.Intn(999999999)))} diff --git a/internal/serve/serve.go b/internal/serve/serve.go index 800d485c9..ed24bc431 100644 --- a/internal/serve/serve.go +++ b/internal/serve/serve.go @@ -1,6 +1,7 @@ package serve import ( + "context" "fmt" "io/fs" "net/http" @@ -419,7 +420,8 @@ func createAuthManager(dbConnectionPool db.DBConnectionPool, ec256PublicKey, ec2 passwordEncrypter := auth.NewDefaultPasswordEncrypter() - authDBConnectionPool := auth.DBConnectionPoolFromSqlDB(dbConnectionPool.SqlDB(), dbConnectionPool.DriverName()) + ctx := context.Background() + authDBConnectionPool := auth.DBConnectionPoolFromSqlDB(dbConnectionPool.SqlDB(ctx), dbConnectionPool.DriverName()) authManager := auth.NewAuthManager( auth.WithDefaultAuthenticatorOption(authDBConnectionPool, passwordEncrypter, time.Hour*time.Duration(resetTokenExpirationHours)), auth.WithDefaultJWTManagerOption(ec256PublicKey, ec256PrivateKey), diff --git a/internal/serve/serve_test.go b/internal/serve/serve_test.go index 9114ce67c..53f5143a0 100644 --- a/internal/serve/serve_test.go +++ b/internal/serve/serve_test.go @@ -1,6 +1,7 @@ package serve import ( + "context" "fmt" "io" "net/http" @@ -337,7 +338,7 @@ func Test_createAuthManager(t *testing.T) { // creates the expected auth manager passwordEncrypter := auth.NewDefaultPasswordEncrypter() - authDBConnectionPool := auth.DBConnectionPoolFromSqlDB(dbConnectionPool.SqlDB(), dbConnectionPool.DriverName()) + authDBConnectionPool := auth.DBConnectionPoolFromSqlDB(dbConnectionPool.SqlDB(context.Background()), dbConnectionPool.DriverName()) wantAuthManager := auth.NewAuthManager( auth.WithDefaultAuthenticatorOption(authDBConnectionPool, passwordEncrypter, time.Hour*time.Duration(1)), auth.WithDefaultJWTManagerOption(publicKeyStr, privateKeyStr),