Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(utils): add goroutine lock #1665

Open
wants to merge 7 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions internal/utils/goroutine_lock/goroutine_lock.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Copyright 2024 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

// Package goroutinelock implements goroutine locks.
package goroutinelock

import (
"sync"
"sync/atomic"

"github.com/cloudwego/kitex/pkg/klog"
)

// Wg is used to implement goroutine locks.
var Wg = &WaitGroup{}

// WaitGroup defines a wait group with counter and status.
type WaitGroup struct {
sync.WaitGroup
count int64
shutdownStarted uint32
}

// Add adds delta and bumps counter.
func (wg *WaitGroup) Add(delta int) {
atomic.AddInt64(&wg.count, int64(delta))
wg.WaitGroup.Add(delta)
if atomic.LoadUint32(&wg.shutdownStarted) > 0 {
klog.Warn("KITEX: shutdown started but a new goroutine lock is added")
}
}

// Done decrease wait group counter by 1.
func (wg *WaitGroup) Done() {
atomic.AddInt64(&wg.count, -1)
wg.WaitGroup.Done()
if atomic.LoadUint32(&wg.shutdownStarted) > 0 {
count := wg.GetCount()
if count > 0 {
klog.Infof("KITEX: waiting for goroutine locks to be released, remaining %d...", count)
} else {
klog.Info("KITEX: all goroutine locks have been released")
}
}
}

// GetCount gets wait group counter.
func (wg *WaitGroup) GetCount() int {
return int(atomic.LoadInt64(&wg.count))
}

