Skip to content

Commit

Permalink
SDP-761 - Ping() and SqlDB() migration
Browse files Browse the repository at this point in the history
  • Loading branch information
marwen-abid committed Nov 5, 2023
1 parent 22e1be8 commit 0cc8bba
Show file tree
Hide file tree
Showing 9 changed files with 61 additions and 50 deletions.
13 changes: 9 additions & 4 deletions db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ type DBConnectionPool interface {
SQLExecuter
BeginTxx(ctx context.Context, opts *sql.TxOptions) (DBTransaction, error)
Close() error
Ping() error
SqlDB() *sql.DB
Ping(ctx context.Context) error
SqlDB(ctx context.Context) *sql.DB
SqlxDB(ctx context.Context) *sqlx.DB
DSN(ctx context.Context) (string, error)
}

Expand All @@ -44,11 +45,15 @@ func (db *DBConnectionPoolImplementation) BeginTxx(ctx context.Context, opts *sq
return db.DB.BeginTxx(ctx, opts)
}

func (db *DBConnectionPoolImplementation) SqlDB() *sql.DB {
func (db *DBConnectionPoolImplementation) Ping(ctx context.Context) error {
return db.DB.PingContext(ctx)
}

func (db *DBConnectionPoolImplementation) SqlDB(ctx context.Context) *sql.DB {
return db.DB.DB
}

func (db *DBConnectionPoolImplementation) SqlxDB() *sqlx.DB {
func (db *DBConnectionPoolImplementation) SqlxDB(ctx context.Context) *sqlx.DB {
return db.DB
}

Expand Down
14 changes: 10 additions & 4 deletions db/db_connection_pool_with_metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"database/sql"
"fmt"

"github.com/jmoiron/sqlx"

"github.com/stellar/stellar-disbursement-platform-backend/internal/monitor"
)

Expand Down Expand Up @@ -42,12 +44,16 @@ func (dbc *DBConnectionPoolWithMetrics) Close() error {
return dbc.dbConnectionPool.Close()
}

func (dbc *DBConnectionPoolWithMetrics) Ping() error {
return dbc.dbConnectionPool.Ping()
func (dbc *DBConnectionPoolWithMetrics) Ping(ctx context.Context) error {
return dbc.dbConnectionPool.Ping(ctx)
}

func (dbc *DBConnectionPoolWithMetrics) SqlDB(ctx context.Context) *sql.DB {
return dbc.dbConnectionPool.SqlDB(ctx)
}

func (dbc *DBConnectionPoolWithMetrics) SqlDB() *sql.DB {
return dbc.dbConnectionPool.SqlDB()
func (dbc *DBConnectionPoolWithMetrics) SqlxDB(ctx context.Context) *sqlx.DB {
return dbc.dbConnectionPool.SqlxDB(ctx)
}

func (dbc *DBConnectionPoolWithMetrics) DSN(ctx context.Context) (string, error) {
Expand Down
23 changes: 22 additions & 1 deletion db/db_connection_pool_with_metrics_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,32 @@ import (
"database/sql"
"testing"

"github.com/jmoiron/sqlx"

"github.com/stellar/stellar-disbursement-platform-backend/db/dbtest"
"github.com/stellar/stellar-disbursement-platform-backend/internal/monitor"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestDBConnectionPoolWithMetrics_SqlxDB(t *testing.T) {
dbt := dbtest.Open(t)
defer dbt.Close()
dbConnectionPool, err := OpenDBConnectionPool(dbt.DSN)
require.NoError(t, err)
defer dbConnectionPool.Close()

mMonitorService := &monitor.MockMonitorService{}

dbConnectionPoolWithMetrics, err := NewDBConnectionPoolWithMetrics(dbConnectionPool, mMonitorService)
require.NoError(t, err)

ctx := context.Background()
sqlxDB := dbConnectionPoolWithMetrics.SqlxDB(ctx)

assert.IsType(t, &sqlx.DB{}, sqlxDB)
}

func TestDBConnectionPoolWithMetrics_SqlDB(t *testing.T) {
dbt := dbtest.Open(t)
defer dbt.Close()
Expand All @@ -23,7 +43,8 @@ func TestDBConnectionPoolWithMetrics_SqlDB(t *testing.T) {
dbConnectionPoolWithMetrics, err := NewDBConnectionPoolWithMetrics(dbConnectionPool, mMonitorService)
require.NoError(t, err)

sqlDB := dbConnectionPoolWithMetrics.SqlDB()
ctx := context.Background()
sqlDB := dbConnectionPoolWithMetrics.SqlDB(ctx)

assert.IsType(t, &sql.DB{}, sqlDB)
}
Expand Down
4 changes: 3 additions & 1 deletion db/db_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package db

import (
"context"
"testing"

"github.com/stellar/go/support/db/dbtest"
Expand All @@ -18,6 +19,7 @@ func TestOpen_OpenDBConnectionPool(t *testing.T) {

assert.Equal(t, "postgres", dbConnectionPool.DriverName())

err = dbConnectionPool.Ping()
ctx := context.Background()
err = dbConnectionPool.Ping(ctx)
require.NoError(t, err)
}
4 changes: 3 additions & 1 deletion db/migrate.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package db

import (
"context"
"embed"
"fmt"
"net/http"
Expand Down Expand Up @@ -29,5 +30,6 @@ func Migrate(dbURL string, dir migrate.MigrationDirection, count int, migrationF
}

m := migrate.HttpFileSystemMigrationSource{FileSystem: http.FS(migrationFiles)}
return ms.ExecMax(dbConnectionPool.SqlDB(), dbConnectionPool.DriverName(), m, dir, count)
ctx := context.Background()
return ms.ExecMax(dbConnectionPool.SqlDB(ctx), dbConnectionPool.DriverName(), m, dir, count)
}
18 changes: 9 additions & 9 deletions db/sql_exec_with_metrics_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func TestSQLExecWithMetrics_GetContext(t *testing.T) {
VALUES
($1, $2)
`
_, err = dbConnectionPool.SqlDB().ExecContext(ctx, query, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZCC")
_, err = dbConnectionPool.ExecContext(ctx, query, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZCC")
require.NoError(t, err)

t.Run("query successful in GetContext", func(t *testing.T) {
Expand Down Expand Up @@ -94,10 +94,10 @@ func TestSQLExecWithMetrics_SelectContext(t *testing.T) {
VALUES
($1, $2)
`
_, err = dbConnectionPool.SqlDB().ExecContext(ctx, query, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZCC")
_, err = dbConnectionPool.ExecContext(ctx, query, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZCC")
require.NoError(t, err)

_, err = dbConnectionPool.SqlDB().ExecContext(ctx, query, "EURT", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZCC")
_, err = dbConnectionPool.ExecContext(ctx, query, "EURT", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZCC")
require.NoError(t, err)

t.Run("query successful in SelectContext", func(t *testing.T) {
Expand Down Expand Up @@ -157,10 +157,10 @@ func TestSQLExecWithMetrics_QueryContext(t *testing.T) {
VALUES
($1, $2)
`
_, err = dbConnectionPool.SqlDB().ExecContext(ctx, query, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZCC")
_, err = dbConnectionPool.ExecContext(ctx, query, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZCC")
require.NoError(t, err)

_, err = dbConnectionPool.SqlDB().ExecContext(ctx, query, "EURT", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZCC")
_, err = dbConnectionPool.ExecContext(ctx, query, "EURT", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZCC")
require.NoError(t, err)

t.Run("query successful in QueryContext", func(t *testing.T) {
Expand Down Expand Up @@ -229,10 +229,10 @@ func TestSQLExecWithMetrics_QueryxContext(t *testing.T) {
VALUES
($1, $2)
`
_, err = dbConnectionPool.SqlDB().ExecContext(ctx, query, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZCC")
_, err = dbConnectionPool.ExecContext(ctx, query, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZCC")
require.NoError(t, err)

_, err = dbConnectionPool.SqlDB().ExecContext(ctx, query, "EURT", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZCC")
_, err = dbConnectionPool.ExecContext(ctx, query, "EURT", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZCC")
require.NoError(t, err)

t.Run("query successful in QueryxContext", func(t *testing.T) {
Expand Down Expand Up @@ -301,7 +301,7 @@ func TestSQLExecWithMetrics_QueryRowxContext(t *testing.T) {
VALUES
($1, $2)
`
_, err = dbConnectionPool.SqlDB().ExecContext(ctx, query, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZCC")
_, err = dbConnectionPool.ExecContext(ctx, query, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZCC")
require.NoError(t, err)

t.Run("query successful in QueryRowxContext", func(t *testing.T) {
Expand Down Expand Up @@ -362,7 +362,7 @@ func TestSQLExecWithMetrics_ExecContext(t *testing.T) {
VALUES
($1, $2)
`
_, err = dbConnectionPool.SqlDB().ExecContext(ctx, query, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZCC")
_, err = dbConnectionPool.ExecContext(ctx, query, "USDC", "GA5ZSEJYB37JRC5AVCIA5MOP4RHTM335X2KGX3IHOJAPP5RE34K4KZCC")
require.NoError(t, err)

t.Run("query successful in ExecContext", func(t *testing.T) {
Expand Down
28 changes: 0 additions & 28 deletions internal/serve/httphandler/assets_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1379,34 +1379,6 @@ func Test_AssetHandler_submitChangeTrustTransaction_makeSurePreconditionsAreSetA
BaseFee: txnbuild.MinBaseFee * feeMultiplierInStroops,
}

t.Run("makes sure a non-empty precondition is used if none is explicitly set", func(t *testing.T) {
mocks := newAssetTestMock(t, distributionKP.Address())
mocks.Handler.GetPreconditionsFn = nil

txParams := txParamsWithoutPreconditions
txParams.Preconditions = defaultPreconditions
tx, err := txnbuild.NewTransaction(txParams)
require.NoError(t, err)

signedTx, err := tx.Sign(network.TestNetworkPassphrase, distributionKP)
require.NoError(t, err)

mocks.SignatureService.
On("SignStellarTransaction", ctx, mock.MatchedBy(matchPreconditionsTimeboundsFn(defaultPreconditions)), distributionKP.Address()).
Return(signedTx, nil).
Once()
defer mocks.SignatureService.AssertExpectations(t)

mocks.HorizonClientMock.
On("SubmitTransactionWithOptions", mock.MatchedBy(matchPreconditionsTimeboundsFn(defaultPreconditions)), horizonclient.SubmitTxOpts{SkipMemoRequiredCheck: true}).
Return(horizon.Transaction{}, nil).
Once()
defer mocks.HorizonClientMock.AssertExpectations(t)

err = mocks.Handler.submitChangeTrustTransaction(ctx, acc, []*txnbuild.ChangeTrust{changeTrustOp})
assert.NoError(t, err)
})

t.Run("makes sure a the precondition that was set is used", func(t *testing.T) {
mocks := newAssetTestMock(t, distributionKP.Address())
newPreconditions := txnbuild.Preconditions{TimeBounds: txnbuild.NewTimeout(int64(rand.Intn(999999999)))}
Expand Down
4 changes: 3 additions & 1 deletion internal/serve/serve.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package serve

import (
"context"
"fmt"
"io/fs"
"net/http"
Expand Down Expand Up @@ -419,7 +420,8 @@ func createAuthManager(dbConnectionPool db.DBConnectionPool, ec256PublicKey, ec2

passwordEncrypter := auth.NewDefaultPasswordEncrypter()

authDBConnectionPool := auth.DBConnectionPoolFromSqlDB(dbConnectionPool.SqlDB(), dbConnectionPool.DriverName())
ctx := context.Background()
authDBConnectionPool := auth.DBConnectionPoolFromSqlDB(dbConnectionPool.SqlDB(ctx), dbConnectionPool.DriverName())
authManager := auth.NewAuthManager(
auth.WithDefaultAuthenticatorOption(authDBConnectionPool, passwordEncrypter, time.Hour*time.Duration(resetTokenExpirationHours)),
auth.WithDefaultJWTManagerOption(ec256PublicKey, ec256PrivateKey),
Expand Down
3 changes: 2 additions & 1 deletion internal/serve/serve_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package serve

import (
"context"
"fmt"
"io"
"net/http"
Expand Down Expand Up @@ -337,7 +338,7 @@ func Test_createAuthManager(t *testing.T) {

// creates the expected auth manager
passwordEncrypter := auth.NewDefaultPasswordEncrypter()
authDBConnectionPool := auth.DBConnectionPoolFromSqlDB(dbConnectionPool.SqlDB(), dbConnectionPool.DriverName())
authDBConnectionPool := auth.DBConnectionPoolFromSqlDB(dbConnectionPool.SqlDB(context.Background()), dbConnectionPool.DriverName())
wantAuthManager := auth.NewAuthManager(
auth.WithDefaultAuthenticatorOption(authDBConnectionPool, passwordEncrypter, time.Hour*time.Duration(1)),
auth.WithDefaultJWTManagerOption(publicKeyStr, privateKeyStr),
Expand Down

0 comments on commit 0cc8bba

Please sign in to comment.