Skip to content

Commit

Permalink
Merge pull request #99 from timescale/vperez/simplify-connection-stri…
Browse files Browse the repository at this point in the history
…ng-handling

Simplify connection string configuration
  • Loading branch information
MetalBlueberry authored Dec 10, 2024
2 parents d8a7672 + 3a46374 commit 0357e2e
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 405 deletions.
10 changes: 7 additions & 3 deletions cmd/timescaledb-parallel-copy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,9 @@ var (

// Parse args
func init() {
// Documented https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING
flag.StringVar(&postgresConnect, "connection", "host=localhost user=postgres sslmode=disable", "PostgreSQL connection url")
flag.StringVar(&dbName, "db-name", "", "Database where the destination table exists")
flag.StringVar(&dbName, "db-name", "", "(deprecated) Database where the destination table exists")
flag.StringVar(&tableName, "table", "test_table", "Destination table for insertions")
flag.StringVar(&schemaName, "schema", "public", "Destination table's schema")
flag.BoolVar(&truncate, "truncate", false, "Truncate the destination table before insert")
Expand Down Expand Up @@ -86,13 +87,16 @@ func (l csvCopierLogger) Infof(msg string, args ...interface{}) {

func main() {
if showVersion {
fmt.Printf("%s %s (%s %s)\n", binName, version, runtime.GOOS, runtime.GOARCH)
log.Printf("%s %s (%s %s)\n", binName, version, runtime.GOOS, runtime.GOARCH)
os.Exit(0)
}

if dbName != "" {
log.Fatalf("Error: Deprecated flag -db-name is being used. Update -connection to connect to the given database")
}

copier, err := csvcopy.NewCopier(
postgresConnect,
dbName,
schemaName,
tableName,
copyOptions,
Expand Down
140 changes: 2 additions & 138 deletions internal/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,151 +4,15 @@ import (
"context"
"fmt"
"io"
"os"
"regexp"
"strconv"
"strings"

"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/stdlib"
"github.com/jmoiron/sqlx"
)

// minimalConnConfig is the minimal settings we need for connection. More
// unusual options are currently not supported.
type minimalConnConfig struct {
host string
user string
password string
db string
port uint16
sslmode string
}

// DSN returns the PostgreSQL compatible DSN string that corresponds to mcc.
// This is expressed as a string of <key>=<value> separated by spaces.
func (mcc *minimalConnConfig) DSN() string {
var s strings.Builder
writeNonempty := func(key, val string) {
if val != "" {
_, err := s.WriteString(key + "=" + val + " ")
if err != nil {
panic(err)
}
}
}
writeNonempty("host", mcc.host)
writeNonempty("user", mcc.user)
writeNonempty("password", mcc.password)
writeNonempty("dbname", mcc.db)
if mcc.port != 0 {
writeNonempty("port", strconv.FormatUint(uint64(mcc.port), 10))
}
writeNonempty("sslmode", mcc.sslmode)
writeNonempty("application_name", "timescaledb-parallel-copy")

return strings.TrimSpace(s.String())
}

// Overrideable is an interface for defining ways to override PG settings
// outside of the usual manners (through the connection string/URL or env vars).
// An example would be having specific flags that can be used to set database
// connect parameters.
type Overrideable interface {
Override() string
}

// OverrideDBName is a type for overriding the database name used to connect.
// To use it, one casts a string of the database name as an OverrideDBName
type OverrideDBName string

func (o OverrideDBName) Override() string {
return string(o)
}

// parseConnStr uses an external lib (that backs pgx) to take care of parsing
// connection parameters for connecting to PostgreSQL. It handles the connStr
// being in DSN or URL form, as well as reading env vars for additional settings.
func parseConnStr(connStr string, overrides ...Overrideable) (*minimalConnConfig, error) {
config, err := pgconn.ParseConfig(connStr)
if err != nil {
return nil, err
}
sslmode, err := determineTLS(connStr)
if err != nil {
return nil, err
}

mcc := &minimalConnConfig{
host: config.Host,
user: config.User,
password: config.Password,
db: config.Database,
port: config.Port,
sslmode: sslmode,
}

for _, o := range overrides {
switch o.(type) {
case OverrideDBName:
mcc.db = o.Override()
default:
return nil, fmt.Errorf("unknown overrideable: %T=%s", o, o.Override())
}
}

return mcc, nil
}

// ErrInvalidSSLMode is the error when the provided SSL mode is not one of the
// values that PostgreSQL supports.
type ErrInvalidSSLMode struct {
given string
}

func (e *ErrInvalidSSLMode) Error() string {
return "invalid SSL mode: " + e.given
}

const (
// envSSLMode is the environment variable key for SSL mode.
envSSLMode = "PGSSLMODE"
)

var sslmodeRegex = regexp.MustCompile("sslmode=([a-zA-Z-]+)")

// determineTLS attempts to match SSL mode to a known PostgreSQL supported value.
func determineTLS(connStr string) (string, error) {
res := sslmodeRegex.FindStringSubmatch(connStr)
var sslmode string
if len(res) == 2 {
sslmode = res[1]
} else {
sslmode = os.Getenv(envSSLMode)
}

if sslmode == "" {
return "", nil
}

switch sslmode {
case "require", "disable", "allow", "prefer", "verify-ca", "verify-full":
return sslmode, nil
default:
return "", &ErrInvalidSSLMode{given: sslmode}
}
}

// Connect returns a SQLX database corresponding to the provided connection
// string/URL, env variables, and any provided overrides.
func Connect(connStr string, overrides ...Overrideable) (*sqlx.DB, error) {
mcc, err := parseConnStr(connStr, overrides...)
if err != nil {
return nil, fmt.Errorf("could not connect: %v", err)
}
// It is required to connect using pgx/v5
// otherwise it may use pgx/v4 if that librariy is registered first
db, err := sqlx.Connect("pgx/v5", mcc.DSN())
func Connect(connStr string) (*sqlx.DB, error) {
db, err := sqlx.Connect("pgx/v5", connStr)
if err != nil {
return nil, fmt.Errorf("could not connect: %v", err)
}
Expand Down
Loading

0 comments on commit 0357e2e

Please sign in to comment.