Skip to content

Commit

Permalink
SDP-761 - Add Multi-tenant Connection Pool, sql exec and router
Browse files Browse the repository at this point in the history
  • Loading branch information
marwen-abid committed Nov 6, 2023
1 parent 0cc8bba commit 7eb0cc5
Show file tree
Hide file tree
Showing 13 changed files with 1,179 additions and 25 deletions.
18 changes: 12 additions & 6 deletions db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ type DBConnectionPool interface {
BeginTxx(ctx context.Context, opts *sql.TxOptions) (DBTransaction, error)
Close() error
Ping(ctx context.Context) error
SqlDB(ctx context.Context) *sql.DB
SqlxDB(ctx context.Context) *sqlx.DB
SqlDB(ctx context.Context) (*sql.DB, error)
SqlxDB(ctx context.Context) (*sqlx.DB, error)
DSN(ctx context.Context) (string, error)
}

Expand All @@ -49,12 +49,18 @@ func (db *DBConnectionPoolImplementation) Ping(ctx context.Context) error {
return db.DB.PingContext(ctx)
}

func (db *DBConnectionPoolImplementation) SqlDB(ctx context.Context) *sql.DB {
return db.DB.DB
func (db *DBConnectionPoolImplementation) SqlDB(ctx context.Context) (*sql.DB, error) {
if db.DB.DB == nil {
return nil, fmt.Errorf("sql.DB is not initialized")
}
return db.DB.DB, nil
}

func (db *DBConnectionPoolImplementation) SqlxDB(ctx context.Context) *sqlx.DB {
return db.DB
func (db *DBConnectionPoolImplementation) SqlxDB(ctx context.Context) (*sqlx.DB, error) {
if db.DB == nil {
return nil, fmt.Errorf("sqlx.DB is not initialized")
}
return db.DB, nil
}

func (db *DBConnectionPoolImplementation) DSN(ctx context.Context) (string, error) {
Expand Down
4 changes: 2 additions & 2 deletions db/db_connection_pool_with_metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,11 @@ func (dbc *DBConnectionPoolWithMetrics) Ping(ctx context.Context) error {
return dbc.dbConnectionPool.Ping(ctx)
}

func (dbc *DBConnectionPoolWithMetrics) SqlDB(ctx context.Context) *sql.DB {
func (dbc *DBConnectionPoolWithMetrics) SqlDB(ctx context.Context) (*sql.DB, error) {
return dbc.dbConnectionPool.SqlDB(ctx)
}

func (dbc *DBConnectionPoolWithMetrics) SqlxDB(ctx context.Context) *sqlx.DB {
func (dbc *DBConnectionPoolWithMetrics) SqlxDB(ctx context.Context) (*sqlx.DB, error) {
return dbc.dbConnectionPool.SqlxDB(ctx)
}

Expand Down
6 changes: 4 additions & 2 deletions db/db_connection_pool_with_metrics_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ func TestDBConnectionPoolWithMetrics_SqlxDB(t *testing.T) {
require.NoError(t, err)

ctx := context.Background()
sqlxDB := dbConnectionPoolWithMetrics.SqlxDB(ctx)
sqlxDB, err := dbConnectionPoolWithMetrics.SqlxDB(ctx)
require.NoError(t, err)

assert.IsType(t, &sqlx.DB{}, sqlxDB)
}
Expand All @@ -44,7 +45,8 @@ func TestDBConnectionPoolWithMetrics_SqlDB(t *testing.T) {
require.NoError(t, err)

ctx := context.Background()
sqlDB := dbConnectionPoolWithMetrics.SqlDB(ctx)
sqlDB, err := dbConnectionPoolWithMetrics.SqlDB(ctx)
require.NoError(t, err)

assert.IsType(t, &sql.DB{}, sqlDB)
}
Expand Down
85 changes: 85 additions & 0 deletions db/db_connection_pool_with_router.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package db

import (
"context"
"database/sql"
"fmt"

"github.com/jmoiron/sqlx"
)

// ConnectionPoolWithRouter implements the DBConnectionPool interface
type ConnectionPoolWithRouter struct {
SQLExecutorWithRouter
}

// NewConnectionPoolWithRouter creates a new ConnectionPoolWithRouter
func NewConnectionPoolWithRouter(dataSourceRouter DataSourceRouter) (*ConnectionPoolWithRouter, error) {
sqlExecutor, err := NewSQLExecutorWithRouter(dataSourceRouter)
if err != nil {
return nil, fmt.Errorf("creating new sqlExecutor for connection pool with router: %w", err)
}
return &ConnectionPoolWithRouter{
SQLExecutorWithRouter: *sqlExecutor,
}, nil
}

func (m ConnectionPoolWithRouter) BeginTxx(ctx context.Context, opts *sql.TxOptions) (DBTransaction, error) {
dbcpl, err := m.dataSourceRouter.GetDataSource(ctx)
if err != nil {
return nil, fmt.Errorf("getting data source from context in BeginTxx: %w", err)
}
return dbcpl.BeginTxx(ctx, opts)
}

func (m ConnectionPoolWithRouter) Close() error {
dbcpls, err := m.dataSourceRouter.GetAllDataSources()
if err != nil {
return fmt.Errorf("getting all data sources in Close: %w", err)
}
if len(dbcpls) == 0 {
return fmt.Errorf("no data sources found in Close")
}
for _, dbcpl := range dbcpls {
err = dbcpl.Close()
if err != nil {
return fmt.Errorf("closing data source in Close: %w", err)
}
}
return nil
}

func (m ConnectionPoolWithRouter) Ping(ctx context.Context) error {
dbcpl, err := m.dataSourceRouter.GetDataSource(ctx)
if err != nil {
return fmt.Errorf("getting data source from context in Ping: %w", err)
}
return dbcpl.Ping(ctx)
}

func (m ConnectionPoolWithRouter) SqlDB(ctx context.Context) (*sql.DB, error) {
dbcpl, err := m.dataSourceRouter.GetDataSource(ctx)
if err != nil {
return nil, fmt.Errorf("getting data source from context in SqlDB: %w", err)
}
return dbcpl.SqlDB(ctx)
}

func (m ConnectionPoolWithRouter) SqlxDB(ctx context.Context) (*sqlx.DB, error) {
dbcpl, err := m.dataSourceRouter.GetDataSource(ctx)
if err != nil {
return nil, fmt.Errorf("getting data source from context in SqlxDB: %w", err)
}
return dbcpl.SqlxDB(ctx)
}

func (m ConnectionPoolWithRouter) DSN(ctx context.Context) (string, error) {
dbcpl, err := m.dataSourceRouter.GetDataSource(ctx)
if err != nil {
return "", fmt.Errorf("getting data source from context in DSN: %w", err)
}
return dbcpl.DSN(ctx)
}

// make sure *ConnectionPoolWithRouter implements DBConnectionPool:
var _ DBConnectionPool = (*ConnectionPoolWithRouter)(nil)
Loading

0 comments on commit 7eb0cc5

Please sign in to comment.