diff --git a/db/db.go b/db/db.go index 19c5b8ea6..4ceadee20 100644 --- a/db/db.go +++ b/db/db.go @@ -30,8 +30,8 @@ type DBConnectionPool interface { BeginTxx(ctx context.Context, opts *sql.TxOptions) (DBTransaction, error) Close() error Ping(ctx context.Context) error - SqlDB(ctx context.Context) *sql.DB - SqlxDB(ctx context.Context) *sqlx.DB + SqlDB(ctx context.Context) (*sql.DB, error) + SqlxDB(ctx context.Context) (*sqlx.DB, error) DSN(ctx context.Context) (string, error) } @@ -49,12 +49,18 @@ func (db *DBConnectionPoolImplementation) Ping(ctx context.Context) error { return db.DB.PingContext(ctx) } -func (db *DBConnectionPoolImplementation) SqlDB(ctx context.Context) *sql.DB { - return db.DB.DB +func (db *DBConnectionPoolImplementation) SqlDB(ctx context.Context) (*sql.DB, error) { + if db.DB.DB == nil { + return nil, fmt.Errorf("sql.DB is not initialized") + } + return db.DB.DB, nil } -func (db *DBConnectionPoolImplementation) SqlxDB(ctx context.Context) *sqlx.DB { - return db.DB +func (db *DBConnectionPoolImplementation) SqlxDB(ctx context.Context) (*sqlx.DB, error) { + if db.DB == nil { + return nil, fmt.Errorf("sqlx.DB is not initialized") + } + return db.DB, nil } func (db *DBConnectionPoolImplementation) DSN(ctx context.Context) (string, error) { diff --git a/db/db_connection_pool_with_metrics.go b/db/db_connection_pool_with_metrics.go index aa47653c2..eec952a7f 100644 --- a/db/db_connection_pool_with_metrics.go +++ b/db/db_connection_pool_with_metrics.go @@ -48,11 +48,11 @@ func (dbc *DBConnectionPoolWithMetrics) Ping(ctx context.Context) error { return dbc.dbConnectionPool.Ping(ctx) } -func (dbc *DBConnectionPoolWithMetrics) SqlDB(ctx context.Context) *sql.DB { +func (dbc *DBConnectionPoolWithMetrics) SqlDB(ctx context.Context) (*sql.DB, error) { return dbc.dbConnectionPool.SqlDB(ctx) } -func (dbc *DBConnectionPoolWithMetrics) SqlxDB(ctx context.Context) *sqlx.DB { +func (dbc *DBConnectionPoolWithMetrics) SqlxDB(ctx context.Context) (*sqlx.DB, error) { return dbc.dbConnectionPool.SqlxDB(ctx) } diff --git a/db/db_connection_pool_with_metrics_test.go b/db/db_connection_pool_with_metrics_test.go index d619f033d..9ed78b324 100644 --- a/db/db_connection_pool_with_metrics_test.go +++ b/db/db_connection_pool_with_metrics_test.go @@ -26,7 +26,8 @@ func TestDBConnectionPoolWithMetrics_SqlxDB(t *testing.T) { require.NoError(t, err) ctx := context.Background() - sqlxDB := dbConnectionPoolWithMetrics.SqlxDB(ctx) + sqlxDB, err := dbConnectionPoolWithMetrics.SqlxDB(ctx) + require.NoError(t, err) assert.IsType(t, &sqlx.DB{}, sqlxDB) } @@ -44,7 +45,8 @@ func TestDBConnectionPoolWithMetrics_SqlDB(t *testing.T) { require.NoError(t, err) ctx := context.Background() - sqlDB := dbConnectionPoolWithMetrics.SqlDB(ctx) + sqlDB, err := dbConnectionPoolWithMetrics.SqlDB(ctx) + require.NoError(t, err) assert.IsType(t, &sql.DB{}, sqlDB) } diff --git a/db/db_connection_pool_with_router.go b/db/db_connection_pool_with_router.go new file mode 100644 index 000000000..27af8a868 --- /dev/null +++ b/db/db_connection_pool_with_router.go @@ -0,0 +1,85 @@ +package db + +import ( + "context" + "database/sql" + "fmt" + + "github.com/jmoiron/sqlx" +) + +// ConnectionPoolWithRouter implements the DBConnectionPool interface +type ConnectionPoolWithRouter struct { + SQLExecutorWithRouter +} + +// NewConnectionPoolWithRouter creates a new ConnectionPoolWithRouter +func NewConnectionPoolWithRouter(dataSourceRouter DataSourceRouter) (*ConnectionPoolWithRouter, error) { + sqlExecutor, err := NewSQLExecutorWithRouter(dataSourceRouter) + if err != nil { + return nil, fmt.Errorf("creating new sqlExecutor for connection pool with router: %w", err) + } + return &ConnectionPoolWithRouter{ + SQLExecutorWithRouter: *sqlExecutor, + }, nil +} + +func (m ConnectionPoolWithRouter) BeginTxx(ctx context.Context, opts *sql.TxOptions) (DBTransaction, error) { + dbcpl, err := m.dataSourceRouter.GetDataSource(ctx) + if err != nil { + return nil, fmt.Errorf("getting data source from context in BeginTxx: %w", err) + } + return dbcpl.BeginTxx(ctx, opts) +} + +func (m ConnectionPoolWithRouter) Close() error { + dbcpls, err := m.dataSourceRouter.GetAllDataSources() + if err != nil { + return fmt.Errorf("getting all data sources in Close: %w", err) + } + if len(dbcpls) == 0 { + return fmt.Errorf("no data sources found in Close") + } + for _, dbcpl := range dbcpls { + err = dbcpl.Close() + if err != nil { + return fmt.Errorf("closing data source in Close: %w", err) + } + } + return nil +} + +func (m ConnectionPoolWithRouter) Ping(ctx context.Context) error { + dbcpl, err := m.dataSourceRouter.GetDataSource(ctx) + if err != nil { + return fmt.Errorf("getting data source from context in Ping: %w", err) + } + return dbcpl.Ping(ctx) +} + +func (m ConnectionPoolWithRouter) SqlDB(ctx context.Context) (*sql.DB, error) { + dbcpl, err := m.dataSourceRouter.GetDataSource(ctx) + if err != nil { + return nil, fmt.Errorf("getting data source from context in SqlDB: %w", err) + } + return dbcpl.SqlDB(ctx) +} + +func (m ConnectionPoolWithRouter) SqlxDB(ctx context.Context) (*sqlx.DB, error) { + dbcpl, err := m.dataSourceRouter.GetDataSource(ctx) + if err != nil { + return nil, fmt.Errorf("getting data source from context in SqlxDB: %w", err) + } + return dbcpl.SqlxDB(ctx) +} + +func (m ConnectionPoolWithRouter) DSN(ctx context.Context) (string, error) { + dbcpl, err := m.dataSourceRouter.GetDataSource(ctx) + if err != nil { + return "", fmt.Errorf("getting data source from context in DSN: %w", err) + } + return dbcpl.DSN(ctx) +} + +// make sure *ConnectionPoolWithRouter implements DBConnectionPool: +var _ DBConnectionPool = (*ConnectionPoolWithRouter)(nil) diff --git a/db/db_connection_pool_with_router_test.go b/db/db_connection_pool_with_router_test.go new file mode 100644 index 000000000..183a895f6 --- /dev/null +++ b/db/db_connection_pool_with_router_test.go @@ -0,0 +1,260 @@ +package db + +import ( + "context" + "database/sql" + "testing" + + "github.com/jmoiron/sqlx" + "github.com/stellar/stellar-disbursement-platform-backend/db/dbtest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestConnectionPoolWithRouter_BeginTxx(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, outerErr := OpenDBConnectionPool(dbt.DSN) + require.NoError(t, outerErr) + defer dbConnectionPool.Close() + + mockRouter := new(MockDataSourceRouter) + + connectionPoolWithRouter, outerErr := NewConnectionPoolWithRouter(mockRouter) + require.NoError(t, outerErr) + ctx := context.Background() + + t.Run("BeginTxx successful", func(t *testing.T) { + mockRouter. + On("GetDataSource", ctx). + Return(dbConnectionPool, nil). + Once() + + dbTx, err := connectionPoolWithRouter.BeginTxx(ctx, nil) + + // Defer a rollback in case anything fails. + defer func() { + err = dbTx.Rollback() + require.Error(t, err, "not in transaction") + }() + require.NoError(t, err) + + assert.IsType(t, &sqlx.Tx{}, dbTx) + + err = dbTx.Commit() + require.NoError(t, err) + + mockRouter.AssertExpectations(t) + }) + + t.Run("error getting data source", func(t *testing.T) { + mockRouter. + On("GetDataSource", ctx). + Return(nil, assert.AnError). + Once() + + dbTx, err := connectionPoolWithRouter.BeginTxx(ctx, nil) + require.Error(t, err) + assert.Nil(t, dbTx) + + mockRouter.AssertExpectations(t) + }) +} + +func TestConnectionPoolWithRouter_Close(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, outerErr := OpenDBConnectionPool(dbt.DSN) + require.NoError(t, outerErr) + defer dbConnectionPool.Close() + + mockRouter := new(MockDataSourceRouter) + + connectionPoolWithRouter, outerErr := NewConnectionPoolWithRouter(mockRouter) + require.NoError(t, outerErr) + + t.Run("Close successful", func(t *testing.T) { + mockRouter. + On("GetAllDataSources"). + Return([]DBConnectionPool{dbConnectionPool}, nil). + Once() + + err := connectionPoolWithRouter.Close() + require.NoError(t, err) + + mockRouter.AssertExpectations(t) + }) + + t.Run("error getting all data sources", func(t *testing.T) { + mockRouter. + On("GetAllDataSources"). + Return(nil, assert.AnError). + Once() + + err := connectionPoolWithRouter.Close() + require.Error(t, err) + + mockRouter.AssertExpectations(t) + }) +} + +func TestConnectionPoolWithRouter_Ping(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, outerErr := OpenDBConnectionPool(dbt.DSN) + require.NoError(t, outerErr) + defer dbConnectionPool.Close() + + mockRouter := new(MockDataSourceRouter) + + connectionPoolWithRouter, outerErr := NewConnectionPoolWithRouter(mockRouter) + require.NoError(t, outerErr) + ctx := context.Background() + + t.Run("Ping successful", func(t *testing.T) { + mockRouter. + On("GetDataSource", ctx). + Return(dbConnectionPool, nil). + Once() + + err := connectionPoolWithRouter.Ping(ctx) + require.NoError(t, err) + + mockRouter.AssertExpectations(t) + }) + + t.Run("error getting data source", func(t *testing.T) { + mockRouter. + On("GetDataSource", ctx). + Return(nil, assert.AnError). + Once() + + err := connectionPoolWithRouter.Ping(ctx) + require.Error(t, err) + + mockRouter.AssertExpectations(t) + }) +} + +func TestConnectionPoolWithRouter_SqlDB(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, outerErr := OpenDBConnectionPool(dbt.DSN) + require.NoError(t, outerErr) + defer dbConnectionPool.Close() + + mockRouter := new(MockDataSourceRouter) + + connectionPoolWithRouter, outerErr := NewConnectionPoolWithRouter(mockRouter) + require.NoError(t, outerErr) + ctx := context.Background() + + t.Run("SqlDB successful", func(t *testing.T) { + mockRouter. + On("GetDataSource", ctx). + Return(dbConnectionPool, nil). + Once() + + db, err := connectionPoolWithRouter.SqlDB(ctx) + require.NoError(t, err) + + assert.IsType(t, &sql.DB{}, db) + + mockRouter.AssertExpectations(t) + }) + + t.Run("error getting data source", func(t *testing.T) { + mockRouter. + On("GetDataSource", ctx). + Return(nil, assert.AnError). + Once() + + db, err := connectionPoolWithRouter.SqlDB(ctx) + require.Error(t, err) + assert.Nil(t, db) + + mockRouter.AssertExpectations(t) + }) +} + +func TestConnectionPoolWithRouter_SqlxDB(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, outerErr := OpenDBConnectionPool(dbt.DSN) + require.NoError(t, outerErr) + defer dbConnectionPool.Close() + + mockRouter := new(MockDataSourceRouter) + + connectionPoolWithRouter, outerErr := NewConnectionPoolWithRouter(mockRouter) + require.NoError(t, outerErr) + ctx := context.Background() + + t.Run("SqlDB successful", func(t *testing.T) { + mockRouter. + On("GetDataSource", ctx). + Return(dbConnectionPool, nil). + Once() + + db, err := connectionPoolWithRouter.SqlxDB(ctx) + require.NoError(t, err) + + assert.IsType(t, &sqlx.DB{}, db) + + mockRouter.AssertExpectations(t) + }) + + t.Run("error getting data source", func(t *testing.T) { + mockRouter. + On("GetDataSource", ctx). + Return(nil, assert.AnError). + Once() + + db, err := connectionPoolWithRouter.SqlxDB(ctx) + require.Error(t, err) + assert.Nil(t, db) + + mockRouter.AssertExpectations(t) + }) +} + +func TestConnectionPoolWithRouter_DSN(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, outerErr := OpenDBConnectionPool(dbt.DSN) + require.NoError(t, outerErr) + defer dbConnectionPool.Close() + + mockRouter := new(MockDataSourceRouter) + + connectionPoolWithRouter, outerErr := NewConnectionPoolWithRouter(mockRouter) + require.NoError(t, outerErr) + ctx := context.Background() + + t.Run("DSN successful", func(t *testing.T) { + mockRouter. + On("GetDataSource", ctx). + Return(dbConnectionPool, nil). + Once() + + dsn, err := connectionPoolWithRouter.DSN(ctx) + require.NoError(t, err) + + assert.Equal(t, dbt.DSN, dsn) + + mockRouter.AssertExpectations(t) + }) + + t.Run("error getting data source", func(t *testing.T) { + mockRouter. + On("GetDataSource", ctx). + Return(nil, assert.AnError). + Once() + + dsn, err := connectionPoolWithRouter.DSN(ctx) + require.Error(t, err) + assert.Equal(t, "", dsn) + + mockRouter.AssertExpectations(t) + }) +} diff --git a/db/migrate.go b/db/migrate.go index 5609760ac..8c410892d 100644 --- a/db/migrate.go +++ b/db/migrate.go @@ -31,5 +31,9 @@ func Migrate(dbURL string, dir migrate.MigrationDirection, count int, migrationF m := migrate.HttpFileSystemMigrationSource{FileSystem: http.FS(migrationFiles)} ctx := context.Background() - return ms.ExecMax(dbConnectionPool.SqlDB(ctx), dbConnectionPool.DriverName(), m, dir, count) + db, err := dbConnectionPool.SqlDB(ctx) + if err != nil { + return 0, fmt.Errorf("fetching sql.DB: %w", err) + } + return ms.ExecMax(db, dbConnectionPool.DriverName(), m, dir, count) } diff --git a/db/sql_exec_with_router.go b/db/sql_exec_with_router.go new file mode 100644 index 000000000..6a00e6d63 --- /dev/null +++ b/db/sql_exec_with_router.go @@ -0,0 +1,109 @@ +package db + +import ( + "context" + "database/sql" + "fmt" + + "github.com/jmoiron/sqlx" +) + +type DataSourceRouter interface { + GetDataSource(ctx context.Context) (DBConnectionPool, error) + GetAllDataSources() ([]DBConnectionPool, error) + AnyDataSource() (DBConnectionPool, error) +} + +type SQLExecutorWithRouter struct { + dataSourceRouter DataSourceRouter +} + +func NewSQLExecutorWithRouter(router DataSourceRouter) (*SQLExecutorWithRouter, error) { + if router == nil { + return nil, fmt.Errorf("router is nil in NewSQLExecutorWithRouter") + } + return &SQLExecutorWithRouter{ + dataSourceRouter: router, + }, nil +} + +func (s SQLExecutorWithRouter) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { + dbcpl, err := s.dataSourceRouter.GetDataSource(ctx) + if err != nil { + return fmt.Errorf("getting data source from context in GetContext: %w", err) + } + return dbcpl.GetContext(ctx, dest, query, args...) +} + +func (s SQLExecutorWithRouter) SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { + dbcpl, err := s.dataSourceRouter.GetDataSource(ctx) + if err != nil { + return fmt.Errorf("getting data source from context in SelectContext: %w", err) + } + + return dbcpl.SelectContext(ctx, dest, query, args...) +} + +func (s SQLExecutorWithRouter) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + dbcpl, err := s.dataSourceRouter.GetDataSource(ctx) + if err != nil { + return nil, fmt.Errorf("getting data source from context in ExecContext: %w", err) + } + + return dbcpl.ExecContext(ctx, query, args...) +} + +func (s SQLExecutorWithRouter) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + dbcpl, err := s.dataSourceRouter.GetDataSource(ctx) + if err != nil { + return nil, fmt.Errorf("getting data source from context in QueryContext: %w", err) + } + + return dbcpl.QueryContext(ctx, query, args...) +} + +func (s SQLExecutorWithRouter) QueryxContext(ctx context.Context, query string, args ...interface{}) (*sqlx.Rows, error) { + dbcpl, err := s.dataSourceRouter.GetDataSource(ctx) + if err != nil { + return nil, fmt.Errorf("getting data source from context in QueryxContext: %w", err) + } + + return dbcpl.QueryxContext(ctx, query, args...) +} + +func (s SQLExecutorWithRouter) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { + dbcpl, err := s.dataSourceRouter.GetDataSource(ctx) + if err != nil { + return nil, fmt.Errorf("getting data source from context in PrepareContext: %w", err) + } + + return dbcpl.PrepareContext(ctx, query) +} + +func (s SQLExecutorWithRouter) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *sqlx.Row { + dbcpl, err := s.dataSourceRouter.GetDataSource(ctx) + if err != nil { + return nil + } + + return dbcpl.QueryRowxContext(ctx, query, args...) +} + +func (s SQLExecutorWithRouter) Rebind(query string) string { + dbcp, err := s.dataSourceRouter.AnyDataSource() + if err != nil { + return sqlx.Rebind(sqlx.DOLLAR, query) + } + return dbcp.Rebind(query) +} + +func (m SQLExecutorWithRouter) DriverName() string { + dbcp, err := m.dataSourceRouter.AnyDataSource() + if err != nil { + return "" + } + return dbcp.DriverName() +} + +// make sure *SQLExecutorWithRouter implements SQLExecuter: +var _ SQLExecuter = (*SQLExecutorWithRouter)(nil) diff --git a/db/sql_exec_with_router_test.go b/db/sql_exec_with_router_test.go new file mode 100644 index 000000000..072b8ed76 --- /dev/null +++ b/db/sql_exec_with_router_test.go @@ -0,0 +1,430 @@ +package db + +import ( + "context" + "fmt" + "testing" + + "github.com/stellar/stellar-disbursement-platform-backend/db/dbtest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +type MockDataSourceRouter struct { + mock.Mock +} + +func (m *MockDataSourceRouter) GetAllDataSources() ([]DBConnectionPool, error) { + args := m.Called() + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]DBConnectionPool), args.Error(1) +} + +func (m *MockDataSourceRouter) AnyDataSource() (DBConnectionPool, error) { + args := m.Called() + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(DBConnectionPool), args.Error(1) +} + +func (m *MockDataSourceRouter) GetDataSource(ctx context.Context) (DBConnectionPool, error) { + args := m.Called(ctx) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(DBConnectionPool), args.Error(1) +} + +func TestSQLExecutorWithRouter_GetContext(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, outerErr := OpenDBConnectionPool(dbt.DSN) + require.NoError(t, outerErr) + defer dbConnectionPool.Close() + + mockRouter := new(MockDataSourceRouter) + + sqlExecWithRouter, outerErr := NewSQLExecutorWithRouter(mockRouter) + require.NoError(t, outerErr) + + ctx := context.Background() + query := "SELECT o.name FROM organizations o" + var dest string + + t.Run("query successful in GetContext", func(t *testing.T) { + mockRouter. + On("GetDataSource", ctx). + Return(dbConnectionPool, nil). + Once() + + err := sqlExecWithRouter.GetContext(ctx, &dest, query) + require.NoError(t, err) + require.Equal(t, "MyCustomAid", dest) + + mockRouter.AssertExpectations(t) + }) + + t.Run("error getting data source in GetContext", func(t *testing.T) { + mockRouter.On("GetDataSource", ctx). + Return(nil, fmt.Errorf("data source error")). + Once() + + err := sqlExecWithRouter.GetContext(ctx, &dest, query) + require.Error(t, err) + assert.Contains(t, err.Error(), "getting data source from context in GetContext") + + mockRouter.AssertExpectations(t) + }) +} + +func TestSQLExecutorWithRouter_SelectContext(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, outerErr := OpenDBConnectionPool(dbt.DSN) + require.NoError(t, outerErr) + defer dbConnectionPool.Close() + + mockRouter := new(MockDataSourceRouter) + + sqlExecWithRouter, outerErr := NewSQLExecutorWithRouter(mockRouter) + require.NoError(t, outerErr) + + ctx := context.Background() + query := "SELECT o.name FROM organizations o" + var dest []string + t.Run("query successful in SelectContext", func(t *testing.T) { + mockRouter. + On("GetDataSource", ctx). + Return(dbConnectionPool, nil). + Once() + + err := sqlExecWithRouter.SelectContext(ctx, &dest, query) + require.NoError(t, err) + require.Equal(t, 1, len(dest)) + require.Equal(t, "MyCustomAid", dest[0]) + + mockRouter.AssertExpectations(t) + }) + + t.Run("error getting data source in SelectContext", func(t *testing.T) { + mockRouter.On("GetDataSource", ctx). + Return(nil, fmt.Errorf("data source error")). + Once() + + err := sqlExecWithRouter.SelectContext(ctx, &dest, query) + require.Error(t, err) + assert.Contains(t, err.Error(), "getting data source from context in SelectContext") + + mockRouter.AssertExpectations(t) + }) +} + +func TestSQLExecutorWithRouter_ExecContext(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, outerErr := OpenDBConnectionPool(dbt.DSN) + require.NoError(t, outerErr) + defer dbConnectionPool.Close() + + mockRouter := new(MockDataSourceRouter) + + sqlExecWithRouter, outerErr := NewSQLExecutorWithRouter(mockRouter) + require.NoError(t, outerErr) + + ctx := context.Background() + query := "INSERT INTO assets (code, issuer) VALUES ('BTC', 'GCNSGHUCG5VMGLT5RIYYZSO7VQULQKAJ62QA33DBC5PPBSO57LFWVV6P')" + t.Run("query successful in ExecContext", func(t *testing.T) { + mockRouter. + On("GetDataSource", ctx). + Return(dbConnectionPool, nil). + Once() + + result, err := sqlExecWithRouter.ExecContext(ctx, query) + require.NoError(t, err) + + rowsAffected, err := result.RowsAffected() + require.NoError(t, err) + + assert.Equal(t, rowsAffected, int64(1)) + require.NoError(t, err) + + mockRouter.AssertExpectations(t) + }) + + t.Run("error getting data source in ExecContext", func(t *testing.T) { + mockRouter.On("GetDataSource", ctx). + Return(nil, fmt.Errorf("data source error")). + Once() + + _, err := sqlExecWithRouter.ExecContext(ctx, query) + require.Error(t, err) + assert.Contains(t, err.Error(), "getting data source from context in ExecContext") + + mockRouter.AssertExpectations(t) + }) +} + +func TestSQLExecutorWithRouter_QueryContext(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, outerErr := OpenDBConnectionPool(dbt.DSN) + require.NoError(t, outerErr) + defer dbConnectionPool.Close() + + mockRouter := new(MockDataSourceRouter) + + sqlExecWithRouter, outerErr := NewSQLExecutorWithRouter(mockRouter) + require.NoError(t, outerErr) + + ctx := context.Background() + query := "SELECT o.name FROM organizations o" + t.Run("query successful in QueryContext", func(t *testing.T) { + mockRouter. + On("GetDataSource", ctx). + Return(dbConnectionPool, nil). + Once() + + rows, err := sqlExecWithRouter.QueryContext(ctx, query) + require.NoError(t, err) + + var dest []string + for rows.Next() { + var name string + err = rows.Scan(&name) + require.NoError(t, err) + dest = append(dest, name) + } + + require.Equal(t, 1, len(dest)) + require.Equal(t, "MyCustomAid", dest[0]) + + mockRouter.AssertExpectations(t) + }) + + t.Run("error getting data source in QueryContext", func(t *testing.T) { + mockRouter.On("GetDataSource", ctx). + Return(nil, fmt.Errorf("data source error")). + Once() + + _, err := sqlExecWithRouter.QueryContext(ctx, query) + require.Error(t, err) + assert.Contains(t, err.Error(), "getting data source from context in QueryContext") + + mockRouter.AssertExpectations(t) + }) +} + +func TestSQLExecutorWithRouter_QueryxContext(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, outerErr := OpenDBConnectionPool(dbt.DSN) + require.NoError(t, outerErr) + defer dbConnectionPool.Close() + + mockRouter := new(MockDataSourceRouter) + + sqlExecWithRouter, outerErr := NewSQLExecutorWithRouter(mockRouter) + require.NoError(t, outerErr) + + ctx := context.Background() + query := "SELECT o.name FROM organizations o" + t.Run("query successful in QueryxContext", func(t *testing.T) { + mockRouter. + On("GetDataSource", ctx). + Return(dbConnectionPool, nil). + Once() + + rows, err := sqlExecWithRouter.QueryxContext(ctx, query) + require.NoError(t, err) + + var dest []string + for rows.Next() { + var name string + err = rows.Scan(&name) + require.NoError(t, err) + dest = append(dest, name) + } + + require.Equal(t, 1, len(dest)) + require.Equal(t, "MyCustomAid", dest[0]) + + mockRouter.AssertExpectations(t) + }) + + t.Run("error getting data source in QueryxContext", func(t *testing.T) { + mockRouter.On("GetDataSource", ctx). + Return(nil, fmt.Errorf("data source error")). + Once() + + _, err := sqlExecWithRouter.QueryxContext(ctx, query) + require.Error(t, err) + assert.Contains(t, err.Error(), "getting data source from context in QueryxContext") + + mockRouter.AssertExpectations(t) + }) +} + +func TestSQLExecutorWithRouter_PrepareContext(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, outerErr := OpenDBConnectionPool(dbt.DSN) + require.NoError(t, outerErr) + defer dbConnectionPool.Close() + + mockRouter := new(MockDataSourceRouter) + + sqlExecWithRouter, outerErr := NewSQLExecutorWithRouter(mockRouter) + require.NoError(t, outerErr) + + ctx := context.Background() + query := "SELECT o.name FROM organizations o" + t.Run("query successful in PrepareContext", func(t *testing.T) { + mockRouter. + On("GetDataSource", ctx). + Return(dbConnectionPool, nil). + Once() + + stmt, err := sqlExecWithRouter.PrepareContext(ctx, query) + require.NoError(t, err) + + rows, err := stmt.Query() + require.NoError(t, err) + + var dest []string + for rows.Next() { + var name string + err = rows.Scan(&name) + require.NoError(t, err) + dest = append(dest, name) + } + + require.Equal(t, 1, len(dest)) + require.Equal(t, "MyCustomAid", dest[0]) + + mockRouter.AssertExpectations(t) + }) + + t.Run("error getting data source in PrepareContext", func(t *testing.T) { + mockRouter.On("GetDataSource", ctx). + Return(nil, fmt.Errorf("data source error")). + Once() + + _, err := sqlExecWithRouter.PrepareContext(ctx, query) + require.Error(t, err) + assert.Contains(t, err.Error(), "getting data source from context in PrepareContext") + + mockRouter.AssertExpectations(t) + }) +} + +func TestSQLExecutorWithRouter_QueryRowxContext(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, outerErr := OpenDBConnectionPool(dbt.DSN) + require.NoError(t, outerErr) + defer dbConnectionPool.Close() + + mockRouter := new(MockDataSourceRouter) + + sqlExecWithRouter, outerErr := NewSQLExecutorWithRouter(mockRouter) + require.NoError(t, outerErr) + + ctx := context.Background() + query := "SELECT o.name FROM organizations o" + t.Run("query successful in QueryRowxContext", func(t *testing.T) { + mockRouter. + On("GetDataSource", ctx). + Return(dbConnectionPool, nil). + Once() + + row := sqlExecWithRouter.QueryRowxContext(ctx, query) + + var dest string + err := row.Scan(&dest) + require.NoError(t, err) + + require.Equal(t, "MyCustomAid", dest) + + mockRouter.AssertExpectations(t) + }) + + t.Run("error getting data source in QueryRowxContext", func(t *testing.T) { + mockRouter.On("GetDataSource", ctx). + Return(nil, fmt.Errorf("data source error")). + Once() + + row := sqlExecWithRouter.QueryRowxContext(ctx, query) + require.Nil(t, row) + + mockRouter.AssertExpectations(t) + }) +} + +func TestSQLExecutorWithRouter_Rebind(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, outerErr := OpenDBConnectionPool(dbt.DSN) + require.NoError(t, outerErr) + defer dbConnectionPool.Close() + + mockRouter := new(MockDataSourceRouter) + + sqlExecWithRouter, outerErr := NewSQLExecutorWithRouter(mockRouter) + require.NoError(t, outerErr) + + query := "SELECT * FROM organizations o WHERE o.name = ?" + expected := "SELECT * FROM organizations o WHERE o.name = $1" + t.Run("query successful in Rebind", func(t *testing.T) { + mockRouter. + On("AnyDataSource"). + Return(dbConnectionPool, nil). + Once() + reboundQuery := sqlExecWithRouter.Rebind(query) + require.Equal(t, expected, reboundQuery) + }) + + t.Run("query successful in Rebind when there is no connectionPool", func(t *testing.T) { + mockRouter. + On("AnyDataSource"). + Return(nil, fmt.Errorf("data source error")). + Once() + reboundQuery := sqlExecWithRouter.Rebind(query) + require.Equal(t, expected, reboundQuery) + }) +} + +func TestSQLExecutorWithRouter_DriverName(t *testing.T) { + dbt := dbtest.Open(t) + defer dbt.Close() + dbConnectionPool, outerErr := OpenDBConnectionPool(dbt.DSN) + require.NoError(t, outerErr) + defer dbConnectionPool.Close() + + mockRouter := new(MockDataSourceRouter) + + sqlExecWithRouter, outerErr := NewSQLExecutorWithRouter(mockRouter) + require.NoError(t, outerErr) + + expected := "postgres" + t.Run("query successful in DriverName", func(t *testing.T) { + mockRouter. + On("AnyDataSource"). + Return(dbConnectionPool, nil). + Once() + driverName := sqlExecWithRouter.DriverName() + require.Equal(t, expected, driverName) + }) + + t.Run("empty when there is no connection pool", func(t *testing.T) { + mockRouter. + On("AnyDataSource"). + Return(nil, fmt.Errorf("data source error")). + Once() + driverName := sqlExecWithRouter.DriverName() + require.Empty(t, driverName) + }) +} diff --git a/internal/serve/serve.go b/internal/serve/serve.go index ed24bc431..fbbd8769c 100644 --- a/internal/serve/serve.go +++ b/internal/serve/serve.go @@ -421,7 +421,11 @@ func createAuthManager(dbConnectionPool db.DBConnectionPool, ec256PublicKey, ec2 passwordEncrypter := auth.NewDefaultPasswordEncrypter() ctx := context.Background() - authDBConnectionPool := auth.DBConnectionPoolFromSqlDB(dbConnectionPool.SqlDB(ctx), dbConnectionPool.DriverName()) + dbcp, err := dbConnectionPool.SqlDB(ctx) + if err != nil { + return nil, fmt.Errorf("getting sql db from db connection pool: %w", err) + } + authDBConnectionPool := auth.DBConnectionPoolFromSqlDB(dbcp, dbConnectionPool.DriverName()) authManager := auth.NewAuthManager( auth.WithDefaultAuthenticatorOption(authDBConnectionPool, passwordEncrypter, time.Hour*time.Duration(resetTokenExpirationHours)), auth.WithDefaultJWTManagerOption(ec256PublicKey, ec256PrivateKey), diff --git a/internal/serve/serve_test.go b/internal/serve/serve_test.go index 53f5143a0..53fae77e1 100644 --- a/internal/serve/serve_test.go +++ b/internal/serve/serve_test.go @@ -338,7 +338,9 @@ func Test_createAuthManager(t *testing.T) { // creates the expected auth manager passwordEncrypter := auth.NewDefaultPasswordEncrypter() - authDBConnectionPool := auth.DBConnectionPoolFromSqlDB(dbConnectionPool.SqlDB(context.Background()), dbConnectionPool.DriverName()) + dbcp, err := dbConnectionPool.SqlDB(context.Background()) + require.NoError(t, err) + authDBConnectionPool := auth.DBConnectionPoolFromSqlDB(dbcp, dbConnectionPool.DriverName()) wantAuthManager := auth.NewAuthManager( auth.WithDefaultAuthenticatorOption(authDBConnectionPool, passwordEncrypter, time.Hour*time.Duration(1)), auth.WithDefaultJWTManagerOption(publicKeyStr, privateKeyStr), diff --git a/stellar-multitenant/pkg/router/multitenant_data_source_router.go b/stellar-multitenant/pkg/router/multitenant_data_source_router.go new file mode 100644 index 000000000..5daf5da2c --- /dev/null +++ b/stellar-multitenant/pkg/router/multitenant_data_source_router.go @@ -0,0 +1,104 @@ +package router + +import ( + "context" + "errors" + "fmt" + "sync" + + "github.com/stellar/stellar-disbursement-platform-backend/db" + "github.com/stellar/stellar-disbursement-platform-backend/stellar-multitenant/pkg/tenant" +) + +var ( + ErrTenantNotFoundInContext = errors.New("tenant not found in context") + ErrNoDataSourcesAvailable = errors.New("no data sources are available") +) + +type tenantContextKey struct{} + +type MultiTenantDataSourceRouter struct { + dataSources sync.Map + tenantManager *tenant.Manager +} + +func NewMultiTenantDataSourceRouter(tenantManager *tenant.Manager) *MultiTenantDataSourceRouter { + return &MultiTenantDataSourceRouter{ + tenantManager: tenantManager, + } +} + +func (m *MultiTenantDataSourceRouter) GetDataSource(ctx context.Context) (db.DBConnectionPool, error) { + tenant, ok := GetTenantFromContext(ctx) + if !ok { + return nil, ErrTenantNotFoundInContext + } + + return m.GetDataSourceForTenant(ctx, *tenant) +} + +// GetDataSourceForTenant returns the database connection pool for the given tenant if it exists, otherwise create a new one. +func (m *MultiTenantDataSourceRouter) GetDataSourceForTenant(ctx context.Context, tenant tenant.Tenant) (db.DBConnectionPool, error) { + value, exists := m.dataSources.Load(tenant.ID) + if exists { + return value.(db.DBConnectionPool), nil + } + + u, err := m.tenantManager.GetDSNForTenant(ctx, tenant.Name) + if err != nil || u == "" { + return nil, fmt.Errorf("getting database DSN for tenant %s: %w", tenant.ID, err) + } + + dbcp, err := db.OpenDBConnectionPool(u) + if err != nil { + return nil, fmt.Errorf("opening database connection pool for tenant %s: %w", tenant.ID, err) + } + + // Store the new pool, but if another goroutine already stored a pool for this tenant, + // then use the existing one and close the newly created one. + actualValue, loaded := m.dataSources.LoadOrStore(tenant.ID, dbcp) + if loaded { + dbcp.Close() // Close the newly created pool if we're not using it + return actualValue.(db.DBConnectionPool), nil + } + + return dbcp, nil +} + +// GetAllDataSources returns all the database connection pools. +func (m *MultiTenantDataSourceRouter) GetAllDataSources() ([]db.DBConnectionPool, error) { + var pools []db.DBConnectionPool + m.dataSources.Range(func(_, value interface{}) bool { + pools = append(pools, value.(db.DBConnectionPool)) + return true + }) + return pools, nil +} + +func (m *MultiTenantDataSourceRouter) AnyDataSource() (db.DBConnectionPool, error) { + var anyDBCP db.DBConnectionPool + var found bool + m.dataSources.Range(func(_, value interface{}) bool { + anyDBCP = value.(db.DBConnectionPool) + found = true + return false + }) + if !found { + return nil, ErrNoDataSourcesAvailable + } + return anyDBCP, nil +} + +// SetTenantInContext stores the tenant information in the context. +func SetTenantInContext(ctx context.Context, tenant *tenant.Tenant) context.Context { + return context.WithValue(ctx, tenantContextKey{}, tenant) +} + +// GetTenantFromContext retrieves the tenant information from the context. +func GetTenantFromContext(ctx context.Context) (*tenant.Tenant, bool) { + tenant, ok := ctx.Value(tenantContextKey{}).(*tenant.Tenant) + return tenant, ok +} + +// make sure *MultiTenantDataSourceRouter implements DataSourceRouter: +var _ db.DataSourceRouter = (*MultiTenantDataSourceRouter)(nil) diff --git a/stellar-multitenant/pkg/router/multitenant_data_source_router_test.go b/stellar-multitenant/pkg/router/multitenant_data_source_router_test.go new file mode 100644 index 000000000..1b7b1852a --- /dev/null +++ b/stellar-multitenant/pkg/router/multitenant_data_source_router_test.go @@ -0,0 +1,127 @@ +package router + +import ( + "context" + "testing" + + "github.com/stellar/stellar-disbursement-platform-backend/db" + "github.com/stellar/stellar-disbursement-platform-backend/db/dbtest" + "github.com/stellar/stellar-disbursement-platform-backend/stellar-multitenant/pkg/tenant" + "github.com/stretchr/testify/require" +) + +func TestMultiTenantDataSourceRouter_GetDataSource(t *testing.T) { + dbt := dbtest.OpenWithTenantMigrationsOnly(t) + defer dbt.Close() + + dbConnectionPool, outerErr := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, outerErr) + defer dbConnectionPool.Close() + + tenantManager := tenant.NewManager(tenant.WithDatabase(dbConnectionPool)) + + router := NewMultiTenantDataSourceRouter(tenantManager) + + ctx := context.Background() + + t.Run("error tenant not found in context", func(t *testing.T) { + dbcp, err := router.GetDataSource(ctx) + require.Nil(t, dbcp) + require.EqualError(t, err, ErrTenantNotFoundInContext.Error()) + }) + + t.Run("successfully getting data source", func(t *testing.T) { + // Create a new context with tenant information + tenantInfo := &tenant.Tenant{ID: "95e788b6-c80e-4975-9d12-141001fe6e44", Name: "aid-org-1"} + ctx = SetTenantInContext(context.Background(), tenantInfo) + + dbcp, err := router.GetDataSource(ctx) + require.NotNil(t, dbcp) + require.NoError(t, err) + defer dbcp.Close() + + dsn, err := dbcp.DSN(ctx) + require.NoError(t, err) + require.Contains(t, dsn, tenantInfo.Name) + }) +} + +func TestMultiTenantDataSourceRouter_GetAllDataSources(t *testing.T) { + dbt := dbtest.OpenWithTenantMigrationsOnly(t) + defer dbt.Close() + + dbConnectionPool, outerErr := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, outerErr) + defer dbConnectionPool.Close() + + tenantManager := tenant.NewManager(tenant.WithDatabase(dbConnectionPool)) + + router := NewMultiTenantDataSourceRouter(tenantManager) + + t.Run("empty data sources", func(t *testing.T) { + dbcps, err := router.GetAllDataSources() + require.NoError(t, err) + require.Nil(t, dbcps) + require.Empty(t, dbcps) + }) + + t.Run("successfully getting data sources", func(t *testing.T) { + // Store DB Connection Pool for aid-org-1 + tenantInfo := &tenant.Tenant{ID: "95e788b6-c80e-4975-9d12-141001fe6e44", Name: "aid-org-1"} + ctx := SetTenantInContext(context.Background(), tenantInfo) + dbcp1, err := router.GetDataSource(ctx) + require.NoError(t, err) + require.NotNil(t, dbcp1) + defer dbcp1.Close() + + // Store DB Connection Pool for aid-org-2 + tenantInfo = &tenant.Tenant{ID: "95e788b6-c80e-4975-9d12-141001fe6e45", Name: "aid-org-2"} + ctx = SetTenantInContext(context.Background(), tenantInfo) + dbcp2, err := router.GetDataSource(ctx) + require.NoError(t, err) + require.NotNil(t, dbcp2) + defer dbcp2.Close() + + dbcps, err := router.GetAllDataSources() + require.NotNil(t, dbcps) + require.NoError(t, err) + + require.Equal(t, 2, len(dbcps)) + require.Contains(t, dbcps, dbcp1) + require.Contains(t, dbcps, dbcp2) + }) +} + +func TestMultiTenantDataSourceRouter_AnyDataSource(t *testing.T) { + dbt := dbtest.OpenWithTenantMigrationsOnly(t) + defer dbt.Close() + + dbConnectionPool, outerErr := db.OpenDBConnectionPool(dbt.DSN) + require.NoError(t, outerErr) + defer dbConnectionPool.Close() + + tenantManager := tenant.NewManager(tenant.WithDatabase(dbConnectionPool)) + + router := NewMultiTenantDataSourceRouter(tenantManager) + + t.Run("no data sources available", func(t *testing.T) { + dbcp, err := router.AnyDataSource() + require.Nil(t, dbcp) + require.EqualError(t, err, ErrNoDataSourcesAvailable.Error()) + }) + + t.Run("successfully getting data source", func(t *testing.T) { + // Store DB Connection Pool for aid-org-1 + tenantInfo := &tenant.Tenant{ID: "95e788b6-c80e-4975-9d12-141001fe6e44", Name: "aid-org-1"} + ctx := SetTenantInContext(context.Background(), tenantInfo) + dbcp1, err := router.GetDataSource(ctx) + require.NoError(t, err) + require.NotNil(t, dbcp1) + defer dbcp1.Close() + + dbcp, err := router.AnyDataSource() + require.NotNil(t, dbcp) + require.NoError(t, err) + require.Equal(t, dbcp1, dbcp) + }) +} diff --git a/stellar-multitenant/pkg/tenant/manager.go b/stellar-multitenant/pkg/tenant/manager.go index 9963dd359..936266c95 100644 --- a/stellar-multitenant/pkg/tenant/manager.go +++ b/stellar-multitenant/pkg/tenant/manager.go @@ -44,33 +44,26 @@ func (m *Manager) ProvisionNewTenant(ctx context.Context, name, userFirstName, u return nil, fmt.Errorf("creating a new database schema: %w", err) } - dataSourceName, err := m.db.DSN(ctx) - if err != nil { - return nil, fmt.Errorf("getting database DSN: %w", err) - } - u, err := url.Parse(dataSourceName) + u, err := m.GetDSNForTenant(ctx, t.Name) if err != nil { - return nil, fmt.Errorf("parsing database DSN: %w", err) + return nil, fmt.Errorf("getting database DSN for tenant %s: %w", t.Name, err) } - q := u.Query() - q.Set("search_path", schemaName) - u.RawQuery = q.Encode() // Applying migrations log.Infof("applying SDP migrations on the tenant %s schema", t.Name) - err = m.RunMigrationsForTenant(ctx, t, u.String(), migrate.Up, 0, sdpmigrations.FS, db.StellarSDPMigrationsTableName) + err = m.RunMigrationsForTenant(ctx, t, u, migrate.Up, 0, sdpmigrations.FS, db.StellarSDPMigrationsTableName) if err != nil { return nil, fmt.Errorf("applying SDP migrations: %w", err) } log.Infof("applying stellar-auth migrations on the tenant %s schema", t.Name) - err = m.RunMigrationsForTenant(ctx, t, u.String(), migrate.Up, 0, authmigrations.FS, db.StellarAuthMigrationsTableName) + err = m.RunMigrationsForTenant(ctx, t, u, migrate.Up, 0, authmigrations.FS, db.StellarAuthMigrationsTableName) if err != nil { return nil, fmt.Errorf("applying stellar-auth migrations: %w", err) } // Connecting to the tenant database schema - tenantSchemaConnectionPool, err := db.OpenDBConnectionPool(u.String()) + tenantSchemaConnectionPool, err := db.OpenDBConnectionPool(u) if err != nil { return nil, fmt.Errorf("opening database connection on tenant schema: %w", err) } @@ -110,6 +103,34 @@ func (m *Manager) ProvisionNewTenant(ctx context.Context, name, userFirstName, u return t, nil } +func (m *Manager) GetDSNForTenant(ctx context.Context, tenantName string) (string, error) { + dataSourceName, err := m.db.DSN(ctx) + if err != nil { + return "", fmt.Errorf("getting database DSN: %w", err) + } + u, err := url.Parse(dataSourceName) + if err != nil { + return "", fmt.Errorf("parsing database DSN: %w", err) + } + q := u.Query() + schemaName := fmt.Sprintf("sdp_%s", tenantName) + q.Set("search_path", schemaName) + u.RawQuery = q.Encode() + return u.String(), nil +} + +func (m *Manager) GetTenantByName(ctx context.Context, name string) (*Tenant, error) { + const q = "SELECT * FROM tenants WHERE name = $1" + var t Tenant + if err := m.db.GetContext(ctx, &t, q, name); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrTenantDoesNotExist + } + return nil, fmt.Errorf("getting tenant %s: %w", name, err) + } + return &t, nil +} + func (m *Manager) AddTenant(ctx context.Context, name string) (*Tenant, error) { if name == "" { return nil, ErrEmptyTenantName