diff --git a/internal/utils/goroutine_lock/goroutine_lock.go b/internal/utils/goroutine_lock/goroutine_lock.go new file mode 100644 index 0000000000..8298bf903e --- /dev/null +++ b/internal/utils/goroutine_lock/goroutine_lock.go @@ -0,0 +1,68 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package goroutinelock implements goroutine locks. +package goroutinelock + +import ( + "sync" + "sync/atomic" + + "github.com/cloudwego/kitex/pkg/klog" +) + +// Wg is used to implement goroutine locks. +var Wg = &WaitGroup{} + +// WaitGroup defines a wait group with counter and status. +type WaitGroup struct { + sync.WaitGroup + count int64 + shutdownStarted uint32 +} + +// Add adds delta and bumps counter. +func (wg *WaitGroup) Add(delta int) { + atomic.AddInt64(&wg.count, int64(delta)) + wg.WaitGroup.Add(delta) + if atomic.LoadUint32(&wg.shutdownStarted) > 0 { + klog.Warn("KITEX: shutdown started but a new goroutine lock is added") + } +} + +// Done decrease wait group counter by 1. +func (wg *WaitGroup) Done() { + atomic.AddInt64(&wg.count, -1) + wg.WaitGroup.Done() + if atomic.LoadUint32(&wg.shutdownStarted) > 0 { + count := wg.GetCount() + if count > 0 { + klog.Infof("KITEX: waiting for goroutine locks to be released, remaining %d...", count) + } else { + klog.Info("KITEX: all goroutine locks have been released") + } + } +} + +// GetCount gets wait group counter. +func (wg *WaitGroup) GetCount() int { + return int(atomic.LoadInt64(&wg.count)) +} + +// StartShutdown sets the shutdown status to true. +func (wg *WaitGroup) StartShutdown() { + atomic.AddUint32(&wg.shutdownStarted, 1) +} diff --git a/internal/utils/goroutine_lock/goroutine_lock_test.go b/internal/utils/goroutine_lock/goroutine_lock_test.go new file mode 100644 index 0000000000..b1c096828a --- /dev/null +++ b/internal/utils/goroutine_lock/goroutine_lock_test.go @@ -0,0 +1,159 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package goroutinelock + +import ( + "sync/atomic" + "testing" + "time" +) + +// TestWaitGroup_AddAndDone tests the basic functionality of the Add and Done methods. +func TestWaitGroup_AddAndDone(t *testing.T) { + t.Parallel() + + wg := &WaitGroup{} + + // Test initial count. + if count := wg.GetCount(); count != 0 { + t.Errorf("Expected initial count to be 0, got %d", count) + } + + // Increase count. + wg.Add(1) + if count := wg.GetCount(); count != 1 { + t.Errorf("Expected count after Add(1) to be 1, got %d", count) + } + + // Decrease count. + wg.Done() + if count := wg.GetCount(); count != 0 { + t.Errorf("Expected count after Done() to be 0, got %d", count) + } +} + +// TestWaitGroup_Concurrency tests the WaitGroup in a concurrent environment. +func TestWaitGroup_Concurrency(t *testing.T) { + t.Parallel() + + wg := &WaitGroup{} + goroutineCount := 100 + doneCh := make(chan struct{}) + + // Launch multiple goroutines, each calling Add and Done. + for i := 0; i < goroutineCount; i++ { + wg.Add(1) + go func() { + // Simulate some work. + time.Sleep(10 * time.Millisecond) + wg.Done() + }() + } + + // Wait for all goroutines to complete. + go func() { + wg.Wait() + close(doneCh) + }() + + <-doneCh + + // Verify that the count has returned to zero. + if count := wg.GetCount(); count != 0 { + t.Errorf("Expected count to be 0 after all goroutines done, got %d", count) + } +} + +// TestWaitGroup_StartShutdown tests the StartShutdown method. +func TestWaitGroup_StartShutdown(t *testing.T) { + t.Parallel() + + wg := &WaitGroup{} + + // Start shutdown. + wg.StartShutdown() + if shutdown := atomic.LoadUint32(&wg.shutdownStarted); shutdown != 1 { + t.Errorf("Expected shutdownStarted to be 1, got %d", shutdown) + } + + // Attempt to add count after shutdown. + wg.Add(1) + if count := wg.GetCount(); count != 1 { + t.Errorf("Expected count to be 1 after Add(1) post-shutdown, got %d", count) + } + + // Should receive warning logs. We can't capture log content here, but we can ensure the code path is executed. + + // Complete work. + wg.Done() + if count := wg.GetCount(); count != 0 { + t.Errorf("Expected count to be 0 after Done(), got %d", count) + } +} + +// TestWaitGroup_MultipleShutdown tests calling StartShutdown multiple times. +func TestWaitGroup_MultipleShutdown(t *testing.T) { + t.Parallel() + + wg := &WaitGroup{} + + wg.StartShutdown() + wg.StartShutdown() + + if shutdown := atomic.LoadUint32(&wg.shutdownStarted); shutdown != 2 { + t.Errorf("Expected shutdownStarted to be 2 after calling StartShutdown twice, got %d", shutdown) + } +} + +// TestWaitGroup_AddAfterShutdown tests adding new counts after shutdown has started. +func TestWaitGroup_AddAfterShutdown(t *testing.T) { + t.Parallel() + + wg := &WaitGroup{} + + // Start shutdown. + wg.StartShutdown() + + // Add count after shutdown. + wg.Add(1) + if count := wg.GetCount(); count != 1 { + t.Errorf("Expected count to be 1 after Add(1) post-shutdown, got %d", count) + } + + // Complete work. + wg.Done() + if count := wg.GetCount(); count != 0 { + t.Errorf("Expected count to be 0 after Done(), got %d", count) + } +} + +// TestWaitGroup_GetCount tests the accuracy of the GetCount method. +func TestWaitGroup_GetCount(t *testing.T) { + t.Parallel() + + wg := &WaitGroup{} + + wg.Add(5) + if count := wg.GetCount(); count != 5 { + t.Errorf("Expected count to be 5, got %d", count) + } + + wg.Done() + if count := wg.GetCount(); count != 4 { + t.Errorf("Expected count to be 4 after Done(), got %d", count) + } +} diff --git a/pkg/utils/goroutine_lock.go b/pkg/utils/goroutine_lock.go new file mode 100644 index 0000000000..bbef680e3b --- /dev/null +++ b/pkg/utils/goroutine_lock.go @@ -0,0 +1,31 @@ +/* + * Copyright 2021 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package utils + +import goroutinelock "github.com/cloudwego/kitex/internal/utils/goroutine_lock" + +// GoroutineLock locks the goroutine so that graceful shutdown will wait until the lock is released +func GoroutineLock() { + goroutinelock.Wg.Add(1) +} + +// GoroutineUnlock unlocks the goroutine to allow graceful shutdown to continue. +// NOTE: This function should be executed using defer to avoid panic in the middle and causing the lock to not be +// released. +func GoroutineUnlock() { + goroutinelock.Wg.Done() +} diff --git a/pkg/utils/goroutine_lock_test.go b/pkg/utils/goroutine_lock_test.go new file mode 100644 index 0000000000..8a94cf43b6 --- /dev/null +++ b/pkg/utils/goroutine_lock_test.go @@ -0,0 +1,40 @@ +/* + * Copyright 2021 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package utils_test + +import ( + "testing" + "time" + + goroutinelock "github.com/cloudwego/kitex/internal/utils/goroutine_lock" + "github.com/cloudwego/kitex/pkg/utils" +) + +func TestGoroutineLockAndUnlock(t *testing.T) { + t.Parallel() + startTime := time.Now() + utils.GoroutineLock() + go func() { + time.Sleep(time.Second) + utils.GoroutineUnlock() + }() + goroutinelock.Wg.Wait() + diff := time.Since(startTime) + if diff < time.Second { + t.Errorf("Expect diff >= 1s, get %v", diff) + } +} diff --git a/server/server.go b/server/server.go index 025e493173..c4cf6b0a06 100644 --- a/server/server.go +++ b/server/server.go @@ -30,6 +30,7 @@ import ( "github.com/cloudwego/localsession/backup" internal_server "github.com/cloudwego/kitex/internal/server" + goroutinelock "github.com/cloudwego/kitex/internal/utils/goroutine_lock" "github.com/cloudwego/kitex/pkg/acl" "github.com/cloudwego/kitex/pkg/diagnosis" "github.com/cloudwego/kitex/pkg/discovery" @@ -319,6 +320,13 @@ func (s *server) Stop() (err error) { } s.svr = nil } + // Goroutine Locks must wait after all connections are closed, otherwise new connections might be created while + // waiting for goroutine locks. + goroutinelock.Wg.StartShutdown() + if count := goroutinelock.Wg.GetCount(); count > 0 { + klog.Infof("KITEX: waiting for goroutine locks to be released, remaining %d...", count) + goroutinelock.Wg.Wait() + } }) return }