Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify connection string configuration #99

Merged
merged 11 commits into from
Dec 10, 2024
10 changes: 7 additions & 3 deletions cmd/timescaledb-parallel-copy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,9 @@ var (

// Parse args
func init() {
// Documented https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING
flag.StringVar(&postgresConnect, "connection", "host=localhost user=postgres sslmode=disable", "PostgreSQL connection url")
flag.StringVar(&dbName, "db-name", "", "Database where the destination table exists")
flag.StringVar(&dbName, "db-name", "", "(deprecated) Database where the destination table exists")
flag.StringVar(&tableName, "table", "test_table", "Destination table for insertions")
flag.StringVar(&schemaName, "schema", "public", "Destination table's schema")
flag.BoolVar(&truncate, "truncate", false, "Truncate the destination table before insert")
Expand Down Expand Up @@ -86,13 +87,16 @@ func (l csvCopierLogger) Infof(msg string, args ...interface{}) {

func main() {
if showVersion {
fmt.Printf("%s %s (%s %s)\n", binName, version, runtime.GOOS, runtime.GOARCH)
log.Printf("%s %s (%s %s)\n", binName, version, runtime.GOOS, runtime.GOARCH)
os.Exit(0)
}

if dbName != "" {
log.Fatalf("Error: Deprecated flag -db-name is being used. Update -connection to connect to the given database")
}

copier, err := csvcopy.NewCopier(
postgresConnect,
dbName,
schemaName,
tableName,
copyOptions,
Expand Down
140 changes: 2 additions & 138 deletions internal/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,151 +4,15 @@ import (
"context"
"fmt"
"io"
"os"
"regexp"
"strconv"
"strings"

"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/stdlib"
"github.com/jmoiron/sqlx"
)

// minimalConnConfig is the minimal settings we need for connection. More
// unusual options are currently not supported.
type minimalConnConfig struct {
host string
user string
password string
db string
port uint16
sslmode string
}

// DSN returns the PostgreSQL compatible DSN string that corresponds to mcc.
// This is expressed as a string of <key>=<value> separated by spaces.
func (mcc *minimalConnConfig) DSN() string {
var s strings.Builder
writeNonempty := func(key, val string) {
if val != "" {
_, err := s.WriteString(key + "=" + val + " ")
if err != nil {
panic(err)
}
}
}
writeNonempty("host", mcc.host)
writeNonempty("user", mcc.user)
writeNonempty("password", mcc.password)
writeNonempty("dbname", mcc.db)
if mcc.port != 0 {
writeNonempty("port", strconv.FormatUint(uint64(mcc.port), 10))
}
writeNonempty("sslmode", mcc.sslmode)
writeNonempty("application_name", "timescaledb-parallel-copy")

return strings.TrimSpace(s.String())
}

// Overrideable is an interface for defining ways to override PG settings
// outside of the usual manners (through the connection string/URL or env vars).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MetalBlueberry pgx supports both the key=value and URI right? I'm like 99% sure, just want to double check.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// An example would be having specific flags that can be used to set database
// connect parameters.
type Overrideable interface {
Override() string
}

// OverrideDBName is a type for overriding the database name used to connect.
// To use it, one casts a string of the database name as an OverrideDBName
type OverrideDBName string

func (o OverrideDBName) Override() string {
return string(o)
}

// parseConnStr uses an external lib (that backs pgx) to take care of parsing
// connection parameters for connecting to PostgreSQL. It handles the connStr
// being in DSN or URL form, as well as reading env vars for additional settings.
func parseConnStr(connStr string, overrides ...Overrideable) (*minimalConnConfig, error) {
config, err := pgconn.ParseConfig(connStr)
if err != nil {
return nil, err
}
sslmode, err := determineTLS(connStr)
if err != nil {
return nil, err
}

mcc := &minimalConnConfig{
host: config.Host,
user: config.User,
password: config.Password,
db: config.Database,
port: config.Port,
sslmode: sslmode,
}

for _, o := range overrides {
switch o.(type) {
case OverrideDBName:
mcc.db = o.Override()
default:
return nil, fmt.Errorf("unknown overrideable: %T=%s", o, o.Override())
}
}

return mcc, nil
}

// ErrInvalidSSLMode is the error when the provided SSL mode is not one of the
// values that PostgreSQL supports.
type ErrInvalidSSLMode struct {
given string
}

func (e *ErrInvalidSSLMode) Error() string {
return "invalid SSL mode: " + e.given
}

const (
// envSSLMode is the environment variable key for SSL mode.
envSSLMode = "PGSSLMODE"
)

var sslmodeRegex = regexp.MustCompile("sslmode=([a-zA-Z-]+)")

// determineTLS attempts to match SSL mode to a known PostgreSQL supported value.
func determineTLS(connStr string) (string, error) {
res := sslmodeRegex.FindStringSubmatch(connStr)
var sslmode string
if len(res) == 2 {
sslmode = res[1]
} else {
sslmode = os.Getenv(envSSLMode)
}

if sslmode == "" {
return "", nil
}

switch sslmode {
case "require", "disable", "allow", "prefer", "verify-ca", "verify-full":
return sslmode, nil
default:
return "", &ErrInvalidSSLMode{given: sslmode}
}
}

// Connect returns a SQLX database corresponding to the provided connection
// string/URL, env variables, and any provided overrides.
func Connect(connStr string, overrides ...Overrideable) (*sqlx.DB, error) {
mcc, err := parseConnStr(connStr, overrides...)
if err != nil {
return nil, fmt.Errorf("could not connect: %v", err)
}
// It is required to connect using pgx/v5
// otherwise it may use pgx/v4 if that librariy is registered first
db, err := sqlx.Connect("pgx/v5", mcc.DSN())
func Connect(connStr string) (*sqlx.DB, error) {
db, err := sqlx.Connect("pgx/v5", connStr)
if err != nil {
return nil, fmt.Errorf("could not connect: %v", err)
}
Expand Down
Loading
Loading