diff --git a/connector.go b/connector.go index 96a8831..bab81fd 100644 --- a/connector.go +++ b/connector.go @@ -2,6 +2,7 @@ package dbsql import ( "context" + "crypto/tls" "database/sql/driver" "fmt" "net/http" @@ -233,6 +234,20 @@ func WithSessionParams(params map[string]string) connOption { } } +// WithSkipTLSHostVerify disables the verification of the hostname in the TLS certificate. +// WARNING: +// When this option is used, TLS is susceptible to machine-in-the-middle attacks. +// Please only use this option when the hostname is an internal private link hostname +func WithSkipTLSHostVerify() connOption { + return func(c *config.Config) { + if c.TLSConfig == nil { + c.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12, InsecureSkipVerify: true} // #nosec G402 + } else { + c.TLSConfig.InsecureSkipVerify = true // #nosec G402 + } + } +} + // WithAuthenticator sets up the Authentication. Mandatory if access token is not provided. func WithAuthenticator(authr auth.Authenticator) connOption { return func(c *config.Config) { diff --git a/connector_test.go b/connector_test.go index f24f767..8e0c7cb 100644 --- a/connector_test.go +++ b/connector_test.go @@ -6,7 +6,9 @@ import ( "time" "github.com/databricks/databricks-sql-go/auth/pat" + "github.com/databricks/databricks-sql-go/internal/client" "github.com/databricks/databricks-sql-go/internal/config" + "github.com/hashicorp/go-retryablehttp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -38,6 +40,7 @@ func TestNewConnector(t *testing.T) { WithTransport(roundTripper), WithCloudFetch(true), WithMaxDownloadThreads(15), + WithSkipTLSHostVerify(), ) expectedCloudFetchConfig := config.CloudFetchConfig{ UseCloudFetch: true, @@ -67,6 +70,7 @@ func TestNewConnector(t *testing.T) { expectedCfg := config.WithDefaults() expectedCfg.DriverVersion = DriverVersion expectedCfg.UserConfig = expectedUserConfig + expectedCfg.TLSConfig.InsecureSkipVerify = true coni, ok := con.(*connector) require.True(t, ok) assert.Nil(t, err) @@ -184,6 +188,28 @@ func TestNewConnector(t *testing.T) { } }) + + t.Run("Connector test WithSkipTLSHostVerify with PoolClient", func(t *testing.T) { + hostname := "databricks-host" + con, err := NewConnector( + WithServerHostname(hostname), + WithSkipTLSHostVerify(), + ) + assert.Nil(t, err) + + coni, ok := con.(*connector) + require.True(t, ok) + userConfig := coni.cfg.UserConfig + require.Equal(t, hostname, userConfig.Host) + + httpClient, ok := coni.client.Transport.(*retryablehttp.RoundTripper) + require.True(t, ok) + poolClient, ok := httpClient.Client.HTTPClient.Transport.(*client.Transport) + require.True(t, ok) + internalClient, ok := poolClient.Base.(*http.Transport) + require.True(t, ok) + require.True(t, internalClient.TLSClientConfig.InsecureSkipVerify) + }) } type mockRoundTripper struct{} diff --git a/internal/client/client.go b/internal/client/client.go index fda1053..febab52 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -2,6 +2,7 @@ package client import ( "context" + "crypto/tls" "crypto/x509" "encoding/json" "fmt" @@ -545,7 +546,12 @@ func RetryableClient(cfg *config.Config) *http.Client { return retryableClient.StandardClient() } -func PooledTransport() *http.Transport { +func PooledTransport(cfg *config.Config) *http.Transport { + var tlsConfig *tls.Config + if (cfg.TLSConfig != nil) && cfg.TLSConfig.InsecureSkipVerify { + tlsConfig = cfg.TLSConfig + } + transport := &http.Transport{ Proxy: http.ProxyFromEnvironment, DialContext: (&net.Dialer{ @@ -553,6 +559,7 @@ func PooledTransport() *http.Transport { KeepAlive: 30 * time.Second, DualStack: true, }).DialContext, + TLSClientConfig: tlsConfig, ForceAttemptHTTP2: true, MaxIdleConns: 100, IdleConnTimeout: 180 * time.Second, @@ -577,7 +584,7 @@ func PooledClient(cfg *config.Config) *http.Client { } } else { tr = &Transport{ - Base: PooledTransport(), + Base: PooledTransport(cfg), Authr: cfg.Authenticator, } }