Skip to content

Commit

Permalink
fix(daemon): close listener only once (#615)
Browse files Browse the repository at this point in the history
* fix(daemon): close listener only once

* refactor(daemon): rename Start to ListenAndServe and implement Serve

* fix(daemon): use atomic.Bool for server

* fix(daemon): attempt to fix idle timeout test
  • Loading branch information
aymanbagabas authored Dec 6, 2024
1 parent b450d10 commit 7c45a99
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 26 deletions.
2 changes: 1 addition & 1 deletion cmd/soft/serve/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ func (s *Server) Start() error {
errg, _ := errgroup.WithContext(s.ctx)
errg.Go(func() error {
s.logger.Print("Starting Git daemon", "addr", s.Config.Git.ListenAddr)
if err := s.GitDaemon.Start(); !errors.Is(err, daemon.ErrServerClosed) {
if err := s.GitDaemon.ListenAndServe(); !errors.Is(err, daemon.ErrServerClosed) {
return err
}
return nil
Expand Down
55 changes: 35 additions & 20 deletions pkg/daemon/daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"path/filepath"
"strings"
"sync"
"sync/atomic"
"time"

"github.com/charmbracelet/log"
Expand Down Expand Up @@ -43,7 +44,6 @@ var ErrServerClosed = fmt.Errorf("git: %w", net.ErrClosed)
// GitDaemon represents a Git daemon.
type GitDaemon struct {
ctx context.Context
listener net.Listener
addr string
finished chan struct{}
conns connections
Expand All @@ -52,6 +52,7 @@ type GitDaemon struct {
wg sync.WaitGroup
once sync.Once
logger *log.Logger
done atomic.Bool // indicates if the server has been closed
}

// NewDaemon returns a new Git daemon.
Expand All @@ -70,26 +71,31 @@ func NewGitDaemon(ctx context.Context) (*GitDaemon, error) {
return d, nil
}

// Start starts the Git TCP daemon.
func (d *GitDaemon) Start() error {
// listen on the socket
{
listener, err := net.Listen("tcp", d.addr)
if err != nil {
return err
}
d.listener = listener
// ListenAndServe starts the Git TCP daemon.
func (d *GitDaemon) ListenAndServe() error {
if d.done.Load() {
return ErrServerClosed
}
listener, err := net.Listen("tcp", d.addr)
if err != nil {
return err

Check failure on line 81 in pkg/daemon/daemon.go

View workflow job for this annotation

GitHub Actions / lint-soft

error returned from external package is unwrapped: sig: func net.Listen(network string, address string) (net.Listener, error) (wrapcheck)
}
return d.Serve(listener)
}

// close eventual connections to the socket
defer d.listener.Close() // nolint: errcheck
// Serve listens on the TCP network address and serves Git requests.
func (d *GitDaemon) Serve(listener net.Listener) error {
if d.done.Load() {
return ErrServerClosed
}

d.wg.Add(1)
defer d.wg.Done()
defer listener.Close() //nolint:errcheck

var tempDelay time.Duration
for {
conn, err := d.listener.Accept()
conn, err := listener.Accept()
if err != nil {
select {
case <-d.finished:
Expand Down Expand Up @@ -305,21 +311,30 @@ func (d *GitDaemon) handleClient(conn net.Conn) {

// Close closes the underlying listener.
func (d *GitDaemon) Close() error {
d.once.Do(func() { close(d.finished) })
err := d.listener.Close()
err := d.closeListener()
d.conns.CloseAll() // nolint: errcheck
return err
}

// closeListener closes the listener and the finished channel.
func (d *GitDaemon) closeListener() error {
if d.done.Load() {
return ErrServerClosed
}
d.once.Do(func() {
close(d.finished)
d.done.Store(true)
})
return nil
}

// Shutdown gracefully shuts down the daemon.
func (d *GitDaemon) Shutdown(ctx context.Context) error {
// in the case when git daemon was never started
if d.listener == nil {
return nil
if d.done.Load() {
return ErrServerClosed
}

d.once.Do(func() { close(d.finished) })
err := d.listener.Close()
err := d.closeListener()
finished := make(chan struct{}, 1)
go func() {
d.wg.Wait()
Expand Down
20 changes: 15 additions & 5 deletions pkg/daemon/daemon_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func TestMain(m *testing.M) {
}
testDaemon = d
go func() {
if err := d.Start(); err != ErrServerClosed {
if err := d.ListenAndServe(); err != ErrServerClosed {
log.Fatal(err)
}
}()
Expand All @@ -75,11 +75,21 @@ func TestMain(m *testing.M) {
}

func TestIdleTimeout(t *testing.T) {
c, err := net.Dial("tcp", testDaemon.addr)
if err != nil {
t.Fatal(err)
var err error
var c net.Conn
var tries int
for {
c, err = net.Dial("tcp", testDaemon.addr)
if err != nil && tries >= 3 {
t.Fatal(err)
}
tries++
if testDaemon.conns.Size() != 0 {
break
}
time.Sleep(10 * time.Millisecond)
}
time.Sleep(time.Second)
time.Sleep(2 * time.Second)
_, err = readPktline(c)
if err == nil {
t.Errorf("expected error, got nil")
Expand Down

0 comments on commit 7c45a99

Please sign in to comment.