Skip to content

Commit

Permalink
Add connection option WithSkipTLSHostVerify for privatelink host (#225)
Browse files Browse the repository at this point in the history
## Background
The driver cannot connect to the workspace whose hostname is an internal
private link hostname, this is because its domain is not added into
Databricks workspace certificate. See
#223

## Description
This change adds connection option `WithSkipTLSHostVerify` which can
disable verifying the hostname in the server certificate. In this mode,
TLS is susceptible to machine-in-the-middle attacks. Please only use
this option when the hostname is an internal private link hostname.

Here is the usage
```go
connector, err := dbsql.NewConnector(
   dbsql.WithServerHostname("<hostname>"),
   dbsql.WithHTTPPath("<http_path>"),
   dbsql.WithAccessToken("<token>"),
   dbsql.WithSkipTLSHostVerify(),
)
```
  • Loading branch information
jackyhu-db authored May 28, 2024
2 parents 697ea4f + 34b7340 commit 683e392
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 2 deletions.
15 changes: 15 additions & 0 deletions connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package dbsql

import (
"context"
"crypto/tls"
"database/sql/driver"
"fmt"
"net/http"
Expand Down Expand Up @@ -233,6 +234,20 @@ func WithSessionParams(params map[string]string) connOption {
}
}

// WithSkipTLSHostVerify disables the verification of the hostname in the TLS certificate.
// WARNING:
// When this option is used, TLS is susceptible to machine-in-the-middle attacks.
// Please only use this option when the hostname is an internal private link hostname
func WithSkipTLSHostVerify() connOption {
return func(c *config.Config) {
if c.TLSConfig == nil {
c.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12, InsecureSkipVerify: true} // #nosec G402
} else {
c.TLSConfig.InsecureSkipVerify = true // #nosec G402
}
}
}

// WithAuthenticator sets up the Authentication. Mandatory if access token is not provided.
func WithAuthenticator(authr auth.Authenticator) connOption {
return func(c *config.Config) {
Expand Down
26 changes: 26 additions & 0 deletions connector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ import (
"time"

"github.com/databricks/databricks-sql-go/auth/pat"
"github.com/databricks/databricks-sql-go/internal/client"
"github.com/databricks/databricks-sql-go/internal/config"
"github.com/hashicorp/go-retryablehttp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -38,6 +40,7 @@ func TestNewConnector(t *testing.T) {
WithTransport(roundTripper),
WithCloudFetch(true),
WithMaxDownloadThreads(15),
WithSkipTLSHostVerify(),
)
expectedCloudFetchConfig := config.CloudFetchConfig{
UseCloudFetch: true,
Expand Down Expand Up @@ -67,6 +70,7 @@ func TestNewConnector(t *testing.T) {
expectedCfg := config.WithDefaults()
expectedCfg.DriverVersion = DriverVersion
expectedCfg.UserConfig = expectedUserConfig
expectedCfg.TLSConfig.InsecureSkipVerify = true
coni, ok := con.(*connector)
require.True(t, ok)
assert.Nil(t, err)
Expand Down Expand Up @@ -184,6 +188,28 @@ func TestNewConnector(t *testing.T) {
}

})

t.Run("Connector test WithSkipTLSHostVerify with PoolClient", func(t *testing.T) {
hostname := "databricks-host"
con, err := NewConnector(
WithServerHostname(hostname),
WithSkipTLSHostVerify(),
)
assert.Nil(t, err)

coni, ok := con.(*connector)
require.True(t, ok)
userConfig := coni.cfg.UserConfig
require.Equal(t, hostname, userConfig.Host)

httpClient, ok := coni.client.Transport.(*retryablehttp.RoundTripper)
require.True(t, ok)
poolClient, ok := httpClient.Client.HTTPClient.Transport.(*client.Transport)
require.True(t, ok)
internalClient, ok := poolClient.Base.(*http.Transport)
require.True(t, ok)
require.True(t, internalClient.TLSClientConfig.InsecureSkipVerify)
})
}

type mockRoundTripper struct{}
Expand Down
11 changes: 9 additions & 2 deletions internal/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package client

import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"fmt"
Expand Down Expand Up @@ -545,14 +546,20 @@ func RetryableClient(cfg *config.Config) *http.Client {
return retryableClient.StandardClient()
}

func PooledTransport() *http.Transport {
func PooledTransport(cfg *config.Config) *http.Transport {
var tlsConfig *tls.Config
if (cfg.TLSConfig != nil) && cfg.TLSConfig.InsecureSkipVerify {
tlsConfig = cfg.TLSConfig
}

transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}).DialContext,
TLSClientConfig: tlsConfig,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 180 * time.Second,
Expand All @@ -577,7 +584,7 @@ func PooledClient(cfg *config.Config) *http.Client {
}
} else {
tr = &Transport{
Base: PooledTransport(),
Base: PooledTransport(cfg),
Authr: cfg.Authenticator,
}
}
Expand Down

0 comments on commit 683e392

Please sign in to comment.