Skip to content

Commit

Permalink
Implement shutdown callback
Browse files Browse the repository at this point in the history
  • Loading branch information
crobert-1 committed Feb 7, 2024
1 parent 599018f commit f805a97
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 35 deletions.
10 changes: 7 additions & 3 deletions .chloggen/goleak_configtls.yaml
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
# Use this changelog template to create an entry for release notes.

# One of 'breaking', 'deprecation', 'new_component', 'enhancement', 'bug_fix'
change_type: bug_fix
change_type: breaking

# The name of the component, or a single word describing the area of concern, (e.g. otlpreceiver)
component: config/configtls

# A brief description of the change. Surround your text with quotes ("") if it needs to start with a backtick (`).
note: Add Shutdown call to TLSServerSetting to fix leaking goroutine
note: Add shutdown call back to TLSServerSetting to fix memory leaks

# One or more tracking issues or pull requests related to the change
issues: [9165]

# (Optional) One or more lines of additional information to render under the primary note.
# These lines will be padded with 2 spaces and then inserted directly into the document.
# Use pipe (|) for multiline entries.
subtext:
subtext: |
The TLSServerSetting.LoadTLSConfig method signature has been modified to now return an extra
value, a function callback. This function callback must be used to ensure memory isn't leaked
on shutdown. Callers to TLSServerSetting.LoadTLSConfig should store every returned callback
and call them as soon as the relevant TLS config is no longer needed.
# Optional: The change log or logs in which this entry should be included.
# e.g. '[user]' or '[user, api]'
Expand Down
2 changes: 1 addition & 1 deletion config/configtls/clientcasfilereloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ func (r *clientCAsFileReloader) handleWatcherEvents() {

func (r *clientCAsFileReloader) shutdown() error {
if r.shutdownCH == nil {
return fmt.Errorf("client CAs file watcher is not running")
return nil
}
r.shutdownCH <- true
close(r.shutdownCH)
Expand Down
4 changes: 2 additions & 2 deletions config/configtls/clientcasfilereloader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ import (
"github.com/stretchr/testify/assert"
)

func TestCannotShutdownIfNotWatching(t *testing.T) {
func TestCanShutdownIfNotWatching(t *testing.T) {
reloader, _, _ := createReloader(t)
err := reloader.shutdown()
assert.Error(t, err)
assert.NoError(t, err)
}

func TestCannotStartIfAlreadyWatching(t *testing.T) {
Expand Down
36 changes: 19 additions & 17 deletions config/configtls/configtls.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,6 @@ type TLSServerSetting struct {
// Reload the ClientCAs file when it is modified
// (optional, default false)
ReloadClientCAFile bool `mapstructure:"client_ca_file_reload"`

// File reloader for the Client CA.
reloader *clientCAsFileReloader
}

// certReloader is a wrapper object for certificate reloading
Expand Down Expand Up @@ -329,39 +326,44 @@ func (c TLSClientSetting) LoadTLSConfig() (*tls.Config, error) {
return tlsCfg, nil
}

// LoadTLSConfig loads the TLS configuration.
func (c *TLSServerSetting) LoadTLSConfig() (*tls.Config, error) {
// LoadTLSConfig loads the TLS configuration. The returned function is a callback that should
// be used to signal shutdown.
func (c TLSServerSetting) LoadTLSConfig() (*tls.Config, func() error, error) {
tlsCfg, err := c.loadTLSConfig()
nopShutdown := func() error { return nil }
var reloader *clientCAsFileReloader

if err != nil {
return nil, fmt.Errorf("failed to load TLS config: %w", err)
return nil, nopShutdown, fmt.Errorf("failed to load TLS config: %w", err)
}
if c.ClientCAFile != "" {
var err error
c.reloader, err = newClientCAsReloader(c.ClientCAFile, c)
reloader, err = newClientCAsReloader(c.ClientCAFile, &c)
if err != nil {
return nil, err
return nil, nopShutdown, err
}
if c.ReloadClientCAFile {
err = c.reloader.startWatching()
err = reloader.startWatching()
if err != nil {
return nil, err
return nil, nopShutdown, err
}
tlsCfg.GetConfigForClient = func(t *tls.ClientHelloInfo) (*tls.Config, error) { return c.reloader.getClientConfig(tlsCfg) }
tlsCfg.GetConfigForClient = func(t *tls.ClientHelloInfo) (*tls.Config, error) { return reloader.getClientConfig(tlsCfg) }
}
tlsCfg.ClientCAs = c.reloader.certPool
tlsCfg.ClientCAs = reloader.certPool
tlsCfg.ClientAuth = tls.RequireAndVerifyClientCert
}
return tlsCfg, nil

if reloader != nil {
return tlsCfg, reloader.shutdown, nil
}

return tlsCfg, nopShutdown, nil
}

func (c TLSServerSetting) loadClientCAFile() (*x509.CertPool, error) {
return c.loadCert(c.ClientCAFile)
}

func (c TLSServerSetting) Shutdown() error {
if c.ReloadClientCAFile {
return c.reloader.shutdown()
}
return nil
}

Expand Down
60 changes: 48 additions & 12 deletions config/configtls/configtls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,22 +282,25 @@ func TestLoadTLSServerConfigError(t *testing.T) {
KeyFile: "doesnt/exist",
},
}
_, err := tlsSetting.LoadTLSConfig()
_, _, err := tlsSetting.LoadTLSConfig()
assert.Error(t, err)

tlsSetting = TLSServerSetting{
ClientCAFile: "doesnt/exist",
}
_, err = tlsSetting.LoadTLSConfig()
_, _, err = tlsSetting.LoadTLSConfig()
assert.Error(t, err)
}

func TestLoadTLSServerConfig(t *testing.T) {
tlsSetting := TLSServerSetting{}
tlsCfg, err := tlsSetting.LoadTLSConfig()
tlsCfg, shutdown, err := tlsSetting.LoadTLSConfig()
assert.NoError(t, err)
assert.NotNil(t, tlsCfg)
defer func() { assert.NoError(t, tlsSetting.Shutdown()) }()
defer func() {
shutdown()
assert.NoError(t, tlsSetting.Shutdown())
}()
}

func TestLoadTLSServerConfigReload(t *testing.T) {
Expand All @@ -311,10 +314,13 @@ func TestLoadTLSServerConfigReload(t *testing.T) {
ReloadClientCAFile: true,
}

tlsCfg, err := tlsSetting.LoadTLSConfig()
tlsCfg, shutdown, err := tlsSetting.LoadTLSConfig()
assert.NoError(t, err)
assert.NotNil(t, tlsCfg)
defer func() { assert.NoError(t, tlsSetting.Shutdown()) }()
defer func() {
shutdown()
assert.NoError(t, tlsSetting.Shutdown())
}()

firstClient, err := tlsCfg.GetConfigForClient(nil)
assert.NoError(t, err)
Expand All @@ -332,6 +338,30 @@ func TestLoadTLSServerConfigReload(t *testing.T) {
assert.NotEqual(t, firstClient.ClientCAs, secondClient.ClientCAs)
}

func TestLoadTLSServerMultipleConfigs(t *testing.T) {
tmpCaPath := createTempClientCaFile(t)

overwriteClientCA(t, tmpCaPath, "ca-1.crt")

tlsSetting := TLSServerSetting{
ClientCAFile: tmpCaPath,
ReloadClientCAFile: true,
}

var allShutdowns []func() error

for i := 0; i < 10; i++ {
tlsCfg, shutdown, err := tlsSetting.LoadTLSConfig()
assert.NoError(t, err)
assert.NotNil(t, tlsCfg)
allShutdowns = append(allShutdowns, shutdown)
}

for _, shutdown := range allShutdowns {
assert.NoError(t, shutdown())
}
}

func TestLoadTLSServerConfigFailingReload(t *testing.T) {

tmpCaPath := createTempClientCaFile(t)
Expand All @@ -343,10 +373,13 @@ func TestLoadTLSServerConfigFailingReload(t *testing.T) {
ReloadClientCAFile: true,
}

tlsCfg, err := tlsSetting.LoadTLSConfig()
tlsCfg, shutdown, err := tlsSetting.LoadTLSConfig()
assert.NoError(t, err)
assert.NotNil(t, tlsCfg)
defer func() { assert.NoError(t, tlsSetting.Shutdown()) }()
defer func() {
shutdown()
assert.NoError(t, tlsSetting.Shutdown())
}()

firstClient, err := tlsCfg.GetConfigForClient(nil)
assert.NoError(t, err)
Expand Down Expand Up @@ -375,7 +408,7 @@ func TestLoadTLSServerConfigFailingInitialLoad(t *testing.T) {
ReloadClientCAFile: true,
}

tlsCfg, err := tlsSetting.LoadTLSConfig()
tlsCfg, _, err := tlsSetting.LoadTLSConfig()
assert.Error(t, err)
assert.Nil(t, tlsCfg)
}
Expand All @@ -389,7 +422,7 @@ func TestLoadTLSServerConfigWrongPath(t *testing.T) {
ReloadClientCAFile: true,
}

tlsCfg, err := tlsSetting.LoadTLSConfig()
tlsCfg, _, err := tlsSetting.LoadTLSConfig()
assert.Error(t, err)
assert.Nil(t, tlsCfg)
}
Expand All @@ -405,10 +438,13 @@ func TestLoadTLSServerConfigFailing(t *testing.T) {
ReloadClientCAFile: true,
}

tlsCfg, err := tlsSetting.LoadTLSConfig()
tlsCfg, shutdown, err := tlsSetting.LoadTLSConfig()
assert.NoError(t, err)
assert.NotNil(t, tlsCfg)
defer func() { assert.NoError(t, tlsSetting.Shutdown()) }()
defer func() {
shutdown()
assert.NoError(t, tlsSetting.Shutdown())
}()

firstClient, err := tlsCfg.GetConfigForClient(nil)
assert.NoError(t, err)
Expand Down

0 comments on commit f805a97

Please sign in to comment.