// StartShutdown sets the shutdown status to true.
func (wg *WaitGroup) StartShutdown() {
atomic.AddUint32(&wg.shutdownStarted, 1)
}
159 changes: 159 additions & 0 deletions internal/utils/goroutine_lock/goroutine_lock_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
/*
* Copyright 2024 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package goroutinelock

import (
"sync/atomic"
"testing"
"time"
)

// TestWaitGroup_AddAndDone tests the basic functionality of the Add and Done methods.
func TestWaitGroup_AddAndDone(t *testing.T) {
t.Parallel()

wg := &WaitGroup{}

// Test initial count.
if count := wg.GetCount(); count != 0 {
t.Errorf("Expected initial count to be 0, got %d", count)
}

// Increase count.
wg.Add(1)
if count := wg.GetCount(); count != 1 {
t.Errorf("Expected count after Add(1) to be 1, got %d", count)
}

// Decrease count.
wg.Done()
if count := wg.GetCount(); count != 0 {
t.Errorf("Expected count after Done() to be 0, got %d", count)
}
}

// TestWaitGroup_Concurrency tests the WaitGroup in a concurrent environment.
func TestWaitGroup_Concurrency(t *testing.T) {
t.Parallel()

wg := &WaitGroup{}
goroutineCount := 100
doneCh := make(chan struct{})

// Launch multiple goroutines, each calling Add and Done.
for i := 0; i < goroutineCount; i++ {
wg.Add(1)
go func() {
// Simulate some work.
time.Sleep(10 * time.Millisecond)
wg.Done()
}()
}

// Wait for all goroutines to complete.
go func() {
wg.Wait()
close(doneCh)
}()

<-doneCh

// Verify that the count has returned to zero.
if count := wg.GetCount(); count != 0 {
t.Errorf("Expected count to be 0 after all goroutines done, got %d", count)
}
}

// TestWaitGroup_StartShutdown tests the StartShutdown method.
func TestWaitGroup_StartShutdown(t *testing.T) {
t.Parallel()

wg := &WaitGroup{}

// Start shutdown.
wg.StartShutdown()
if shutdown := atomic.LoadUint32(&wg.shutdownStarted); shutdown != 1 {
t.Errorf("Expected shutdownStarted to be 1, got %d", shutdown)
}

// Attempt to add count after shutdown.
wg.Add(1)
if count := wg.GetCount(); count != 1 {
t.Errorf("Expected count to be 1 after Add(1) post-shutdown, got %d", count)
}

// Should receive warning logs. We can't capture log content here, but we can ensure the code path is executed.

// Complete work.
wg.Done()
if count := wg.GetCount(); count != 0 {
t.Errorf("Expected count to be 0 after Done(), got %d", count)
}
}

// TestWaitGroup_MultipleShutdown tests calling StartShutdown multiple times.
func TestWaitGroup_MultipleShutdown(t *testing.T) {
t.Parallel()

wg := &WaitGroup{}

wg.StartShutdown()
wg.StartShutdown()

if shutdown := atomic.LoadUint32(&wg.shutdownStarted); shutdown != 2 {
t.Errorf("Expected shutdownStarted to be 2 after calling StartShutdown twice, got %d", shutdown)
}
}

// TestWaitGroup_AddAfterShutdown tests adding new counts after shutdown has started.
func TestWaitGroup_AddAfterShutdown(t *testing.T) {
t.Parallel()

wg := &WaitGroup{}

// Start shutdown.
wg.StartShutdown()

// Add count after shutdown.
wg.Add(1)
if count := wg.GetCount(); count != 1 {
t.Errorf("Expected count to be 1 after Add(1) post-shutdown, got %d", count)
}

// Complete work.
wg.Done()
if count := wg.GetCount(); count != 0 {
t.Errorf("Expected count to be 0 after Done(), got %d", count)
}
}

// TestWaitGroup_GetCount tests the accuracy of the GetCount method.
func TestWaitGroup_GetCount(t *testing.T) {
t.Parallel()

wg := &WaitGroup{}

wg.Add(5)
if count := wg.GetCount(); count != 5 {
t.Errorf("Expected count to be 5, got %d", count)
}

wg.Done()
if count := wg.GetCount(); count != 4 {
t.Errorf("Expected count to be 4 after Done(), got %d", count)
}
}
31 changes: 31 additions & 0 deletions pkg/utils/goroutine_lock.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Copyright 2021 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package utils

import goroutinelock "github.com/cloudwego/kitex/internal/utils/goroutine_lock"

// GoroutineLock locks the goroutine so that graceful shutdown will wait until the lock is released
func GoroutineLock() {
goroutinelock.Wg.Add(1)
}

// GoroutineUnlock unlocks the goroutine to allow graceful shutdown to continue.
// NOTE: This function should be executed using defer to avoid panic in the middle and causing the lock to not be
// released.
func GoroutineUnlock() {
goroutinelock.Wg.Done()
}
40 changes: 40 additions & 0 deletions pkg/utils/goroutine_lock_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright 2021 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package utils_test

import (
"testing"
"time"

goroutinelock "github.com/cloudwego/kitex/internal/utils/goroutine_lock"
"github.com/cloudwego/kitex/pkg/utils"
)

func TestGoroutineLockAndUnlock(t *testing.T) {
t.Parallel()
startTime := time.Now()
utils.GoroutineLock()
go func() {
time.Sleep(time.Second)
utils.GoroutineUnlock()
}()
goroutinelock.Wg.Wait()
diff := time.Since(startTime)
if diff < time.Second {
t.Errorf("Expect diff >= 1s, get %v", diff)
}
}
8 changes: 8 additions & 0 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/cloudwego/localsession/backup"

internal_server "github.com/cloudwego/kitex/internal/server"
goroutinelock "github.com/cloudwego/kitex/internal/utils/goroutine_lock"
"github.com/cloudwego/kitex/pkg/acl"
"github.com/cloudwego/kitex/pkg/diagnosis"
"github.com/cloudwego/kitex/pkg/discovery"
Expand Down Expand Up @@ -319,6 +320,13 @@ func (s *server) Stop() (err error) {
}
s.svr = nil
}
// Goroutine Locks must wait after all connections are closed, otherwise new connections might be created while
// waiting for goroutine locks.
goroutinelock.Wg.StartShutdown()
if count := goroutinelock.Wg.GetCount(); count > 0 {
klog.Infof("KITEX: waiting for goroutine locks to be released, remaining %d...", count)
goroutinelock.Wg.Wait()
}
})
return
}
Expand Down